Fix: mitm proxy should forward websocket

This commit is contained in:
yaling888 2022-04-27 05:14:03 +08:00
parent 7c50c068f5
commit 30025c0241
7 changed files with 82 additions and 72 deletions

View File

@ -215,11 +215,10 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica
BasicConstraintsValid: true, BasicConstraintsValid: true,
NotBefore: time.Now().Add(-c.validity), NotBefore: time.Now().Add(-c.validity),
NotAfter: time.Now().Add(c.validity), NotAfter: time.Now().Add(c.validity),
DNSNames: dnsNames,
IPAddresses: ips,
} }
tmpl.DNSNames = dnsNames
tmpl.IPAddresses = ips
raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey) raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -62,7 +62,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string,
request.RequestURI = "" request.RequestURI = ""
if isUpgradeRequest(request) { if isUpgradeRequest(request) {
if resp = handleUpgrade(conn, request, in); resp == nil { if resp = HandleUpgrade(conn, conn.RemoteAddr(), request, in); resp == nil {
return // hijack connection return // hijack connection
} }
} }

View File

@ -1,6 +1,8 @@
package http package http
import ( import (
"context"
"crypto/tls"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -16,13 +18,17 @@ func isUpgradeRequest(req *http.Request) bool {
return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") return strings.EqualFold(req.Header.Get("Connection"), "Upgrade")
} }
func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) { func HandleUpgrade(localConn net.Conn, source net.Addr, request *http.Request, in chan<- C.ConnContext) (resp *http.Response) {
removeProxyHeaders(request.Header) removeProxyHeaders(request.Header)
RemoveExtraHTTPHostPort(request) RemoveExtraHTTPHostPort(request)
address := request.Host address := request.Host
if _, _, err := net.SplitHostPort(address); err != nil { if _, _, err := net.SplitHostPort(address); err != nil {
address = net.JoinHostPort(address, "80") port := "80"
if request.TLS != nil {
port = "443"
}
address = net.JoinHostPort(address, port)
} }
dstAddr := socks5.ParseAddr(address) dstAddr := socks5.ParseAddr(address)
@ -32,38 +38,58 @@ func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext
left, right := net.Pipe() left, right := net.Pipe()
in <- inbound.NewHTTP(dstAddr, conn.RemoteAddr(), right) in <- inbound.NewMitm(dstAddr, source, request.Header.Get("User-Agent"), right)
var remoteServer *N.BufferedConn
if request.TLS != nil {
tlsConn := tls.Client(left, &tls.Config{
ServerName: request.URL.Hostname(),
})
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
defer cancel()
if tlsConn.HandshakeContext(ctx) != nil {
_ = localConn.Close()
_ = left.Close()
return
}
remoteServer = N.NewBufferedConn(tlsConn)
} else {
remoteServer = N.NewBufferedConn(left)
}
bufferedLeft := N.NewBufferedConn(left)
defer func() { defer func() {
_ = bufferedLeft.Close() _ = remoteServer.Close()
}() }()
err := request.Write(bufferedLeft) err := request.Write(remoteServer)
if err != nil { if err != nil {
_ = localConn.Close()
return return
} }
resp, err = http.ReadResponse(bufferedLeft.Reader(), request) resp, err = http.ReadResponse(remoteServer.Reader(), request)
if err != nil { if err != nil {
_ = localConn.Close()
return return
} }
if resp.StatusCode == http.StatusSwitchingProtocols { if resp.StatusCode == http.StatusSwitchingProtocols {
removeProxyHeaders(resp.Header) removeProxyHeaders(resp.Header)
err = conn.SetReadDeadline(time.Time{}) err = localConn.SetReadDeadline(time.Time{}) // set to not time out
if err != nil { if err != nil {
return return
} }
err = resp.Write(conn) err = resp.Write(localConn)
if err != nil { if err != nil {
return return
} }
N.Relay(bufferedLeft, conn) N.Relay(remoteServer, localConn) // blocking here
_ = conn.Close() _ = localConn.Close()
resp = nil resp = nil
} }
return return

View File

@ -45,12 +45,12 @@ readLoop:
for { for {
// use SetReadDeadline instead of Proxy-Connection keep-alive // use SetReadDeadline instead of Proxy-Connection keep-alive
if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil {
break readLoop break
} }
request, err := H.ReadRequest(conn.Reader()) request, err := H.ReadRequest(conn.Reader())
if err != nil { if err != nil {
break readLoop break
} }
var response *http.Response var response *http.Response
@ -71,7 +71,7 @@ readLoop:
// Manual writing to support CONNECT for http 1.0 (workaround for uplay client) // Manual writing to support CONNECT for http 1.0 (workaround for uplay client)
if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil { if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop // close connection break // close connection
} }
if strings.HasSuffix(session.request.URL.Host, ":80") { if strings.HasSuffix(session.request.URL.Host, ":80") {
@ -81,7 +81,7 @@ readLoop:
b, err := conn.Peek(1) b, err := conn.Peek(1)
if err != nil { if err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop // close connection break // close connection
} }
// TLS handshake. // TLS handshake.
@ -92,7 +92,7 @@ readLoop:
if err = tlsConn.Handshake(); err != nil { if err = tlsConn.Handshake(); err != nil {
session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err)) session.response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err))
_ = writeResponse(session, false) _ = writeResponse(session, false)
break readLoop // close connection break // close connection
} }
c = tlsConn c = tlsConn
@ -105,20 +105,27 @@ readLoop:
prepareRequest(c, session.request) prepareRequest(c, session.request)
H.RemoveHopByHopHeaders(session.request.Header)
H.RemoveExtraHTTPHostPort(session.request)
// hijack api // hijack api
if session.request.URL.Hostname() == opt.ApiHost { if session.request.URL.Hostname() == opt.ApiHost {
if err = handleApiRequest(session, opt); err != nil { if err = handleApiRequest(session, opt); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop
} }
return break
} }
// forward websocket
if isWebsocketRequest(request) {
session.request.RequestURI = ""
if session.response = H.HandleUpgrade(conn, source, request, in); session.response == nil {
return // hijack connection
}
}
H.RemoveHopByHopHeaders(session.request.Header)
H.RemoveExtraHTTPHostPort(session.request)
// hijack custom request and write back custom response if necessary // hijack custom request and write back custom response if necessary
if opt.Handler != nil { if opt.Handler != nil && session.response == nil {
newReq, newRes := opt.Handler.HandleRequest(session) newReq, newRes := opt.Handler.HandleRequest(session)
if newReq != nil { if newReq != nil {
session.request = newReq session.request = newReq
@ -128,28 +135,30 @@ readLoop:
if err = writeResponse(session, false); err != nil { if err = writeResponse(session, false); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop break
} }
return continue
} }
} }
session.request.RequestURI = "" if session.response == nil {
session.request.RequestURI = ""
if session.request.URL.Host == "" { if session.request.URL.Host == "" {
session.response = session.NewErrorResponse(ErrInvalidURL) session.response = session.NewErrorResponse(ErrInvalidURL)
} else { } else {
client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in)
// send the request to remote server // send the request to remote server
session.response, err = client.Do(session.request) session.response, err = client.Do(session.request)
if err != nil { if err != nil {
handleError(opt, session, err) handleError(opt, session, err)
session.response = session.NewErrorResponse(err) session.response = session.NewErrorResponse(fmt.Errorf("request failed: %w", err))
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
_ = writeResponse(session, false) _ = writeResponse(session, false)
break readLoop break
}
} }
} }
} }
@ -157,7 +166,7 @@ readLoop:
if err = writeResponseWithHandler(session, opt); err != nil { if err = writeResponseWithHandler(session, opt); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop // close connection break // close connection
} }
} }

