diff --git a/common/net/relay.go b/common/net/relay.go index e7157639..6035a412 100644 --- a/common/net/relay.go +++ b/common/net/relay.go @@ -12,19 +12,28 @@ import ( func Relay(leftConn, rightConn net.Conn) { ch := make(chan error) + tcpKeepAlive(leftConn) + tcpKeepAlive(rightConn) + go func() { buf := pool.Get(pool.RelayBufferSize) // Wrapping to avoid using *net.TCPConn.(ReadFrom) // See also https://github.com/Dreamacro/clash/pull/1209 _, err := io.CopyBuffer(WriteOnlyWriter{Writer: leftConn}, ReadOnlyReader{Reader: rightConn}, buf) - pool.Put(buf) - leftConn.SetReadDeadline(time.Now()) + _ = pool.Put(buf) + _ = leftConn.SetReadDeadline(time.Now()) ch <- err }() buf := pool.Get(pool.RelayBufferSize) - io.CopyBuffer(WriteOnlyWriter{Writer: rightConn}, ReadOnlyReader{Reader: leftConn}, buf) - pool.Put(buf) - rightConn.SetReadDeadline(time.Now()) + _, _ = io.CopyBuffer(WriteOnlyWriter{Writer: rightConn}, ReadOnlyReader{Reader: leftConn}, buf) + _ = pool.Put(buf) + _ = rightConn.SetReadDeadline(time.Now()) <-ch } + +func tcpKeepAlive(c net.Conn) { + if tcp, ok := c.(*net.TCPConn); ok { + _ = tcp.SetKeepAlive(true) + } +} diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 680d646c..3bc92e13 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -62,20 +62,22 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, request.RequestURI = "" if isUpgradeRequest(request) { - handleUpgrade(conn, request, in) - - return // hijack connection + if resp = handleUpgrade(conn, request, in); resp == nil { + return // hijack connection + } } - RemoveHopByHopHeaders(request.Header) - RemoveExtraHTTPHostPort(request) + if resp == nil { + RemoveHopByHopHeaders(request.Header) + RemoveExtraHTTPHostPort(request) - if request.URL.Scheme == "" || request.URL.Host == "" { - resp = responseWith(request, http.StatusBadRequest) - } else { - resp, err = client.Do(request) - if err != nil { - resp = responseWith(request, http.StatusBadGateway) + if request.URL.Scheme == "" || request.URL.Host == "" { + resp = responseWith(request, http.StatusBadRequest) + } else { + resp, err = client.Do(request) + if err != nil { + resp = responseWith(request, http.StatusBadGateway) + } } } @@ -96,7 +98,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, } } - conn.Close() + _ = conn.Close() } func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response { diff --git a/listener/http/upgrade.go b/listener/http/upgrade.go index df700dfb..f770eb25 100644 --- a/listener/http/upgrade.go +++ b/listener/http/upgrade.go @@ -4,6 +4,7 @@ import ( "net" "net/http" "strings" + "time" "github.com/Dreamacro/clash/adapter/inbound" N "github.com/Dreamacro/clash/common/net" @@ -15,9 +16,7 @@ func isUpgradeRequest(req *http.Request) bool { return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") } -func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext) { - defer conn.Close() - +func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { removeProxyHeaders(request.Header) RemoveExtraHTTPHostPort(request) @@ -36,26 +35,36 @@ func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext in <- inbound.NewHTTP(dstAddr, conn.RemoteAddr(), right) bufferedLeft := N.NewBufferedConn(left) - defer bufferedLeft.Close() + defer func() { + _ = bufferedLeft.Close() + }() err := request.Write(bufferedLeft) if err != nil { return } - resp, err := http.ReadResponse(bufferedLeft.Reader(), request) - if err != nil { - return - } - - removeProxyHeaders(resp.Header) - - err = resp.Write(conn) + resp, err = http.ReadResponse(bufferedLeft.Reader(), request) if err != nil { return } if resp.StatusCode == http.StatusSwitchingProtocols { + removeProxyHeaders(resp.Header) + + err = conn.SetReadDeadline(time.Time{}) + if err != nil { + return + } + + err = resp.Write(conn) + if err != nil { + return + } + N.Relay(bufferedLeft, conn) + _ = conn.Close() + resp = nil } + return } diff --git a/listener/http/utils.go b/listener/http/utils.go index ca704645..e9994acc 100644 --- a/listener/http/utils.go +++ b/listener/http/utils.go @@ -8,7 +8,7 @@ import ( "strings" ) -// removeHopByHopHeaders remove Proxy-* headers +// removeProxyHeaders remove Proxy-* headers func removeProxyHeaders(header http.Header) { header.Del("Proxy-Connection") header.Del("Proxy-Authenticate") diff --git a/tunnel/connection.go b/tunnel/connection.go index 57283aab..0384e805 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -62,13 +62,5 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, fAddr n } func handleSocket(ctx C.ConnContext, outbound net.Conn) { - tcpKeepAlive(ctx.Conn()) - tcpKeepAlive(outbound) N.Relay(ctx.Conn(), outbound) } - -func tcpKeepAlive(c net.Conn) { - if tcp, ok := c.(*net.TCPConn); ok { - tcp.SetKeepAlive(true) - } -}