Fix: mitm proxy should handle none-http(s) protocol over tcp

This commit is contained in:
yaling888 2022-04-28 06:46:57 +08:00
parent 22458ad0be
commit da92601902
3 changed files with 69 additions and 26 deletions

View File

@ -17,7 +17,7 @@ func getServerConn(serverConn *N.BufferedConn, request *http.Request, srcAddr ne
return serverConn, nil return serverConn, nil
} }
address := request.Host address := request.URL.Host
if _, _, err := net.SplitHostPort(address); err != nil { if _, _, err := net.SplitHostPort(address); err != nil {
port := "80" port := "80"
if request.TLS != nil { if request.TLS != nil {

9
listener/mitm/hack.go Normal file
View File

@ -0,0 +1,9 @@
package mitm
import (
_ "net/http"
_ "unsafe"
)
//go:linkname validMethod net/http.validMethod
func validMethod(method string) bool

View File

@ -1,9 +1,11 @@
package mitm package mitm
import ( import (
"bufio"
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -18,11 +20,12 @@ import (
H "github.com/Dreamacro/clash/listener/http" H "github.com/Dreamacro/clash/listener/http"
) )
func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) {
var ( var (
clientIP = netip.MustParseAddrPort(clientConn.RemoteAddr().String()).Addr() clientIP = netip.MustParseAddrPort(c.RemoteAddr().String()).Addr()
sourceAddr net.Addr sourceAddr net.Addr
serverConn *N.BufferedConn serverConn *N.BufferedConn
connState *tls.ConnectionState
) )
defer func() { defer func() {
@ -31,17 +34,11 @@ func HandleConn(clientConn net.Conn, opt *Option, in chan<- C.ConnContext, cache
} }
}() }()
startOver: conn := N.NewBufferedConn(c)
var conn *N.BufferedConn
if bufConn, ok := clientConn.(*N.BufferedConn); ok {
conn = bufConn
} else {
conn = N.NewBufferedConn(clientConn)
}
trusted := cache == nil // disable authenticate if cache is nil trusted := cache == nil // disable authenticate if cache is nil
if !trusted { if !trusted {
trusted = clientIP.IsLoopback() trusted = clientIP.IsLoopback() || clientIP.IsUnspecified()
} }
readLoop: readLoop:
@ -60,7 +57,7 @@ readLoop:
session := newSession(conn, request, response) session := newSession(conn, request, response)
sourceAddr = parseSourceAddress(session.request, clientConn.RemoteAddr(), sourceAddr) sourceAddr = parseSourceAddress(session.request, conn.RemoteAddr(), sourceAddr)
session.request.RemoteAddr = sourceAddr.String() session.request.RemoteAddr = sourceAddr.String()
if !trusted { if !trusted {
@ -81,9 +78,9 @@ readLoop:
goto readLoop goto readLoop
} }
b, err := conn.Peek(1) b, err1 := conn.Peek(1)
if err != nil { if err1 != nil {
handleError(opt, session, err) handleError(opt, session, err1)
break // close connection break // close connection
} }
@ -98,15 +95,49 @@ readLoop:
break // close connection break // close connection
} }
clientConn = tlsConn cs := tlsConn.ConnectionState()
connState = &cs
conn = N.NewBufferedConn(tlsConn)
}
if strings.HasSuffix(session.request.URL.Host, ":443") {
goto readLoop
}
var noErr bool
buf, err2 := conn.Peek(7)
if err2 != nil {
if err2 == bufio.ErrBufferFull || errors.Is(err2, io.EOF) {
noErr = true
} else { } else {
clientConn = conn handleError(opt, session, err2)
break // close connection
}
} }
goto startOver // others protocol over tcp
if noErr || !isHTTPTraffic(buf) {
if connState != nil {
session.request.TLS = connState
} }
prepareRequest(clientConn, session.request) serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in)
if err != nil {
break
}
_ = conn.SetReadDeadline(time.Time{})
N.Relay(serverConn, conn)
return // hijack connection
}
goto readLoop
}
prepareRequest(connState, session.request)
// hijack api // hijack api
if session.request.URL.Hostname() == opt.ApiHost { if session.request.URL.Hostname() == opt.ApiHost {
@ -250,7 +281,7 @@ func handleError(opt *Option, session *Session, err error) {
opt.Handler.HandleError(session, err) opt.Handler.HandleError(session, err)
} }
func prepareRequest(conn net.Conn, request *http.Request) { func prepareRequest(connState *tls.ConnectionState, request *http.Request) {
host := request.Header.Get("Host") host := request.Header.Get("Host")
if host != "" { if host != "" {
request.Host = host request.Host = host
@ -264,10 +295,8 @@ func prepareRequest(conn net.Conn, request *http.Request) {
request.URL.Scheme = "http" request.URL.Scheme = "http"
} }
if tlsConn, ok := conn.(*tls.Conn); ok { if connState != nil {
cs := tlsConn.ConnectionState() request.TLS = connState
request.TLS = &cs
request.URL.Scheme = "https" request.URL.Scheme = "https"
} }
@ -300,5 +329,10 @@ func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr
} }
func isWebsocketRequest(req *http.Request) bool { func isWebsocketRequest(req *http.Request) bool {
return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket" return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
}
func isHTTPTraffic(buf []byte) bool {
method, _, _ := strings.Cut(string(buf), " ")
return validMethod(method)
} }