From 22458ad0bed825881e742f5e4a4b867970b602bf Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Thu, 28 Apr 2022 00:46:47 +0800 Subject: [PATCH 1/3] Chore: mitm proxy with authenticate --- listener/http/proxy.go | 2 +- listener/http/upgrade.go | 63 ++++++++++----------------- listener/mitm/client.go | 80 +++++++++++++++++----------------- listener/mitm/proxy.go | 92 ++++++++++++++++++++-------------------- listener/mitm/server.go | 4 +- listener/mitm/utils.go | 5 --- 6 files changed, 111 insertions(+), 135 deletions(-) diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 0ec43dc7..a762a617 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -62,7 +62,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, request.RequestURI = "" if isUpgradeRequest(request) { - if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil { + if resp = HandleUpgrade(conn, nil, request, in); resp == nil { return // hijack connection } } diff --git a/listener/http/upgrade.go b/listener/http/upgrade.go index 7e53eecf..4737db48 100644 --- a/listener/http/upgrade.go +++ b/listener/http/upgrade.go @@ -1,8 +1,6 @@ package http import ( - "context" - "crypto/tls" "net" "net/http" "strings" @@ -18,58 +16,43 @@ func isUpgradeRequest(req *http.Request) bool { return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") } -func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { +func HandleUpgrade(localConn net.Conn, serverConn *N.BufferedConn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { removeProxyHeaders(request.Header) RemoveExtraHTTPHostPort(request) - address := request.Host - if _, _, err := net.SplitHostPort(address); err != nil { - port := "80" - if request.TLS != nil { - port = "443" + if serverConn == nil { + address := request.Host + if _, _, err := net.SplitHostPort(address); err != nil { + port := "80" + if request.TLS != nil { + port = "443" + } + address = net.JoinHostPort(address, port) } - address = net.JoinHostPort(address, port) - } - dstAddr := socks5.ParseAddr(address) - if dstAddr == nil { - return - } - - left, right := net.Pipe() - - in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right) - - var remoteServer *N.BufferedConn - if request.TLS != nil { - tlsConn := tls.Client(left, &tls.Config{ - ServerName: request.URL.Hostname(), - }) - - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() - if tlsConn.HandshakeContext(ctx) != nil { - _ = localConn.Close() - _ = left.Close() + dstAddr := socks5.ParseAddr(address) + if dstAddr == nil { return } - remoteServer = N.NewBufferedConn(tlsConn) - } else { - remoteServer = N.NewBufferedConn(left) + left, right := net.Pipe() + + in <- inbound.NewHTTP(dstAddr, localConn.RemoteAddr(), right) + + serverConn = N.NewBufferedConn(left) + + defer func() { + _ = serverConn.Close() + }() } - defer func() { - _ = remoteServer.Close() - }() - - err := request.Write(remoteServer) + err := request.Write(serverConn) if err != nil { _ = localConn.Close() return } - resp, err = http.ReadResponse(remoteServer.Reader(), request) + resp, err = http.ReadResponse(serverConn.Reader(), request) if err != nil { _ = localConn.Close() return @@ -88,7 +71,7 @@ func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, i return } - N.Relay(remoteServer, localConn) // blocking here + N.Relay(serverConn, localConn) // blocking here _ = localConn.Close() resp = nil } diff --git a/listener/mitm/client.go b/listener/mitm/client.go index a01c65d8..a2e95ced 100644 --- a/listener/mitm/client.go +++ b/listener/mitm/client.go @@ -3,53 +3,53 @@ package mitm import ( "context" "crypto/tls" - "errors" "net" "net/http" - "time" "github.com/Dreamacro/clash/adapter/inbound" + N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" ) -func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client { - return &http.Client{ - Transport: &http.Transport{ - // excepted HTTP/2 - TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), - // only needed 1 connection - MaxIdleConns: 1, - MaxIdleConnsPerHost: 1, - MaxConnsPerHost: 1, - IdleConnTimeout: 60 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{ - GetClientCertificate: func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, e error) { - return nil, ErrCertUnsupported - }, - }, - DialContext: func(context context.Context, network, address string) (net.Conn, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, errors.New("unsupported network " + network) - } - - dstAddr := socks5.ParseAddr(address) - if dstAddr == nil { - return nil, socks5.ErrAddressNotSupported - } - - left, right := net.Pipe() - - in <- inbound.NewMitm(dstAddr, source, userAgent, right) - - return left, nil - }, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Timeout: 120 * time.Second, +func getServerConn(serverConn *N.BufferedConn, request *http.Request, srcAddr net.Addr, in chan<- C.ConnContext) (*N.BufferedConn, error) { + if serverConn != nil { + return serverConn, nil } + + address := request.Host + if _, _, err := net.SplitHostPort(address); err != nil { + port := "80" + if request.TLS != nil { + port = "443" + } + address = net.JoinHostPort(address, port) + } + + dstAddr := socks5.ParseAddr(address) + if dstAddr == nil { + return nil, socks5.ErrAddressNotSupported + } + + left, right := net.Pipe() + + in <- inbound.NewMitm(dstAddr, srcAddr, request.Header.Get("User-Agent"), right) + + if request.TLS != nil { + tlsConn := tls.Client(left, &tls.Config{ + ServerName: request.TLS.ServerName, + }) + + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + if err := tlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + + serverConn = N.NewBufferedConn(tlsConn) + } else { + serverConn = N.NewBufferedConn(left) + } + + return serverConn, nil } diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index 09aecbb7..1fd5ec2e 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/tls" "encoding/pem" - "errors" "fmt" "io" "net" @@ -19,27 +18,31 @@ import ( H "github.com/Dreamacro/clash/listener/http" ) -func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { +func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { var ( - source net.Addr - client *http.Client + clientIP = netip.MustParseAddrPort(clientConn.RemoteAddr().String()).Addr() + sourceAddr net.Addr + serverConn *N.BufferedConn ) defer func() { - if client != nil { - client.CloseIdleConnections() + if serverConn != nil { + _ = serverConn.Close() } }() startOver: var conn *N.BufferedConn - if bufConn, ok := c.(*N.BufferedConn); ok { + if bufConn, ok := clientConn.(*N.BufferedConn); ok { conn = bufConn } else { - conn = N.NewBufferedConn(c) + conn = N.NewBufferedConn(clientConn) } trusted := cache == nil // disable authenticate if cache is nil + if !trusted { + trusted = clientIP.IsLoopback() + } readLoop: for { @@ -57,8 +60,8 @@ readLoop: session := newSession(conn, request, response) - source = parseSourceAddress(session.request, c.RemoteAddr(), source) - session.request.RemoteAddr = source.String() + sourceAddr = parseSourceAddress(session.request, clientConn.RemoteAddr(), sourceAddr) + session.request.RemoteAddr = sourceAddr.String() if !trusted { session.response = H.Authenticate(session.request, cache) @@ -95,15 +98,15 @@ readLoop: break // close connection } - c = tlsConn + clientConn = tlsConn } else { - c = conn + clientConn = conn } goto startOver } - prepareRequest(c, session.request) + prepareRequest(clientConn, session.request) // hijack api if session.request.URL.Hostname() == opt.ApiHost { @@ -115,17 +118,22 @@ readLoop: // forward websocket if isWebsocketRequest(request) { + serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in) + if err != nil { + break + } + session.request.RequestURI = "" - if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil { + if session.response = H.HandleUpgrade(conn, serverConn, request, in); session.response == nil { return // hijack connection } } - H.RemoveHopByHopHeaders(session.request.Header) - H.RemoveExtraHTTPHostPort(session.request) + if session.response == nil { + H.RemoveHopByHopHeaders(session.request.Header) + H.RemoveExtraHTTPHostPort(session.request) - // hijack custom request and write back custom response if necessary - if opt.Handler != nil && session.response == nil { + // hijack custom request and write back custom response if necessary newReq, newRes := opt.Handler.HandleRequest(session) if newReq != nil { session.request = newReq @@ -139,26 +147,26 @@ readLoop: } continue } - } - if session.response == nil { session.request.RequestURI = "" if session.request.URL.Host == "" { session.response = session.NewErrorResponse(ErrInvalidURL) } else { - client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) + serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in) + if err != nil { + break + } // send the request to remote server - session.response, err = client.Do(session.request) - + err = session.request.Write(serverConn) if err != nil { - handleError(opt, session, err) - session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err)) - if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { - _ = writeResponse(session, false) - break - } + break + } + + session.response, err = http.ReadResponse(serverConn.Reader(), request) + if err != nil { + break } } } @@ -174,11 +182,9 @@ readLoop: } func writeResponseWithHandler(session *Session, opt *Option) error { - if opt.Handler != nil { - res := opt.Handler.HandleResponse(session) - if res != nil { - session.response = res - } + res := opt.Handler.HandleResponse(session) + if res != nil { + session.response = res } return writeResponse(session, true) @@ -220,10 +226,8 @@ func handleApiRequest(session *Session, opt *Option) error { ` - if opt.Handler != nil { - if opt.Handler.HandleApiRequest(session) { - return nil - } + if opt.Handler.HandleApiRequest(session) { + return nil } b = fmt.Sprintf(b, session.request.URL.Path) @@ -243,9 +247,7 @@ func handleError(opt *Option, session *Session, err error) { _ = session.response.Body.Close() }() } - if opt.Handler != nil { - opt.Handler.HandleError(session, err) - } + opt.Handler.HandleError(session, err) } func prepareRequest(conn net.Conn, request *http.Request) { @@ -297,10 +299,6 @@ func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr } } -func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client { - if cli != nil { - return cli - } - - return newClient(source, req.Header.Get("User-Agent"), in) +func isWebsocketRequest(req *http.Request) bool { + return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" } diff --git a/listener/mitm/server.go b/listener/mitm/server.go index d7699b81..143199b3 100644 --- a/listener/mitm/server.go +++ b/listener/mitm/server.go @@ -54,7 +54,7 @@ func (l *Listener) Close() error { // New the MITM proxy actually is a type of HTTP proxy func New(option *Option, in chan<- C.ConnContext) (*Listener, error) { - return NewWithAuthenticate(option, in, false) + return NewWithAuthenticate(option, in, true) } func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate bool) (*Listener, error) { @@ -65,7 +65,7 @@ func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate b var c *cache.Cache[string, bool] if authenticate { - c = cache.New[string, bool](time.Second * 30) + c = cache.New[string, bool](time.Second * 90) } hl := &Listener{ diff --git a/listener/mitm/utils.go b/listener/mitm/utils.go index d7c10a2a..8c60c7a3 100644 --- a/listener/mitm/utils.go +++ b/listener/mitm/utils.go @@ -15,15 +15,10 @@ import ( ) var ( - ErrCertUnsupported = errors.New("tls: client cert unsupported") ErrInvalidResponse = errors.New("invalid response") ErrInvalidURL = errors.New("invalid URL") ) -func isWebsocketRequest(req *http.Request) bool { - return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" -} - func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { if body == nil { body = &bytes.Buffer{} From da92601902b3e86bb1ba7d5febf86e9f9fabe6e4 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Thu, 28 Apr 2022 06:46:57 +0800 Subject: [PATCH 2/3] Fix: mitm proxy should handle none-http(s) protocol over tcp --- listener/mitm/client.go | 2 +- listener/mitm/hack.go | 9 +++++ listener/mitm/proxy.go | 84 +++++++++++++++++++++++++++++------------ 3 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 listener/mitm/hack.go diff --git a/listener/mitm/client.go b/listener/mitm/client.go index a2e95ced..a6afa1a7 100644 --- a/listener/mitm/client.go +++ b/listener/mitm/client.go @@ -17,7 +17,7 @@ func getServerConn(serverConn *N.BufferedConn, request *http.Request, srcAddr ne return serverConn, nil } - address := request.Host + address := request.URL.Host if _, _, err := net.SplitHostPort(address); err != nil { port := "80" if request.TLS != nil { diff --git a/listener/mitm/hack.go b/listener/mitm/hack.go new file mode 100644 index 00000000..caff3b2c --- /dev/null +++ b/listener/mitm/hack.go @@ -0,0 +1,9 @@ +package mitm + +import ( + _ "net/http" + _ "unsafe" +) + +//go:linkname validMethod net/http.validMethod +func validMethod(method string) bool diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index 1fd5ec2e..63e81175 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -1,9 +1,11 @@ package mitm import ( + "bufio" "bytes" "crypto/tls" "encoding/pem" + "errors" "fmt" "io" "net" @@ -18,11 +20,12 @@ import ( H "github.com/Dreamacro/clash/listener/http" ) -func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { +func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { var ( - clientIP = netip.MustParseAddrPort(clientConn.RemoteAddr().String()).Addr() + clientIP = netip.MustParseAddrPort(c.RemoteAddr().String()).Addr() sourceAddr net.Addr serverConn *N.BufferedConn + connState *tls.ConnectionState ) defer func() { @@ -31,17 +34,11 @@ func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache } }() -startOver: - var conn *N.BufferedConn - if bufConn, ok := clientConn.(*N.BufferedConn); ok { - conn = bufConn - } else { - conn = N.NewBufferedConn(clientConn) - } + conn := N.NewBufferedConn(c) trusted := cache == nil // disable authenticate if cache is nil if !trusted { - trusted = clientIP.IsLoopback() + trusted = clientIP.IsLoopback() || clientIP.IsUnspecified() } readLoop: @@ -60,7 +57,7 @@ readLoop: session := newSession(conn, request, response) - sourceAddr = parseSourceAddress(session.request, clientConn.RemoteAddr(), sourceAddr) + sourceAddr = parseSourceAddress(session.request, conn.RemoteAddr(), sourceAddr) session.request.RemoteAddr = sourceAddr.String() if !trusted { @@ -81,9 +78,9 @@ readLoop: goto readLoop } - b, err := conn.Peek(1) - if err != nil { - handleError(opt, session, err) + b, err1 := conn.Peek(1) + if err1 != nil { + handleError(opt, session, err1) break // close connection } @@ -98,15 +95,49 @@ readLoop: break // close connection } - clientConn = tlsConn - } else { - clientConn = conn + cs := tlsConn.ConnectionState() + connState = &cs + + conn = N.NewBufferedConn(tlsConn) } - goto startOver + if strings.HasSuffix(session.request.URL.Host, ":443") { + goto readLoop + } + + var noErr bool + + buf, err2 := conn.Peek(7) + if err2 != nil { + if err2 == bufio.ErrBufferFull || errors.Is(err2, io.EOF) { + noErr = true + } else { + handleError(opt, session, err2) + break // close connection + } + } + + // others protocol over tcp + if noErr || !isHTTPTraffic(buf) { + if connState != nil { + session.request.TLS = connState + } + + serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in) + if err != nil { + break + } + + _ = conn.SetReadDeadline(time.Time{}) + + N.Relay(serverConn, conn) + return // hijack connection + } + + goto readLoop } - prepareRequest(clientConn, session.request) + prepareRequest(connState, session.request) // hijack api if session.request.URL.Hostname() == opt.ApiHost { @@ -250,7 +281,7 @@ func handleError(opt *Option, session *Session, err error) { opt.Handler.HandleError(session, err) } -func prepareRequest(conn net.Conn, request *http.Request) { +func prepareRequest(connState *tls.ConnectionState, request *http.Request) { host := request.Header.Get("Host") if host != "" { request.Host = host @@ -264,10 +295,8 @@ func prepareRequest(conn net.Conn, request *http.Request) { request.URL.Scheme = "http" } - if tlsConn, ok := conn.(*tls.Conn); ok { - cs := tlsConn.ConnectionState() - request.TLS = &cs - + if connState != nil { + request.TLS = connState request.URL.Scheme = "https" } @@ -300,5 +329,10 @@ func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr } func isWebsocketRequest(req *http.Request) bool { - return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" + return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") +} + +func isHTTPTraffic(buf []byte) bool { + method, _, _ := strings.Cut(string(buf), " ") + return validMethod(method) } From 7e85d5a954581e45466cb3438ba2805c5ab18028 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Fri, 29 Apr 2022 05:15:32 +0800 Subject: [PATCH 3/3] Fix: tls handshake with timeout --- listener/mitm/proxy.go | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index 63e81175..6e73eb03 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -3,14 +3,15 @@ package mitm import ( "bufio" "bytes" + "context" "crypto/tls" "encoding/pem" - "errors" "fmt" "io" "net" "net/http" "net/netip" + "os" "strings" "time" @@ -68,6 +69,11 @@ readLoop: if trusted { if session.request.Method == http.MethodConnect { + if session.request.ProtoMajor > 1 { + session.request.ProtoMajor = 1 + session.request.ProtoMinor = 1 + } + // Manual writing to support CONNECT for http 1.0 (workaround for uplay client) if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil { handleError(opt, session, err) @@ -88,12 +94,15 @@ readLoop: if b[0] == 0x16 { tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Hostname())) - // Handshake with the local client - if err = tlsConn.Handshake(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + // handshake with the local client + if err = tlsConn.HandshakeContext(ctx); err != nil { + cancel() session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err)) _ = writeResponse(session, false) break // close connection } + cancel() cs := tlsConn.ConnectionState() connState = &cs @@ -105,20 +114,20 @@ readLoop: goto readLoop } - var noErr bool + if conn.SetReadDeadline(time.Now().Add(time.Second)) != nil { + break + } buf, err2 := conn.Peek(7) if err2 != nil { - if err2 == bufio.ErrBufferFull || errors.Is(err2, io.EOF) { - noErr = true - } else { + if err2 != bufio.ErrBufferFull && !os.IsTimeout(err2) { handleError(opt, session, err2) break // close connection } } // others protocol over tcp - if noErr || !isHTTPTraffic(buf) { + if !isHTTPTraffic(buf) { if connState != nil { session.request.TLS = connState } @@ -128,7 +137,9 @@ readLoop: break } - _ = conn.SetReadDeadline(time.Time{}) + if conn.SetReadDeadline(time.Time{}) != nil { + break + } N.Relay(serverConn, conn) return // hijack connection