Chore: merge branch 'with-tun' into plus-pro

This commit is contained in:
yaling888 2022-04-29 22:24:18 +08:00
commit d74dd69329
7 changed files with 181 additions and 151 deletions

View File

@ -62,7 +62,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
request.RequestURI = "" request.RequestURI = ""
if isUpgradeRequest(request) { if isUpgradeRequest(request) {
if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil { if resp = HandleUpgrade(conn, nil, request, in); resp == nil {
return // hijack connection return // hijack connection
} }
} }

View File

@ -1,8 +1,6 @@
package http package http
import ( import (
"context"
"crypto/tls"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -18,58 +16,43 @@ func isUpgradeRequest(req *http.Request) bool {
return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") return strings.EqualFold(req.Header.Get("Connection"), "Upgrade")
} }
func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { func HandleUpgrade(localConn net.Conn, serverConn *N.BufferedConn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) {
removeProxyHeaders(request.Header) removeProxyHeaders(request.Header)
RemoveExtraHTTPHostPort(request) RemoveExtraHTTPHostPort(request)
address := request.Host if serverConn == nil {
if _, _, err := net.SplitHostPort(address); err != nil { address := request.Host
port := "80" if _, _, err := net.SplitHostPort(address); err != nil {
if request.TLS != nil { port := "80"
port = "443" if request.TLS != nil {
port = "443"
}
address = net.JoinHostPort(address, port)
} }
address = net.JoinHostPort(address, port)
}
dstAddr := socks5.ParseAddr(address) dstAddr := socks5.ParseAddr(address)
if dstAddr == nil { if dstAddr == nil {
return
}
left, right := net.Pipe()
in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right)
var remoteServer *N.BufferedConn
if request.TLS != nil {
tlsConn := tls.Client(left, &tls.Config{
ServerName: request.URL.Hostname(),
})
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
defer cancel()
if tlsConn.HandshakeContext(ctx) != nil {
_ = localConn.Close()
_ = left.Close()
return return
} }
remoteServer = N.NewBufferedConn(tlsConn) left, right := net.Pipe()
} else {
remoteServer = N.NewBufferedConn(left) in <- inbound.NewHTTP(dstAddr, localConn.RemoteAddr(), right)
serverConn = N.NewBufferedConn(left)
defer func() {
_ = serverConn.Close()
}()
} }
defer func() { err := request.Write(serverConn)
_ = remoteServer.Close()
}()
err := request.Write(remoteServer)
if err != nil { if err != nil {
_ = localConn.Close() _ = localConn.Close()
return return
} }
resp, err = http.ReadResponse(remoteServer.Reader(), request) resp, err = http.ReadResponse(serverConn.Reader(), request)
if err != nil { if err != nil {
_ = localConn.Close() _ = localConn.Close()
return return
@ -88,7 +71,7 @@ func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, i
return return
} }
N.Relay(remoteServer, localConn) // blocking here N.Relay(serverConn, localConn) // blocking here
_ = localConn.Close() _ = localConn.Close()
resp = nil resp = nil
} }

View File

@ -3,53 +3,53 @@ package mitm
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"net" "net"
"net/http" "net/http"
"time"
"github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/adapter/inbound"
N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
) )
func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client { func getServerConn(serverConn *N.BufferedConn, request *http.Request, srcAddr net.Addr, in chan<- C.ConnContext) (*N.BufferedConn, error) {
return &http.Client{ if serverConn != nil {
Transport: &http.Transport{ return serverConn, nil
// excepted HTTP/2
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
// only needed 1 connection
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
IdleConnTimeout: 60 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{
GetClientCertificate: func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, e error) {
return nil, ErrCertUnsupported
},
},
DialContext: func(context context.Context, network, address string) (net.Conn, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, errors.New("unsupported network " + network)
}
dstAddr := socks5.ParseAddr(address)
if dstAddr == nil {
return nil, socks5.ErrAddressNotSupported
}
left, right := net.Pipe()
in <- inbound.NewMitm(dstAddr, source, userAgent, right)
return left, nil
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Timeout: 120 * time.Second,
} }
address := request.URL.Host
if _, _, err := net.SplitHostPort(address); err != nil {
port := "80"
if request.TLS != nil {
port = "443"
}
address = net.JoinHostPort(address, port)
}
dstAddr := socks5.ParseAddr(address)
if dstAddr == nil {
return nil, socks5.ErrAddressNotSupported
}
left, right := net.Pipe()
in <- inbound.NewMitm(dstAddr, srcAddr, request.Header.Get("User-Agent"), right)
if request.TLS != nil {
tlsConn := tls.Client(left, &tls.Config{
ServerName: request.TLS.ServerName,
})
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
defer cancel()
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}
serverConn = N.NewBufferedConn(tlsConn)
} else {
serverConn = N.NewBufferedConn(left)
}
return serverConn, nil
} }