View File

@ -20,6 +20,10 @@ var (
ErrInvalidURL = errors.New("invalid URL") ErrInvalidURL = errors.New("invalid URL")
) )
func isWebsocketRequest(req *http.Request) bool {
return req.Header.Get("Connection") == "Upgrade" && req.Header.Get("Upgrade") == "websocket"
}
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
if body == nil { if body == nil {
body = &bytes.Buffer{} body = &bytes.Buffer{}

View File

@ -2,7 +2,6 @@ package nat
import ( import (
"net/netip" "net/netip"
"sync"
"github.com/Dreamacro/clash/common/generics/list" "github.com/Dreamacro/clash/common/generics/list"
) )
@ -25,7 +24,6 @@ type binding struct {
} }
type table struct { type table struct {
mu sync.Mutex
tuples map[tuple]*list.Element[*binding] tuples map[tuple]*list.Element[*binding]
ports [portLength]*list.Element[*binding] ports [portLength]*list.Element[*binding]
available *list.List[*binding] available *list.List[*binding]
@ -39,13 +37,13 @@ func (t *table) tupleOf(port uint16) tuple {
elm := t.ports[offset] elm := t.ports[offset]
t.available.MoveToFront(elm)
return elm.Value.tuple return elm.Value.tuple
} }
func (t *table) portOf(tuple tuple) uint16 { func (t *table) portOf(tuple tuple) uint16 {
t.mu.Lock()
elm := t.tuples[tuple] elm := t.tuples[tuple]
t.mu.Unlock()
if elm == nil { if elm == nil {
return 0 return 0
} }
@ -59,11 +57,8 @@ func (t *table) newConn(tuple tuple) uint16 {
elm := t.available.Back() elm := t.available.Back()
b := elm.Value b := elm.Value
t.mu.Lock()
delete(t.tuples, b.tuple) delete(t.tuples, b.tuple)
t.tuples[tuple] = elm t.tuples[tuple] = elm
t.mu.Unlock()
b.tuple = tuple b.tuple = tuple
t.available.MoveToFront(elm) t.available.MoveToFront(elm)
@ -71,19 +66,6 @@ func (t *table) newConn(tuple tuple) uint16 {
return portBegin + b.offset return portBegin + b.offset
} }
func (t *table) delete(tup tuple) {
t.mu.Lock()
elm := t.tuples[tup]
if elm == nil {
t.mu.Unlock()
return
}
delete(t.tuples, tup)
t.mu.Unlock()
t.available.MoveToBack(elm)
}
func newTable() *table { func newTable() *table {
result := &table{ result := &table{
tuples: make(map[tuple]*list.Element[*binding], portLength), tuples: make(map[tuple]*list.Element[*binding], portLength),

View File

@ -16,8 +16,6 @@ type conn struct {
net.Conn net.Conn
tuple tuple tuple tuple
close func(tuple tuple)
} }
func (t *TCP) Accept() (net.Conn, error) { func (t *TCP) Accept() (net.Conn, error) {
@ -39,9 +37,6 @@ func (t *TCP) Accept() (net.Conn, error) {
return &conn{ return &conn{
Conn: c, Conn: c,
tuple: tup, tuple: tup,
close: func(tuple tuple) {
t.table.delete(tuple)
},
}, nil }, nil
} }
@ -57,11 +52,6 @@ func (t *TCP) SetDeadline(time time.Time) error {
return t.listener.SetDeadline(time) return t.listener.SetDeadline(time)
} }
func (c *conn) Close() error {
c.close(c.tuple)
return c.Conn.Close()
}
func (c *conn) LocalAddr() net.Addr { func (c *conn) LocalAddr() net.Addr {
return &net.TCPAddr{ return &net.TCPAddr{
IP: c.tuple.SourceAddr.Addr().AsSlice(), IP: c.tuple.SourceAddr.Addr().AsSlice(),