Fix: mitm proxy should forward websocket
This commit is contained in:
parent
7c50c068f5
commit
30025c0241
@ -215,11 +215,10 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica
|
|||||||
BasicConstraintsValid: true,
|
BasicConstraintsValid: true,
|
||||||
NotBefore: time.Now().Add(-c.validity),
|
NotBefore: time.Now().Add(-c.validity),
|
||||||
NotAfter: time.Now().Add(c.validity),
|
NotAfter: time.Now().Add(c.validity),
|
||||||
|
DNSNames: dnsNames,
|
||||||
|
IPAddresses: ips,
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl.DNSNames = dnsNames
|
|
||||||
tmpl.IPAddresses = ips
|
|
||||||
|
|
||||||
raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey)
|
raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -62,7 +62,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
|
|||||||
request.RequestURI = ""
|
request.RequestURI = ""
|
||||||
|
|
||||||
if isUpgradeRequest(request) {
|
if isUpgradeRequest(request) {
|
||||||
if resp = handleUpgrade(conn, request, in); resp == nil {
|
if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil {
|
||||||
return // hijack connection
|
return // hijack connection
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@ -16,13 +18,17 @@ func isUpgradeRequest(req *http.Request) bool {
|
|||||||
return strings.EqualFold(req.Header.Get("Connection"), "Upgrade")
|
return strings.EqualFold(req.Header.Get("Connection"), "Upgrade")
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) {
|
func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) {
|
||||||
removeProxyHeaders(request.Header)
|
removeProxyHeaders(request.Header)
|
||||||
RemoveExtraHTTPHostPort(request)
|
RemoveExtraHTTPHostPort(request)
|
||||||
|
|
||||||
address := request.Host
|
address := request.Host
|
||||||
if _, _, err := net.SplitHostPort(address); err != nil {
|
if _, _, err := net.SplitHostPort(address); err != nil {
|
||||||
address = net.JoinHostPort(address, "80")
|
port := "80"
|
||||||
|
if request.TLS != nil {
|
||||||
|
port = "443"
|
||||||
|
}
|
||||||
|
address = net.JoinHostPort(address, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
dstAddr := socks5.ParseAddr(address)
|
dstAddr := socks5.ParseAddr(address)
|
||||||
@ -32,38 +38,58 @@ func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext
|
|||||||
|
|
||||||
left, right := net.Pipe()
|
left, right := net.Pipe()
|
||||||
|
|
||||||
in <- inbound.NewHTTP(dstAddr, conn.RemoteAddr(), right)
|
in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right)
|
||||||
|
|
||||||
bufferedLeft := N.NewBufferedConn(left)
|
var remoteServer *N.BufferedConn
|
||||||
defer func() {
|
if request.TLS != nil {
|
||||||
_ = bufferedLeft.Close()
|
tlsConn := tls.Client(left, &tls.Config{
|
||||||
}()
|
ServerName: request.URL.Hostname(),
|
||||||
|
})
|
||||||
|
|
||||||
err := request.Write(bufferedLeft)
|
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
|
||||||
if err != nil {
|
defer cancel()
|
||||||
|
if tlsConn.HandshakeContext(ctx) != nil {
|
||||||
|
_ = localConn.Close()
|
||||||
|
_ = left.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err = http.ReadResponse(bufferedLeft.Reader(), request)
|
remoteServer = N.NewBufferedConn(tlsConn)
|
||||||
|
} else {
|
||||||
|
remoteServer = N.NewBufferedConn(left)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = remoteServer.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := request.Write(remoteServer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = localConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = http.ReadResponse(remoteServer.Reader(), request)
|
||||||
|
if err != nil {
|
||||||
|
_ = localConn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusSwitchingProtocols {
|
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||||
removeProxyHeaders(resp.Header)
|
removeProxyHeaders(resp.Header)
|
||||||
|
|
||||||
err = conn.SetReadDeadline(time.Time{})
|
err = localConn.SetReadDeadline(time.Time{}) // set to not time out
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = resp.Write(conn)
|
err = resp.Write(localConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
N.Relay(bufferedLeft, conn)
|
N.Relay(remoteServer, localConn) // blocking here
|
||||||
_ = conn.Close()
|
_ = localConn.Close()
|
||||||
resp = nil
|
resp = nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -45,12 +45,12 @@ readLoop:
|
|||||||
for {
|
for {
|
||||||
// use SetReadDeadline instead of Proxy-Connection keep-alive
|
// use SetReadDeadline instead of Proxy-Connection keep-alive
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil {
|
||||||
break readLoop
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
request, err := H.ReadRequest(conn.Reader())
|
request, err := H.ReadRequest(conn.Reader())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break readLoop
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
var response *http.Response
|
var response *http.Response
|
||||||
@ -71,7 +71,7 @@ readLoop:
|
|||||||
// Manual writing to support CONNECT for http 1.0 (workaround for uplay client)
|
// Manual writing to support CONNECT for http 1.0 (workaround for uplay client)
|
||||||
if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil {
|
if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil {
|
||||||
handleError(opt, session, err)
|
handleError(opt, session, err)
|
||||||
break readLoop // close connection
|
break // close connection
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasSuffix(session.request.URL.Host, ":80") {
|
if strings.HasSuffix(session.request.URL.Host, ":80") {
|
||||||
@ -81,7 +81,7 @@ readLoop:
|
|||||||
b, err := conn.Peek(1)
|
b, err := conn.Peek(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(opt, session, err)
|
handleError(opt, session, err)
|
||||||
break readLoop // close connection
|
break // close connection
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLS handshake.
|
// TLS handshake.
|
||||||
@ -92,7 +92,7 @@ readLoop:
|
|||||||
if err = tlsConn.Handshake(); err != nil {
|
if err = tlsConn.Handshake(); err != nil {
|
||||||
session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err))
|
session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err))
|
||||||
_ = writeResponse(session, false)
|
_ = writeResponse(session, false)
|
||||||
break readLoop // close connection
|
break // close connection
|
||||||
}
|
}
|
||||||
|
|
||||||
c = tlsConn
|
c = tlsConn
|
||||||
@ -105,20 +105,27 @@ readLoop:
|
|||||||
|
|
||||||
prepareRequest(c, session.request)
|
prepareRequest(c, session.request)
|
||||||
|
|
||||||
H.RemoveHopByHopHeaders(session.request.Header)
|
|
||||||
H.RemoveExtraHTTPHostPort(session.request)
|
|
||||||
|
|
||||||
// hijack api
|
// hijack api
|
||||||
if session.request.URL.Hostname() == opt.ApiHost {
|
if session.request.URL.Hostname() == opt.ApiHost {
|
||||||
if err = handleApiRequest(session, opt); err != nil {
|
if err = handleApiRequest(session, opt); err != nil {
|
||||||
handleError(opt, session, err)
|
handleError(opt, session, err)
|
||||||
break readLoop
|
|
||||||
}
|
}
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// forward websocket
|
||||||
|
if isWebsocketRequest(request) {
|
||||||
|
session.request.RequestURI = ""
|
||||||
|
if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil {
|
||||||
|
return // hijack connection
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
H.RemoveHopByHopHeaders(session.request.Header)
|
||||||
|
H.RemoveExtraHTTPHostPort(session.request)
|
||||||
|
|
||||||
// hijack custom request and write back custom response if necessary
|
// hijack custom request and write back custom response if necessary
|
||||||
if opt.Handler != nil {
|
if opt.Handler != nil && session.response == nil {
|
||||||
newReq, newRes := opt.Handler.HandleRequest(session)
|
newReq, newRes := opt.Handler.HandleRequest(session)
|
||||||
if newReq != nil {
|
if newReq != nil {
|
||||||
session.request = newReq
|
session.request = newReq
|
||||||
@ -128,12 +135,13 @@ readLoop:
|
|||||||
|
|
||||||
if err = writeResponse(session, false); err != nil {
|
if err = writeResponse(session, false); err != nil {
|
||||||
handleError(opt, session, err)
|
handleError(opt, session, err)
|
||||||
break readLoop
|
break
|
||||||
}
|
}
|
||||||
return
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if session.response == nil {
|
||||||
session.request.RequestURI = ""
|
session.request.RequestURI = ""
|
||||||
|
|
||||||
if session.request.URL.Host == "" {
|
if session.request.URL.Host == "" {
|
||||||
@ -146,10 +154,11 @@ readLoop:
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(opt, session, err)
|
handleError(opt, session, err)
|
||||||
session.response = session.NewErrorResponse(err)
|
session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err))
|
||||||
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
|
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
|
||||||
_ = writeResponse(session, false)
|
_ = writeResponse(session, false)
|
||||||
break readLoop
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -157,7 +166,7 @@ readLoop:
|
|||||||
|
|
||||||
if err = writeResponseWithHandler(session, opt); err != nil {
|
if err = writeResponseWithHandler(session, opt); err != nil {
|
||||||
handleError(opt, session, err)
|
handleError(opt, session, err)
|
||||||
break readLoop // close connection
|
break // close connection
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,6 +20,10 @@ var (
|
|||||||
ErrInvalidURL = errors.New("invalid URL")
|
ErrInvalidURL = errors.New("invalid URL")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func isWebsocketRequest(req *http.Request) bool {
|
||||||
|
return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket"
|
||||||
|
}
|
||||||
|
|
||||||
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
|
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
|
||||||
if body == nil {
|
if body == nil {
|
||||||
body = &bytes.Buffer{}
|
body = &bytes.Buffer{}
|
||||||
|
@ -2,7 +2,6 @@ package nat
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/Dreamacro/clash/common/generics/list"
|
"github.com/Dreamacro/clash/common/generics/list"
|
||||||
)
|
)
|
||||||
@ -25,7 +24,6 @@ type binding struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type table struct {
|
type table struct {
|
||||||
mu sync.Mutex
|
|
||||||
tuples map[tuple]*list.Element[*binding]
|
tuples map[tuple]*list.Element[*binding]
|
||||||
ports [portLength]*list.Element[*binding]
|
ports [portLength]*list.Element[*binding]
|
||||||
available *list.List[*binding]
|
available *list.List[*binding]
|
||||||
@ -39,13 +37,13 @@ func (t *table) tupleOf(port uint16) tuple {
|
|||||||
|
|
||||||
elm := t.ports[offset]
|
elm := t.ports[offset]
|
||||||
|
|
||||||
|
t.available.MoveToFront(elm)
|
||||||
|
|
||||||
return elm.Value.tuple
|
return elm.Value.tuple
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) portOf(tuple tuple) uint16 {
|
func (t *table) portOf(tuple tuple) uint16 {
|
||||||
t.mu.Lock()
|
|
||||||
elm := t.tuples[tuple]
|
elm := t.tuples[tuple]
|
||||||
t.mu.Unlock()
|
|
||||||
if elm == nil {
|
if elm == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@ -59,11 +57,8 @@ func (t *table) newConn(tuple tuple) uint16 {
|
|||||||
elm := t.available.Back()
|
elm := t.available.Back()
|
||||||
b := elm.Value
|
b := elm.Value
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
delete(t.tuples, b.tuple)
|
delete(t.tuples, b.tuple)
|
||||||
t.tuples[tuple] = elm
|
t.tuples[tuple] = elm
|
||||||
t.mu.Unlock()
|
|
||||||
|
|
||||||
b.tuple = tuple
|
b.tuple = tuple
|
||||||
|
|
||||||
t.available.MoveToFront(elm)
|
t.available.MoveToFront(elm)
|
||||||
@ -71,19 +66,6 @@ func (t *table) newConn(tuple tuple) uint16 {
|
|||||||
return portBegin + b.offset
|
return portBegin + b.offset
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *table) delete(tup tuple) {
|
|
||||||
t.mu.Lock()
|
|
||||||
elm := t.tuples[tup]
|
|
||||||
if elm == nil {
|
|
||||||
t.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(t.tuples, tup)
|
|
||||||
t.mu.Unlock()
|
|
||||||
|
|
||||||
t.available.MoveToBack(elm)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTable() *table {
|
func newTable() *table {
|
||||||
result := &table{
|
result := &table{
|
||||||
tuples: make(map[tuple]*list.Element[*binding], portLength),
|
tuples: make(map[tuple]*list.Element[*binding], portLength),
|
||||||
|
@ -16,8 +16,6 @@ type conn struct {
|
|||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
tuple tuple
|
tuple tuple
|
||||||
|
|
||||||
close func(tuple tuple)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCP) Accept() (net.Conn, error) {
|
func (t *TCP) Accept() (net.Conn, error) {
|
||||||
@ -39,9 +37,6 @@ func (t *TCP) Accept() (net.Conn, error) {
|
|||||||
return &conn{
|
return &conn{
|
||||||
Conn: c,
|
Conn: c,
|
||||||
tuple: tup,
|
tuple: tup,
|
||||||
close: func(tuple tuple) {
|
|
||||||
t.table.delete(tuple)
|
|
||||||
},
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,11 +52,6 @@ func (t *TCP) SetDeadline(time time.Time) error {
|
|||||||
return t.listener.SetDeadline(time)
|
return t.listener.SetDeadline(time)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Close() error {
|
|
||||||
c.close(c.tuple)
|
|
||||||
return c.Conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *conn) LocalAddr() net.Addr {
|
func (c *conn) LocalAddr() net.Addr {
|
||||||
return &net.TCPAddr{
|
return &net.TCPAddr{
|
||||||
IP: c.tuple.SourceAddr.Addr().AsSlice(),
|
IP: c.tuple.SourceAddr.Addr().AsSlice(),
|
||||||
|
Reference in New Issue
Block a user