9
listener/mitm/hack.go Normal file
View File

@ -0,0 +1,9 @@
package mitm
import (
_ "net/http"
_ "unsafe"
)
//go:linkname validMethod net/http.validMethod
func validMethod(method string) bool

View File

@ -1,15 +1,17 @@
package mitm package mitm
import ( import (
"bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
"os"
"strings" "strings"
"time" "time"
@ -21,25 +23,24 @@ import (
func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) {
var ( var (
source net.Addr clientIP = netip.MustParseAddrPort(c.RemoteAddr().String()).Addr()
client *http.Client sourceAddr net.Addr
serverConn *N.BufferedConn
connState *tls.ConnectionState
) )
defer func() { defer func() {
if client != nil { if serverConn != nil {
client.CloseIdleConnections() _ = serverConn.Close()
} }
}() }()
startOver: conn := N.NewBufferedConn(c)
var conn *N.BufferedConn
if bufConn, ok := c.(*N.BufferedConn); ok {
conn = bufConn
} else {
conn = N.NewBufferedConn(c)
}
trusted := cache == nil // disable authenticate if cache is nil trusted := cache == nil // disable authenticate if cache is nil
if !trusted {
trusted = clientIP.IsLoopback() || clientIP.IsUnspecified()
}
readLoop: readLoop:
for { for {
@ -57,8 +58,8 @@ readLoop:
session := newSession(conn, request, response) session := newSession(conn, request, response)
source = parseSourceAddress(session.request, c.RemoteAddr(), source) sourceAddr = parseSourceAddress(session.request, conn.RemoteAddr(), sourceAddr)
session.request.RemoteAddr = source.String() session.request.RemoteAddr = sourceAddr.String()
if !trusted { if !trusted {
session.response = H.Authenticate(session.request, cache) session.response = H.Authenticate(session.request, cache)
@ -68,6 +69,11 @@ readLoop:
if trusted { if trusted {
if session.request.Method == http.MethodConnect { if session.request.Method == http.MethodConnect {
if session.request.ProtoMajor > 1 {
session.request.ProtoMajor = 1
session.request.ProtoMinor = 1
}
// Manual writing to support CONNECT for http 1.0 (workaround for uplay client) // Manual writing to support CONNECT for http 1.0 (workaround for uplay client)
if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil { if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
@ -78,9 +84,9 @@ readLoop:
goto readLoop goto readLoop
} }
b, err := conn.Peek(1) b, err1 := conn.Peek(1)
if err != nil { if err1 != nil {
handleError(opt, session, err) handleError(opt, session, err1)
break // close connection break // close connection
} }
@ -88,22 +94,61 @@ readLoop:
if b[0] == 0x16 { if b[0] == 0x16 {
tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Hostname())) tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Hostname()))
// Handshake with the local client ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
if err = tlsConn.Handshake(); err != nil { // handshake with the local client
if err = tlsConn.HandshakeContext(ctx); err != nil {
cancel()
session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err)) session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err))
_ = writeResponse(session, false) _ = writeResponse(session, false)
break // close connection break // close connection
} }
cancel()
c = tlsConn cs := tlsConn.ConnectionState()
} else { connState = &cs
c = conn
conn = N.NewBufferedConn(tlsConn)
} }
goto startOver if strings.HasSuffix(session.request.URL.Host, ":443") {
goto readLoop
}
if conn.SetReadDeadline(time.Now().Add(time.Second)) != nil {
break
}
buf, err2 := conn.Peek(7)
if err2 != nil {
if err2 != bufio.ErrBufferFull && !os.IsTimeout(err2) {
handleError(opt, session, err2)
break // close connection
}
}
// others protocol over tcp
if !isHTTPTraffic(buf) {
if connState != nil {
session.request.TLS = connState
}
serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in)
if err != nil {
break
}
if conn.SetReadDeadline(time.Time{}) != nil {
break
}
N.Relay(serverConn, conn)
return // hijack connection
}
goto readLoop
} }
prepareRequest(c, session.request) prepareRequest(connState, session.request)
// hijack api // hijack api
if session.request.URL.Hostname() == opt.ApiHost { if session.request.URL.Hostname() == opt.ApiHost {
@ -115,17 +160,22 @@ readLoop:
// forward websocket // forward websocket
if isWebsocketRequest(request) { if isWebsocketRequest(request) {
serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in)
if err != nil {
break
}
session.request.RequestURI = "" session.request.RequestURI = ""
if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil { if session.response = H.HandleUpgrade(conn, serverConn, request, in); session.response == nil {
return // hijack connection return // hijack connection
} }
} }
H.RemoveHopByHopHeaders(session.request.Header) if session.response == nil {
H.RemoveExtraHTTPHostPort(session.request) H.RemoveHopByHopHeaders(session.request.Header)
H.RemoveExtraHTTPHostPort(session.request)
// hijack custom request and write back custom response if necessary // hijack custom request and write back custom response if necessary
if opt.Handler != nil && session.response == nil {
newReq, newRes := opt.Handler.HandleRequest(session) newReq, newRes := opt.Handler.HandleRequest(session)
if newReq != nil { if newReq != nil {
session.request = newReq session.request = newReq
@ -139,26 +189,26 @@ readLoop:
} }
continue continue
} }
}
if session.response == nil {
session.request.RequestURI = "" session.request.RequestURI = ""
if session.request.URL.Host == "" { if session.request.URL.Host == "" {
session.response = session.NewErrorResponse(ErrInvalidURL) session.response = session.NewErrorResponse(ErrInvalidURL)
} else { } else {
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) serverConn, err = getServerConn(serverConn, session.request, sourceAddr, in)
if err != nil {
break
}
// send the request to remote server // send the request to remote server
session.response, err = client.Do(session.request) err = session.request.Write(serverConn)
if err != nil { if err != nil {
handleError(opt, session, err) break
session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err)) }
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
_ = writeResponse(session, false) session.response, err = http.ReadResponse(serverConn.Reader(), request)
break if err != nil {
} break
} }
} }
} }
@ -174,11 +224,9 @@ readLoop:
} }
func writeResponseWithHandler(session *Session, opt *Option) error { func writeResponseWithHandler(session *Session, opt *Option) error {
if opt.Handler != nil { res := opt.Handler.HandleResponse(session)
res := opt.Handler.HandleResponse(session) if res != nil {
if res != nil { session.response = res
session.response = res
}
} }
return writeResponse(session, true) return writeResponse(session, true)
@ -220,10 +268,8 @@ func handleApiRequest(session *Session, opt *Option) error {
</body></html> </body></html>
` `
if opt.Handler != nil { if opt.Handler.HandleApiRequest(session) {
if opt.Handler.HandleApiRequest(session) { return nil
return nil
}
} }
b = fmt.Sprintf(b, session.request.URL.Path) b = fmt.Sprintf(b, session.request.URL.Path)
@ -243,12 +289,10 @@ func handleError(opt *Option, session *Session, err error) {
_ = session.response.Body.Close() _ = session.response.Body.Close()
}() }()
} }
if opt.Handler != nil { opt.Handler.HandleError(session, err)
opt.Handler.HandleError(session, err)
}
} }
func prepareRequest(conn net.Conn, request *http.Request) { func prepareRequest(connState *tls.ConnectionState, request *http.Request) {
host := request.Header.Get("Host") host := request.Header.Get("Host")
if host != "" { if host != "" {
request.Host = host request.Host = host
@ -262,10 +306,8 @@ func prepareRequest(conn net.Conn, request *http.Request) {
request.URL.Scheme = "http" request.URL.Scheme = "http"
} }
if tlsConn, ok := conn.(*tls.Conn); ok { if connState != nil {
cs := tlsConn.ConnectionState() request.TLS = connState
request.TLS = &cs
request.URL.Scheme = "https" request.URL.Scheme = "https"
} }
@ -297,10 +339,11 @@ func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr
} }
} }
func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client { func isWebsocketRequest(req *http.Request) bool {
if cli != nil { return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
return cli }
}
func isHTTPTraffic(buf []byte) bool {
return newClient(source, req.Header.Get("User-Agent"), in) method, _, _ := strings.Cut(string(buf), " ")
return validMethod(method)
} }

