From 30025c02414af8b2cfa697ddfad5e69915f315ae Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Wed, 27 Apr 2022 05:14:03 +0800 Subject: [PATCH] Fix: mitm proxy should forward websocket --- common/cert/cert.go | 5 +- listener/http/proxy.go | 2 +- listener/http/upgrade.go | 48 ++++++++++---- listener/mitm/proxy.go | 63 +++++++++++-------- listener/mitm/utils.go | 4 ++ listener/tun/ipstack/system/mars/nat/table.go | 22 +------ listener/tun/ipstack/system/mars/nat/tcp.go | 10 --- 7 files changed, 82 insertions(+), 72 deletions(-) diff --git a/common/cert/cert.go b/common/cert/cert.go index 274ef55e..29bec9de 100644 --- a/common/cert/cert.go +++ b/common/cert/cert.go @@ -215,11 +215,10 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica BasicConstraintsValid: true, NotBefore: time.Now().Add(-c.validity), NotAfter: time.Now().Add(c.validity), + DNSNames: dnsNames, + IPAddresses: ips, } - tmpl.DNSNames = dnsNames - tmpl.IPAddresses = ips - raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey) if err != nil { return nil, err diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 3bc92e13..0ec43dc7 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -62,7 +62,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, request.RequestURI = "" if isUpgradeRequest(request) { - if resp = handleUpgrade(conn, request, in); resp == nil { + if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil { return // hijack connection } } diff --git a/listener/http/upgrade.go b/listener/http/upgrade.go index f770eb25..7e53eecf 100644 --- a/listener/http/upgrade.go +++ b/listener/http/upgrade.go @@ -1,6 +1,8 @@ package http import ( + "context" + "crypto/tls" "net" "net/http" "strings" @@ -16,13 +18,17 @@ func isUpgradeRequest(req *http.Request) bool { return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") } -func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { +func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { removeProxyHeaders(request.Header) RemoveExtraHTTPHostPort(request) address := request.Host if _, _, err := net.SplitHostPort(address); err != nil { - address = net.JoinHostPort(address, "80") + port := "80" + if request.TLS != nil { + port = "443" + } + address = net.JoinHostPort(address, port) } dstAddr := socks5.ParseAddr(address) @@ -32,38 +38,58 @@ func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext left, right := net.Pipe() - in <- inbound.NewHTTP(dstAddr, conn.RemoteAddr(), right) + in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right) + + var remoteServer *N.BufferedConn + if request.TLS != nil { + tlsConn := tls.Client(left, &tls.Config{ + ServerName: request.URL.Hostname(), + }) + + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + if tlsConn.HandshakeContext(ctx) != nil { + _ = localConn.Close() + _ = left.Close() + return + } + + remoteServer = N.NewBufferedConn(tlsConn) + } else { + remoteServer = N.NewBufferedConn(left) + } - bufferedLeft := N.NewBufferedConn(left) defer func() { - _ = bufferedLeft.Close() + _ = remoteServer.Close() }() - err := request.Write(bufferedLeft) + err := request.Write(remoteServer) if err != nil { + _ = localConn.Close() return } - resp, err = http.ReadResponse(bufferedLeft.Reader(), request) + resp, err = http.ReadResponse(remoteServer.Reader(), request) if err != nil { + _ = localConn.Close() return } if resp.StatusCode == http.StatusSwitchingProtocols { removeProxyHeaders(resp.Header) - err = conn.SetReadDeadline(time.Time{}) + err = localConn.SetReadDeadline(time.Time{}) // set to not time out if err != nil { return } - err = resp.Write(conn) + err = resp.Write(localConn) if err != nil { return } - N.Relay(bufferedLeft, conn) - _ = conn.Close() + N.Relay(remoteServer, localConn) // blocking here + _ = localConn.Close() resp = nil } return diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index eb8876c0..09aecbb7 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -45,12 +45,12 @@ readLoop: for { // use SetReadDeadline instead of Proxy-Connection keep-alive if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil { - break readLoop + break } request, err := H.ReadRequest(conn.Reader()) if err != nil { - break readLoop + break } var response *http.Response @@ -71,7 +71,7 @@ readLoop: // Manual writing to support CONNECT for http 1.0 (workaround for uplay client) if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil { handleError(opt, session, err) - break readLoop // close connection + break // close connection } if strings.HasSuffix(session.request.URL.Host, ":80") { @@ -81,7 +81,7 @@ readLoop: b, err := conn.Peek(1) if err != nil { handleError(opt, session, err) - break readLoop // close connection + break // close connection } // TLS handshake. @@ -92,7 +92,7 @@ readLoop: if err = tlsConn.Handshake(); err != nil { session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err)) _ = writeResponse(session, false) - break readLoop // close connection + break // close connection } c = tlsConn @@ -105,20 +105,27 @@ readLoop: prepareRequest(c, session.request) - H.RemoveHopByHopHeaders(session.request.Header) - H.RemoveExtraHTTPHostPort(session.request) - // hijack api if session.request.URL.Hostname() == opt.ApiHost { if err = handleApiRequest(session, opt); err != nil { handleError(opt, session, err) - break readLoop } - return + break } + // forward websocket + if isWebsocketRequest(request) { + session.request.RequestURI = "" + if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil { + return // hijack connection + } + } + + H.RemoveHopByHopHeaders(session.request.Header) + H.RemoveExtraHTTPHostPort(session.request) + // hijack custom request and write back custom response if necessary - if opt.Handler != nil { + if opt.Handler != nil && session.response == nil { newReq, newRes := opt.Handler.HandleRequest(session) if newReq != nil { session.request = newReq @@ -128,28 +135,30 @@ readLoop: if err = writeResponse(session, false); err != nil { handleError(opt, session, err) - break readLoop + break } - return + continue } } - session.request.RequestURI = "" + if session.response == nil { + session.request.RequestURI = "" - if session.request.URL.Host == "" { - session.response = session.NewErrorResponse(ErrInvalidURL) - } else { - client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) + if session.request.URL.Host == "" { + session.response = session.NewErrorResponse(ErrInvalidURL) + } else { + client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) - // send the request to remote server - session.response, err = client.Do(session.request) + // send the request to remote server + session.response, err = client.Do(session.request) - if err != nil { - handleError(opt, session, err) - session.response = session.NewErrorResponse(err) - if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { - _ = writeResponse(session, false) - break readLoop + if err != nil { + handleError(opt, session, err) + session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err)) + if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { + _ = writeResponse(session, false) + break + } } } } @@ -157,7 +166,7 @@ readLoop: if err = writeResponseWithHandler(session, opt); err != nil { handleError(opt, session, err) - break readLoop // close connection + break // close connection } } diff --git a/listener/mitm/utils.go b/listener/mitm/utils.go index a84c75cf..d7c10a2a 100644 --- a/listener/mitm/utils.go +++ b/listener/mitm/utils.go @@ -20,6 +20,10 @@ var ( ErrInvalidURL = errors.New("invalid URL") ) +func isWebsocketRequest(req *http.Request) bool { + return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" +} + func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { if body == nil { body = &bytes.Buffer{} diff --git a/listener/tun/ipstack/system/mars/nat/table.go b/listener/tun/ipstack/system/mars/nat/table.go index 9c1b32cd..38b7d6c6 100644 --- a/listener/tun/ipstack/system/mars/nat/table.go +++ b/listener/tun/ipstack/system/mars/nat/table.go @@ -2,7 +2,6 @@ package nat import ( "net/netip" - "sync" "github.com/Dreamacro/clash/common/generics/list" ) @@ -25,7 +24,6 @@ type binding struct { } type table struct { - mu sync.Mutex tuples map[tuple]*list.Element[*binding] ports [portLength]*list.Element[*binding] available *list.List[*binding] @@ -39,13 +37,13 @@ func (t *table) tupleOf(port uint16) tuple { elm := t.ports[offset] + t.available.MoveToFront(elm) + return elm.Value.tuple } func (t *table) portOf(tuple tuple) uint16 { - t.mu.Lock() elm := t.tuples[tuple] - t.mu.Unlock() if elm == nil { return 0 } @@ -59,11 +57,8 @@ func (t *table) newConn(tuple tuple) uint16 { elm := t.available.Back() b := elm.Value - t.mu.Lock() delete(t.tuples, b.tuple) t.tuples[tuple] = elm - t.mu.Unlock() - b.tuple = tuple t.available.MoveToFront(elm) @@ -71,19 +66,6 @@ func (t *table) newConn(tuple tuple) uint16 { return portBegin + b.offset } -func (t *table) delete(tup tuple) { - t.mu.Lock() - elm := t.tuples[tup] - if elm == nil { - t.mu.Unlock() - return - } - delete(t.tuples, tup) - t.mu.Unlock() - - t.available.MoveToBack(elm) -} - func newTable() *table { result := &table{ tuples: make(map[tuple]*list.Element[*binding], portLength), diff --git a/listener/tun/ipstack/system/mars/nat/tcp.go b/listener/tun/ipstack/system/mars/nat/tcp.go index 48ad3e43..cc0abe7d 100644 --- a/listener/tun/ipstack/system/mars/nat/tcp.go +++ b/listener/tun/ipstack/system/mars/nat/tcp.go @@ -16,8 +16,6 @@ type conn struct { net.Conn tuple tuple - - close func(tuple tuple) } func (t *TCP) Accept() (net.Conn, error) { @@ -39,9 +37,6 @@ func (t *TCP) Accept() (net.Conn, error) { return &conn{ Conn: c, tuple: tup, - close: func(tuple tuple) { - t.table.delete(tuple) - }, }, nil } @@ -57,11 +52,6 @@ func (t *TCP) SetDeadline(time time.Time) error { return t.listener.SetDeadline(time) } -func (c *conn) Close() error { - c.close(c.tuple) - return c.Conn.Close() -} - func (c *conn) LocalAddr() net.Addr { return &net.TCPAddr{ IP: c.tuple.SourceAddr.Addr().AsSlice(),