From da92601902b3e86bb1ba7d5febf86e9f9fabe6e4 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Thu, 28 Apr 2022 06:46:57 +0800 Subject: [PATCH] Fix: mitm proxy should handle none-http(s) protocol over tcp --- listener/mitm/client.go | 2 +- listener/mitm/hack.go | 9 +++++ listener/mitm/proxy.go | 84 +++++++++++++++++++++++++++++------------ 3 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 listener/mitm/hack.go diff --git a/listener/mitm/client.go b/listener/mitm/client.go index a2e95ced..a6afa1a7 100644 --- a/listener/mitm/client.go +++ b/listener/mitm/client.go @@ -17,7 +17,7 @@ func getServerConn(serverConn *N.BufferedConn, request *http.Request, srcAddr ne return serverConn, nil } - address := request.Host + address := request.URL.Host if _, _, err := net.SplitHostPort(address); err != nil { port := "80" if request.TLS != nil { diff --git a/listener/mitm/hack.go b/listener/mitm/hack.go new file mode 100644 index 00000000..caff3b2c --- /dev/null +++ b/listener/mitm/hack.go @@ -0,0 +1,9 @@ +package mitm + +import ( + _ "net/http" + _ "unsafe" +) + +//go:linkname validMethod net/http.validMethod +func validMethod(method string) bool diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index 1fd5ec2e..63e81175 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -1,9 +1,11 @@ package mitm import ( + "bufio" "bytes" "crypto/tls" "encoding/pem" + "errors" "fmt" "io" "net" @@ -18,11 +20,12 @@ import ( H "github.com/Dreamacro/clash/listener/http" ) -func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { +func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { var ( - clientIP = netip.MustParseAddrPort(clientConn.RemoteAddr().String()).Addr() + clientIP = netip.MustParseAddrPort(c.RemoteAddr().String()).Addr() sourceAddr net.Addr serverConn *N.BufferedConn + connState *tls.ConnectionState ) defer func() { @@ -31,17 +34,11 @@ func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache } }() -startOver: - var conn *N.BufferedConn - if bufConn, ok := clientConn.(*N.BufferedConn); ok { - conn = bufConn - } else { - conn = N.NewBufferedConn(clientConn) - } + conn := N.NewBufferedConn(c) trusted := cache == nil // disable authenticate if cache is nil if !trusted { - trusted = clientIP.IsLoopback() + trusted = clientIP.IsLoopback() || clientIP.IsUnspecified() } readLoop: @@ -60,7 +57,7 @@ readLoop: session := newSession(conn, request, response) - sourceAddr = parseSourceAddress(session.request, clientConn.RemoteAddr(), sourceAddr) + sourceAddr = parseSourceAddress(session.request, conn.RemoteAddr(), sourceAddr) session.request.RemoteAddr = sourceAddr.String() if !trusted { @@ -81,9 +78,9 @@ readLoop: goto readLoop } - b, err := conn.Peek(1) - if err != nil { - handleError(opt, session, err) + b, err1 := conn.Peek(1) + if err1 != nil { + handleError(opt, session, err1) break // close connection } @@ -98,15 +95,49 @@ readLoop: break // close connection } - clientConn = tlsConn - } else { - clientConn = conn + cs := tlsConn.ConnectionState() + connState = &cs + + conn = N.NewBufferedConn(tlsConn) } - goto startOver + if strings.HasSuffix(session.request.URL.Host, ":443") { + goto readLoop + } + + var noErr bool + + buf, err2 := conn.Peek(7) + if err2 != nil { + if err2 == bufio.ErrBufferFull || errors.Is(err2, io.EOF) { + noErr = true + } else { + handleError(opt, session, err2) + break // close connection + } + } + + // others protocol over tcp + if noErr || !isHTTPTraffic(buf) { + if connState != nil { + session.request.TLS = connState + } + + serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in) + if err != nil { + break + } + + _ = conn.SetReadDeadline(time.Time{}) + + N.Relay(serverConn, conn) + return // hijack connection + } + + goto readLoop } - prepareRequest(clientConn, session.request) + prepareRequest(connState, session.request) // hijack api if session.request.URL.Hostname() == opt.ApiHost { @@ -250,7 +281,7 @@ func handleError(opt *Option, session *Session, err error) { opt.Handler.HandleError(session, err) } -func prepareRequest(conn net.Conn, request *http.Request) { +func prepareRequest(connState *tls.ConnectionState, request *http.Request) { host := request.Header.Get("Host") if host != "" { request.Host = host @@ -264,10 +295,8 @@ func prepareRequest(conn net.Conn, request *http.Request) { request.URL.Scheme = "http" } - if tlsConn, ok := conn.(*tls.Conn); ok { - cs := tlsConn.ConnectionState() - request.TLS = &cs - + if connState != nil { + request.TLS = connState request.URL.Scheme = "https" } @@ -300,5 +329,10 @@ func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr } func isWebsocketRequest(req *http.Request) bool { - return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" + return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") +} + +func isHTTPTraffic(buf []byte) bool { + method, _, _ := strings.Cut(string(buf), " ") + return validMethod(method) }