Chore: remove TODO

This commit is contained in:
yaling888
2022-04-19 17:05:12 +08:00
parent 42cf42fd8b
commit 33d23dad6c
5 changed files with 24 additions and 51 deletions

View File

@ -90,7 +90,7 @@ func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error {
} }
if metadata.Type == C.MITM { if metadata.Type == C.MITM {
req.Header.Add("Origin-Request-Source-Address", metadata.SourceAddress()) req.Header.Set("Origin-Request-Source-Address", metadata.SourceAddress())
} }
if err := req.Write(rw); err != nil { if err := req.Write(rw); err != nil {

View File

@ -7,26 +7,23 @@ import (
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/trie" "github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"go.uber.org/atomic"
) )
var ( var (
errIgnored = errors.New("not match in mitm host lists") errIgnored = errors.New("not match in mitm host lists")
httpProxyClient = NewHttp(HttpOption{}) httpProxyClient = NewHttp(HttpOption{})
MiddlemanServerAddress = atomic.NewString("") MiddlemanRewriteHosts *trie.DomainTrie[bool]
MiddlemanRewriteHosts *trie.DomainTrie[bool]
) )
type Mitm struct { type Mitm struct {
*Base *Base
serverAddr string
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (d *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.Conn, error) { func (m *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.Conn, error) {
addr := MiddlemanServerAddress.Load() if MiddlemanRewriteHosts == nil {
if addr == "" || MiddlemanRewriteHosts == nil {
return nil, errIgnored return nil, errIgnored
} }
@ -41,7 +38,7 @@ func (d *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...diale
metadata.DstIP = nil metadata.DstIP = nil
} }
c, err := dialer.DialContext(ctx, "tcp", addr, []dialer.Option{dialer.WithInterface(""), dialer.WithRoutingMark(0)}...) c, err := dialer.DialContext(ctx, "tcp", m.serverAddr, []dialer.Option{dialer.WithInterface(""), dialer.WithRoutingMark(0)}...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -55,14 +52,15 @@ func (d *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...diale
return nil, err return nil, err
} }
return NewConn(c, d), nil return NewConn(c, m), nil
} }
func NewMitm() *Mitm { func NewMitm(serverAddr string) *Mitm {
return &Mitm{ return &Mitm{
Base: &Base{ Base: &Base{
name: "Mitm", name: "Mitm",
tp: C.Mitm, tp: C.Mitm,
}, },
serverAddr: serverAddr,
} }
} }

View File

@ -10,6 +10,7 @@ import (
"os" "os"
"strconv" "strconv"
"sync" "sync"
"time"
"github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/adapter/outbound" "github.com/Dreamacro/clash/adapter/outbound"
@ -360,7 +361,6 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) {
if mitmListener.RawAddress() == addr { if mitmListener.RawAddress() == addr {
return return
} }
outbound.MiddlemanServerAddress.Store("")
tunnel.MitmOutbound = nil tunnel.MitmOutbound = nil
_ = mitmListener.Close() _ = mitmListener.Close()
mitmListener = nil mitmListener = nil
@ -401,7 +401,7 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) {
return return
} }
certOption.SetValidity(cert.TTL << 3) certOption.SetValidity(time.Hour * 24 * 90)
certOption.SetOrganization("Clash ManInTheMiddle Proxy Services") certOption.SetOrganization("Clash ManInTheMiddle Proxy Services")
opt := &mitm.Option{ opt := &mitm.Option{
@ -416,8 +416,7 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) {
return return
} }
outbound.MiddlemanServerAddress.Store(mitmListener.Address()) tunnel.MitmOutbound = outbound.NewMitm(mitmListener.Address())
tunnel.MitmOutbound = outbound.NewMitm()
log.Infoln("Mitm proxy listening at: %s", mitmListener.Address()) log.Infoln("Mitm proxy listening at: %s", mitmListener.Address())
} }

View File

@ -32,8 +32,8 @@ func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.C
}() }()
startOver: startOver:
if tc, ok := c.(*net.TCPConn); ok { if tcpConn, ok := c.(*net.TCPConn); ok {
_ = tc.SetKeepAlive(true) _ = tcpConn.SetKeepAlive(true)
} }
var conn *N.BufferedConn var conn *N.BufferedConn
@ -47,14 +47,13 @@ startOver:
readLoop: readLoop:
for { for {
err := conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive // use SetDeadline instead of Proxy-Connection keep-alive
if err != nil { if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
break readLoop break readLoop
} }
request, err := H.ReadRequest(conn.Reader()) request, err := H.ReadRequest(conn.Reader())
if err != nil { if err != nil {
handleError(opt, nil, err)
break readLoop break readLoop
} }
@ -83,27 +82,15 @@ readLoop:
goto readLoop goto readLoop
} }
b := make([]byte, 1) b, err := conn.Peek(1)
if _, err = session.conn.Read(b); err != nil { if err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop // close connection break readLoop // close connection
} }
buff := make([]byte, session.conn.(*N.BufferedConn).Buffered())
if _, err = session.conn.Read(buff); err != nil {
handleError(opt, session, err)
break readLoop // close connection
}
mrConn := &multiReaderConn{
Conn: session.conn,
reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buff), session.conn),
}
// TLS handshake. // TLS handshake.
if b[0] == 0x16 { if b[0] == 0x16 {
// TODO serve by generic host name maybe better? tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
tlsConn := tls.Server(mrConn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host))
// Handshake with the local client // Handshake with the local client
if err = tlsConn.Handshake(); err != nil { if err = tlsConn.Handshake(); err != nil {
@ -114,7 +101,7 @@ readLoop:
c = tlsConn c = tlsConn
} else { } else {
c = mrConn c = conn
} }
goto startOver goto startOver
@ -122,8 +109,11 @@ readLoop:
prepareRequest(c, session.request) prepareRequest(c, session.request)
H.RemoveHopByHopHeaders(session.request.Header)
H.RemoveExtraHTTPHostPort(session.request)
// hijack api // hijack api
if session.request.URL.Host == opt.ApiHost { if session.request.URL.Hostname() == opt.ApiHost {
if err = handleApiRequest(session, opt); err != nil { if err = handleApiRequest(session, opt); err != nil {
handleError(opt, session, err) handleError(opt, session, err)
break readLoop break readLoop
@ -162,7 +152,6 @@ readLoop:
handleError(opt, session, err) handleError(opt, session, err)
session.response = session.NewErrorResponse(err) session.response = session.NewErrorResponse(err)
if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") {
// TODO block unsupported host?
_ = writeResponse(session, false) _ = writeResponse(session, false)
break readLoop break readLoop
} }
@ -287,9 +276,6 @@ func prepareRequest(conn net.Conn, request *http.Request) {
if request.Header.Get("Accept-Encoding") != "" { if request.Header.Get("Accept-Encoding") != "" {
request.Header.Set("Accept-Encoding", "gzip") request.Header.Set("Accept-Encoding", "gzip")
} }
H.RemoveHopByHopHeaders(request.Header)
H.RemoveExtraHTTPHostPort(request)
} }
func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr { func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr {

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"time" "time"
@ -21,15 +20,6 @@ var (
ErrInvalidURL = errors.New("invalid URL") ErrInvalidURL = errors.New("invalid URL")
) )
type multiReaderConn struct {
net.Conn
reader io.Reader
}
func (c *multiReaderConn) Read(buf []byte) (int, error) {
return c.reader.Read(buf)
}
func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
if body == nil { if body == nil {
body = &bytes.Buffer{} body = &bytes.Buffer{}