diff --git a/adapter/outbound/http.go b/adapter/outbound/http.go index ff87af6f..91d4f381 100644 --- a/adapter/outbound/http.go +++ b/adapter/outbound/http.go @@ -90,7 +90,7 @@ func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error { } if metadata.Type == C.MITM { - req.Header.Add("Origin-Request-Source-Address", metadata.SourceAddress()) + req.Header.Set("Origin-Request-Source-Address", metadata.SourceAddress()) } if err := req.Write(rw); err != nil { diff --git a/adapter/outbound/mitm.go b/adapter/outbound/mitm.go index faa04a4c..a8fd6cfb 100644 --- a/adapter/outbound/mitm.go +++ b/adapter/outbound/mitm.go @@ -7,26 +7,23 @@ import ( "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/trie" C "github.com/Dreamacro/clash/constant" - - "go.uber.org/atomic" ) var ( errIgnored = errors.New("not match in mitm host lists") httpProxyClient = NewHttp(HttpOption{}) - MiddlemanServerAddress = atomic.NewString("") - MiddlemanRewriteHosts *trie.DomainTrie[bool] + MiddlemanRewriteHosts *trie.DomainTrie[bool] ) type Mitm struct { *Base + serverAddr string } // DialContext implements C.ProxyAdapter -func (d *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.Conn, error) { - addr := MiddlemanServerAddress.Load() - if addr == "" || MiddlemanRewriteHosts == nil { +func (m *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.Conn, error) { + if MiddlemanRewriteHosts == nil { return nil, errIgnored } @@ -41,7 +38,7 @@ func (d *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...diale metadata.DstIP = nil } - c, err := dialer.DialContext(ctx, "tcp", addr, []dialer.Option{dialer.WithInterface(""), dialer.WithRoutingMark(0)}...) + c, err := dialer.DialContext(ctx, "tcp", m.serverAddr, []dialer.Option{dialer.WithInterface(""), dialer.WithRoutingMark(0)}...) if err != nil { return nil, err } @@ -55,14 +52,15 @@ func (d *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...diale return nil, err } - return NewConn(c, d), nil + return NewConn(c, m), nil } -func NewMitm() *Mitm { +func NewMitm(serverAddr string) *Mitm { return &Mitm{ Base: &Base{ name: "Mitm", tp: C.Mitm, }, + serverAddr: serverAddr, } } diff --git a/listener/listener.go b/listener/listener.go index a8137849..f75a282f 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -10,6 +10,7 @@ import ( "os" "strconv" "sync" + "time" "github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/adapter/outbound" @@ -360,7 +361,6 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) { if mitmListener.RawAddress() == addr { return } - outbound.MiddlemanServerAddress.Store("") tunnel.MitmOutbound = nil _ = mitmListener.Close() mitmListener = nil @@ -401,7 +401,7 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) { return } - certOption.SetValidity(cert.TTL << 3) + certOption.SetValidity(time.Hour * 24 * 90) certOption.SetOrganization("Clash ManInTheMiddle Proxy Services") opt := &mitm.Option{ @@ -416,8 +416,7 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) { return } - outbound.MiddlemanServerAddress.Store(mitmListener.Address()) - tunnel.MitmOutbound = outbound.NewMitm() + tunnel.MitmOutbound = outbound.NewMitm(mitmListener.Address()) log.Infoln("Mitm proxy listening at: %s", mitmListener.Address()) } diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index ce5e8f2e..0ab138fc 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -32,8 +32,8 @@ func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.C }() startOver: - if tc, ok := c.(*net.TCPConn); ok { - _ = tc.SetKeepAlive(true) + if tcpConn, ok := c.(*net.TCPConn); ok { + _ = tcpConn.SetKeepAlive(true) } var conn *N.BufferedConn @@ -47,14 +47,13 @@ startOver: readLoop: for { - err := conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive - if err != nil { + // use SetDeadline instead of Proxy-Connection keep-alive + if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil { break readLoop } request, err := H.ReadRequest(conn.Reader()) if err != nil { - handleError(opt, nil, err) break readLoop } @@ -83,27 +82,15 @@ readLoop: goto readLoop } - b := make([]byte, 1) - if _, err = session.conn.Read(b); err != nil { + b, err := conn.Peek(1) + if err != nil { handleError(opt, session, err) break readLoop // close connection } - buff := make([]byte, session.conn.(*N.BufferedConn).Buffered()) - if _, err = session.conn.Read(buff); err != nil { - handleError(opt, session, err) - break readLoop // close connection - } - - mrConn := &multiReaderConn{ - Conn: session.conn, - reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn), - } - // TLS handshake. if b[0] == 0x16 { - // TODO serve by generic host name maybe better? - tlsConn := tls.Server(mrConn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) + tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) // Handshake with the local client if err = tlsConn.Handshake(); err != nil { @@ -114,7 +101,7 @@ readLoop: c = tlsConn } else { - c = mrConn + c = conn } goto startOver @@ -122,8 +109,11 @@ readLoop: prepareRequest(c, session.request) + H.RemoveHopByHopHeaders(session.request.Header) + H.RemoveExtraHTTPHostPort(session.request) + // hijack api - if session.request.URL.Host == opt.ApiHost { + if session.request.URL.Hostname() == opt.ApiHost { if err = handleApiRequest(session, opt); err != nil { handleError(opt, session, err) break readLoop @@ -162,7 +152,6 @@ readLoop: handleError(opt, session, err) session.response = session.NewErrorResponse(err) if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { - // TODO block unsupported host? _ = writeResponse(session, false) break readLoop } @@ -287,9 +276,6 @@ func prepareRequest(conn net.Conn, request *http.Request) { if request.Header.Get("Accept-Encoding") != "" { request.Header.Set("Accept-Encoding", "gzip") } - - H.RemoveHopByHopHeaders(request.Header) - H.RemoveExtraHTTPHostPort(request) } func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr { diff --git a/listener/mitm/utils.go b/listener/mitm/utils.go index 5c3b15bd..a84c75cf 100644 --- a/listener/mitm/utils.go +++ b/listener/mitm/utils.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "io/ioutil" - "net" "net/http" "time" @@ -21,15 +20,6 @@ var ( ErrInvalidURL = errors.New("invalid URL") ) -type multiReaderConn struct { - net.Conn - reader io.Reader -} - -func (c *multiReaderConn) Read(buf []byte) (int, error) { - return c.reader.Read(buf) -} - func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { if body == nil { body = &bytes.Buffer{}