diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 0ec43dc7..a762a617 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -62,7 +62,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, request.RequestURI = "" if isUpgradeRequest(request) { - if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil { + if resp = HandleUpgrade(conn, nil, request, in); resp == nil { return // hijack connection } } diff --git a/listener/http/upgrade.go b/listener/http/upgrade.go index 7e53eecf..4737db48 100644 --- a/listener/http/upgrade.go +++ b/listener/http/upgrade.go @@ -1,8 +1,6 @@ package http import ( - "context" - "crypto/tls" "net" "net/http" "strings" @@ -18,58 +16,43 @@ func isUpgradeRequest(req *http.Request) bool { return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") } -func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { +func HandleUpgrade(localConn net.Conn, serverConn *N.BufferedConn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { removeProxyHeaders(request.Header) RemoveExtraHTTPHostPort(request) - address := request.Host - if _, _, err := net.SplitHostPort(address); err != nil { - port := "80" - if request.TLS != nil { - port = "443" + if serverConn == nil { + address := request.Host + if _, _, err := net.SplitHostPort(address); err != nil { + port := "80" + if request.TLS != nil { + port = "443" + } + address = net.JoinHostPort(address, port) } - address = net.JoinHostPort(address, port) - } - dstAddr := socks5.ParseAddr(address) - if dstAddr == nil { - return - } - - left, right := net.Pipe() - - in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right) - - var remoteServer *N.BufferedConn - if request.TLS != nil { - tlsConn := tls.Client(left, &tls.Config{ - ServerName: request.URL.Hostname(), - }) - - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() - if tlsConn.HandshakeContext(ctx) != nil { - _ = localConn.Close() - _ = left.Close() + dstAddr := socks5.ParseAddr(address) + if dstAddr == nil { return } - remoteServer = N.NewBufferedConn(tlsConn) - } else { - remoteServer = N.NewBufferedConn(left) + left, right := net.Pipe() + + in <- inbound.NewHTTP(dstAddr, localConn.RemoteAddr(), right) + + serverConn = N.NewBufferedConn(left) + + defer func() { + _ = serverConn.Close() + }() } - defer func() { - _ = remoteServer.Close() - }() - - err := request.Write(remoteServer) + err := request.Write(serverConn) if err != nil { _ = localConn.Close() return } - resp, err = http.ReadResponse(remoteServer.Reader(), request) + resp, err = http.ReadResponse(serverConn.Reader(), request) if err != nil { _ = localConn.Close() return @@ -88,7 +71,7 @@ func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, i return } - N.Relay(remoteServer, localConn) // blocking here + N.Relay(serverConn, localConn) // blocking here _ = localConn.Close() resp = nil } diff --git a/listener/mitm/client.go b/listener/mitm/client.go index a01c65d8..a2e95ced 100644 --- a/listener/mitm/client.go +++ b/listener/mitm/client.go @@ -3,53 +3,53 @@ package mitm import ( "context" "crypto/tls" - "errors" "net" "net/http" - "time" "github.com/Dreamacro/clash/adapter/inbound" + N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" ) -func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client { - return &http.Client{ - Transport: &http.Transport{ - // excepted HTTP/2 - TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), - // only needed 1 connection - MaxIdleConns: 1, - MaxIdleConnsPerHost: 1, - MaxConnsPerHost: 1, - IdleConnTimeout: 60 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{ - GetClientCertificate: func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, e error) { - return nil, ErrCertUnsupported - }, - }, - DialContext: func(context context.Context, network, address string) (net.Conn, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, errors.New("unsupported network " + network) - } - - dstAddr := socks5.ParseAddr(address) - if dstAddr == nil { - return nil, socks5.ErrAddressNotSupported - } - - left, right := net.Pipe() - - in <- inbound.NewMitm(dstAddr, source, userAgent, right) - - return left, nil - }, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Timeout: 120 * time.Second, +func getServerConn(serverConn *N.BufferedConn, request *http.Request, srcAddr net.Addr, in chan<- C.ConnContext) (*N.BufferedConn, error) { + if serverConn != nil { + return serverConn, nil } + + address := request.Host + if _, _, err := net.SplitHostPort(address); err != nil { + port := "80" + if request.TLS != nil { + port = "443" + } + address = net.JoinHostPort(address, port) + } + + dstAddr := socks5.ParseAddr(address) + if dstAddr == nil { + return nil, socks5.ErrAddressNotSupported + } + + left, right := net.Pipe() + + in <- inbound.NewMitm(dstAddr, srcAddr, request.Header.Get("User-Agent"), right) + + if request.TLS != nil { + tlsConn := tls.Client(left, &tls.Config{ + ServerName: request.TLS.ServerName, + }) + + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + if err := tlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + + serverConn = N.NewBufferedConn(tlsConn) + } else { + serverConn = N.NewBufferedConn(left) + } + + return serverConn, nil } diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index 09aecbb7..1fd5ec2e 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/tls" "encoding/pem" - "errors" "fmt" "io" "net" @@ -19,27 +18,31 @@ import ( H "github.com/Dreamacro/clash/listener/http" ) -func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { +func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { var ( - source net.Addr - client *http.Client + clientIP = netip.MustParseAddrPort(clientConn.RemoteAddr().String()).Addr() + sourceAddr net.Addr + serverConn *N.BufferedConn ) defer func() { - if client != nil { - client.CloseIdleConnections() + if serverConn != nil { + _ = serverConn.Close() } }() startOver: var conn *N.BufferedConn - if bufConn, ok := c.(*N.BufferedConn); ok { + if bufConn, ok := clientConn.(*N.BufferedConn); ok { conn = bufConn } else { - conn = N.NewBufferedConn(c) + conn = N.NewBufferedConn(clientConn) } trusted := cache == nil // disable authenticate if cache is nil + if !trusted { + trusted = clientIP.IsLoopback() + } readLoop: for { @@ -57,8 +60,8 @@ readLoop: session := newSession(conn, request, response) - source = parseSourceAddress(session.request, c.RemoteAddr(), source) - session.request.RemoteAddr = source.String() + sourceAddr = parseSourceAddress(session.request, clientConn.RemoteAddr(), sourceAddr) + session.request.RemoteAddr = sourceAddr.String() if !trusted { session.response = H.Authenticate(session.request, cache) @@ -95,15 +98,15 @@ readLoop: break // close connection } - c = tlsConn + clientConn = tlsConn } else { - c = conn + clientConn = conn } goto startOver } - prepareRequest(c, session.request) + prepareRequest(clientConn, session.request) // hijack api if session.request.URL.Hostname() == opt.ApiHost { @@ -115,17 +118,22 @@ readLoop: // forward websocket if isWebsocketRequest(request) { + serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in) + if err != nil { + break + } + session.request.RequestURI = "" - if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil { + if session.response = H.HandleUpgrade(conn, serverConn, request, in); session.response == nil { return // hijack connection } } - H.RemoveHopByHopHeaders(session.request.Header) - H.RemoveExtraHTTPHostPort(session.request) + if session.response == nil { + H.RemoveHopByHopHeaders(session.request.Header) + H.RemoveExtraHTTPHostPort(session.request) - // hijack custom request and write back custom response if necessary - if opt.Handler != nil && session.response == nil { + // hijack custom request and write back custom response if necessary newReq, newRes := opt.Handler.HandleRequest(session) if newReq != nil { session.request = newReq @@ -139,26 +147,26 @@ readLoop: } continue } - } - if session.response == nil { session.request.RequestURI = "" if session.request.URL.Host == "" { session.response = session.NewErrorResponse(ErrInvalidURL) } else { - client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) + serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in) + if err != nil { + break + } // send the request to remote server - session.response, err = client.Do(session.request) - + err = session.request.Write(serverConn) if err != nil { - handleError(opt, session, err) - session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err)) - if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { - _ = writeResponse(session, false) - break - } + break + } + + session.response, err = http.ReadResponse(serverConn.Reader(), request) + if err != nil { + break } } } @@ -174,11 +182,9 @@ readLoop: } func writeResponseWithHandler(session *Session, opt *Option) error { - if opt.Handler != nil { - res := opt.Handler.HandleResponse(session) - if res != nil { - session.response = res - } + res := opt.Handler.HandleResponse(session) + if res != nil { + session.response = res } return writeResponse(session, true) @@ -220,10 +226,8 @@ func handleApiRequest(session *Session, opt *Option) error { ` - if opt.Handler != nil { - if opt.Handler.HandleApiRequest(session) { - return nil - } + if opt.Handler.HandleApiRequest(session) { + return nil } b = fmt.Sprintf(b, session.request.URL.Path) @@ -243,9 +247,7 @@ func handleError(opt *Option, session *Session, err error) { _ = session.response.Body.Close() }() } - if opt.Handler != nil { - opt.Handler.HandleError(session, err) - } + opt.Handler.HandleError(session, err) } func prepareRequest(conn net.Conn, request *http.Request) { @@ -297,10 +299,6 @@ func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr } } -func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client { - if cli != nil { - return cli - } - - return newClient(source, req.Header.Get("User-Agent"), in) +func isWebsocketRequest(req *http.Request) bool { + return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" } diff --git a/listener/mitm/server.go b/listener/mitm/server.go index d7699b81..143199b3 100644 --- a/listener/mitm/server.go +++ b/listener/mitm/server.go @@ -54,7 +54,7 @@ func (l *Listener) Close() error { // New the MITM proxy actually is a type of HTTP proxy func New(option *Option, in chan<- C.ConnContext) (*Listener, error) { - return NewWithAuthenticate(option, in, false) + return NewWithAuthenticate(option, in, true) } func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate bool) (*Listener, error) { @@ -65,7 +65,7 @@ func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate b var c *cache.Cache[string, bool] if authenticate { - c = cache.New[string, bool](time.Second * 30) + c = cache.New[string, bool](time.Second * 90) } hl := &Listener{ diff --git a/listener/mitm/utils.go b/listener/mitm/utils.go index d7c10a2a..8c60c7a3 100644 --- a/listener/mitm/utils.go +++ b/listener/mitm/utils.go @@ -15,15 +15,10 @@ import ( ) var ( - ErrCertUnsupported = errors.New("tls: client cert unsupported") ErrInvalidResponse = errors.New("invalid response") ErrInvalidURL = errors.New("invalid URL") ) -func isWebsocketRequest(req *http.Request) bool { - return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" -} - func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { if body == nil { body = &bytes.Buffer{}