View File

@ -54,7 +54,7 @@ func (l *Listener) Close() error {
// New the MITM proxy actually is a type of HTTP proxy // New the MITM proxy actually is a type of HTTP proxy
func New(option *Option, in chan<- C.ConnContext) (*Listener, error) { func New(option *Option, in chan<- C.ConnContext) (*Listener, error) {
return NewWithAuthenticate(option, in, false) return NewWithAuthenticate(option, in, true)
} }
func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate bool) (*Listener, error) { func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate bool) (*Listener, error) {
@ -65,7 +65,7 @@ func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate b
var c *cache.Cache[string, bool] var c *cache.Cache[string, bool]
if authenticate { if authenticate {
c = cache.New[string, bool](time.Second * 30) c = cache.New[string, bool](time.Second * 90)
} }
hl := &Listener{ hl := &Listener{

View File

@ -15,15 +15,10 @@ import (
) )
var ( var (
ErrCertUnsupported = errors.New("tls: client cert unsupported")
ErrInvalidResponse = errors.New("invalid response") ErrInvalidResponse = errors.New("invalid response")
ErrInvalidURL = errors.New("invalid URL") ErrInvalidURL = errors.New("invalid URL")
) )
func isWebsocketRequest(req *http.Request) bool {
return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket"
}
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
if body == nil { if body == nil {
body = &bytes.Buffer{} body = &bytes.Buffer{}