From 24ce6622a2960cdc6f84e4dde2ae0a74570415a4 Mon Sep 17 00:00:00 2001 From: fishg <1423545+fishg@users.noreply.github.com> Date: Wed, 30 Mar 2022 23:54:52 +0800 Subject: [PATCH 01/14] Feature: add tls SNI sniffing (#68) --- common/snifer/tls/sniff.go | 148 +++++++++++++++++++++++++++++ common/snifer/tls/sniff_test.go | 159 ++++++++++++++++++++++++++++++++ component/resolver/enhancer.go | 7 ++ dns/enhancer.go | 6 ++ tunnel/statistic/tracker.go | 17 ++++ 5 files changed, 337 insertions(+) create mode 100644 common/snifer/tls/sniff.go create mode 100644 common/snifer/tls/sniff_test.go diff --git a/common/snifer/tls/sniff.go b/common/snifer/tls/sniff.go new file mode 100644 index 00000000..1471fc68 --- /dev/null +++ b/common/snifer/tls/sniff.go @@ -0,0 +1,148 @@ +package tls + +import ( + "encoding/binary" + "errors" + "strings" +) + +var ErrNoClue = errors.New("not enough information for making a decision") + +type SniffHeader struct { + domain string +} + +func (h *SniffHeader) Protocol() string { + return "tls" +} + +func (h *SniffHeader) Domain() string { + return h.domain +} + +var ( + errNotTLS = errors.New("not TLS header") + errNotClientHello = errors.New("not client hello") +) + +func IsValidTLSVersion(major, minor byte) bool { + return major == 3 +} + +// ReadClientHello returns server name (if any) from TLS client hello message. +// https://github.com/golang/go/blob/master/src/crypto/tls/handshake_messages.go#L300 +func ReadClientHello(data []byte, h *SniffHeader) error { + if len(data) < 42 { + return ErrNoClue + } + sessionIDLen := int(data[38]) + if sessionIDLen > 32 || len(data) < 39+sessionIDLen { + return ErrNoClue + } + data = data[39+sessionIDLen:] + if len(data) < 2 { + return ErrNoClue + } + // cipherSuiteLen is the number of bytes of cipher suite numbers. Since + // they are uint16s, the number must be even. + cipherSuiteLen := int(data[0])<<8 | int(data[1]) + if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { + return errNotClientHello + } + data = data[2+cipherSuiteLen:] + if len(data) < 1 { + return ErrNoClue + } + compressionMethodsLen := int(data[0]) + if len(data) < 1+compressionMethodsLen { + return ErrNoClue + } + data = data[1+compressionMethodsLen:] + + if len(data) == 0 { + return errNotClientHello + } + if len(data) < 2 { + return errNotClientHello + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if extensionsLength != len(data) { + return errNotClientHello + } + + for len(data) != 0 { + if len(data) < 4 { + return errNotClientHello + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return errNotClientHello + } + + if extension == 0x00 { /* extensionServerName */ + d := data[:length] + if len(d) < 2 { + return errNotClientHello + } + namesLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != namesLen { + return errNotClientHello + } + for len(d) > 0 { + if len(d) < 3 { + return errNotClientHello + } + nameType := d[0] + nameLen := int(d[1])<<8 | int(d[2]) + d = d[3:] + if len(d) < nameLen { + return errNotClientHello + } + if nameType == 0 { + serverName := string(d[:nameLen]) + // An SNI value may not include a + // trailing dot. See + // https://tools.ietf.org/html/rfc6066#section-3. + if strings.HasSuffix(serverName, ".") { + return errNotClientHello + } + h.domain = serverName + return nil + } + d = d[nameLen:] + } + } + data = data[length:] + } + + return errNotTLS +} + +func SniffTLS(b []byte) (*SniffHeader, error) { + if len(b) < 5 { + return nil, ErrNoClue + } + + if b[0] != 0x16 /* TLS Handshake */ { + return nil, errNotTLS + } + if !IsValidTLSVersion(b[1], b[2]) { + return nil, errNotTLS + } + headerLen := int(binary.BigEndian.Uint16(b[3:5])) + if 5+headerLen > len(b) { + return nil, ErrNoClue + } + + h := &SniffHeader{} + err := ReadClientHello(b[5:5+headerLen], h) + if err == nil { + return h, nil + } + return nil, err +} diff --git a/common/snifer/tls/sniff_test.go b/common/snifer/tls/sniff_test.go new file mode 100644 index 00000000..26f5f1ee --- /dev/null +++ b/common/snifer/tls/sniff_test.go @@ -0,0 +1,159 @@ +package tls + +import ( + "testing" +) + +func TestTLSHeaders(t *testing.T) { + cases := []struct { + input []byte + domain string + err bool + }{ + { + input: []byte{ + 0x16, 0x03, 0x01, 0x00, 0xc8, 0x01, 0x00, 0x00, + 0xc4, 0x03, 0x03, 0x1a, 0xac, 0xb2, 0xa8, 0xfe, + 0xb4, 0x96, 0x04, 0x5b, 0xca, 0xf7, 0xc1, 0xf4, + 0x2e, 0x53, 0x24, 0x6e, 0x34, 0x0c, 0x58, 0x36, + 0x71, 0x97, 0x59, 0xe9, 0x41, 0x66, 0xe2, 0x43, + 0xa0, 0x13, 0xb6, 0x00, 0x00, 0x20, 0x1a, 0x1a, + 0xc0, 0x2b, 0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, + 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0x14, 0xcc, 0x13, + 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x9c, 0x00, 0x9d, + 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x01, 0x00, + 0x00, 0x7b, 0xba, 0xba, 0x00, 0x00, 0xff, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x16, 0x00, + 0x14, 0x00, 0x00, 0x11, 0x63, 0x2e, 0x73, 0x2d, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, + 0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x00, 0x17, 0x00, + 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0d, 0x00, + 0x14, 0x00, 0x12, 0x04, 0x03, 0x08, 0x04, 0x04, + 0x01, 0x05, 0x03, 0x08, 0x05, 0x05, 0x01, 0x08, + 0x06, 0x06, 0x01, 0x02, 0x01, 0x00, 0x05, 0x00, + 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, + 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, + 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, + 0x2f, 0x31, 0x2e, 0x31, 0x00, 0x0b, 0x00, 0x02, + 0x01, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, + 0xaa, 0xaa, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, + 0xaa, 0xaa, 0x00, 0x01, 0x00, + }, + domain: "c.s-microsoft.com", + err: false, + }, + { + input: []byte{ + 0x16, 0x03, 0x01, 0x00, 0xee, 0x01, 0x00, 0x00, + 0xea, 0x03, 0x03, 0xe7, 0x91, 0x9e, 0x93, 0xca, + 0x78, 0x1b, 0x3c, 0xe0, 0x65, 0x25, 0x58, 0xb5, + 0x93, 0xe1, 0x0f, 0x85, 0xec, 0x9a, 0x66, 0x8e, + 0x61, 0x82, 0x88, 0xc8, 0xfc, 0xae, 0x1e, 0xca, + 0xd7, 0xa5, 0x63, 0x20, 0xbd, 0x1c, 0x00, 0x00, + 0x8b, 0xee, 0x09, 0xe3, 0x47, 0x6a, 0x0e, 0x74, + 0xb0, 0xbc, 0xa3, 0x02, 0xa7, 0x35, 0xe8, 0x85, + 0x70, 0x7c, 0x7a, 0xf0, 0x00, 0xdf, 0x4a, 0xea, + 0x87, 0x01, 0x14, 0x91, 0x00, 0x20, 0xea, 0xea, + 0xc0, 0x2b, 0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, + 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0x14, 0xcc, 0x13, + 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x9c, 0x00, 0x9d, + 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x01, 0x00, + 0x00, 0x81, 0x9a, 0x9a, 0x00, 0x00, 0xff, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, + 0x16, 0x00, 0x00, 0x13, 0x77, 0x77, 0x77, 0x30, + 0x37, 0x2e, 0x63, 0x6c, 0x69, 0x63, 0x6b, 0x74, + 0x61, 0x6c, 0x65, 0x2e, 0x6e, 0x65, 0x74, 0x00, + 0x17, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x14, 0x00, 0x12, 0x04, 0x03, 0x08, + 0x04, 0x04, 0x01, 0x05, 0x03, 0x08, 0x05, 0x05, + 0x01, 0x08, 0x06, 0x06, 0x01, 0x02, 0x01, 0x00, + 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x12, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, + 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, + 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, 0x75, 0x50, + 0x00, 0x00, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, + 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x9a, 0x9a, + 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x8a, 0x8a, + 0x00, 0x01, 0x00, + }, + domain: "www07.clicktale.net", + err: false, + }, + { + input: []byte{ + 0x16, 0x03, 0x01, 0x00, 0xe6, 0x01, 0x00, 0x00, 0xe2, 0x03, 0x03, 0x81, 0x47, 0xc1, + 0x66, 0xd5, 0x1b, 0xfa, 0x4b, 0xb5, 0xe0, 0x2a, 0xe1, 0xa7, 0x87, 0x13, 0x1d, 0x11, 0xaa, 0xc6, + 0xce, 0xfc, 0x7f, 0xab, 0x94, 0xc8, 0x62, 0xad, 0xc8, 0xab, 0x0c, 0xdd, 0xcb, 0x20, 0x6f, 0x9d, + 0x07, 0xf1, 0x95, 0x3e, 0x99, 0xd8, 0xf3, 0x6d, 0x97, 0xee, 0x19, 0x0b, 0x06, 0x1b, 0xf4, 0x84, + 0x0b, 0xb6, 0x8f, 0xcc, 0xde, 0xe2, 0xd0, 0x2d, 0x6b, 0x0c, 0x1f, 0x52, 0x53, 0x13, 0x00, 0x08, + 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0x00, 0xff, 0x01, 0x00, 0x00, 0x91, 0x00, 0x00, 0x00, 0x0c, + 0x00, 0x0a, 0x00, 0x00, 0x07, 0x64, 0x6f, 0x67, 0x66, 0x69, 0x73, 0x68, 0x00, 0x0b, 0x00, 0x04, + 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x0a, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, + 0x00, 0x19, 0x00, 0x18, 0x00, 0x23, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, + 0x00, 0x0d, 0x00, 0x1e, 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, + 0x08, 0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, + 0x06, 0x01, 0x00, 0x2b, 0x00, 0x07, 0x06, 0x7f, 0x1c, 0x7f, 0x1b, 0x7f, 0x1a, 0x00, 0x2d, 0x00, + 0x02, 0x01, 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x2f, 0x35, 0x0c, + 0xb6, 0x90, 0x0a, 0xb7, 0xd5, 0xc4, 0x1b, 0x2f, 0x60, 0xaa, 0x56, 0x7b, 0x3f, 0x71, 0xc8, 0x01, + 0x7e, 0x86, 0xd3, 0xb7, 0x0c, 0x29, 0x1a, 0x9e, 0x5b, 0x38, 0x3f, 0x01, 0x72, + }, + domain: "dogfish", + err: false, + }, + { + input: []byte{ + 0x16, 0x03, 0x01, 0x01, 0x03, 0x01, 0x00, 0x00, + 0xff, 0x03, 0x03, 0x3d, 0x89, 0x52, 0x9e, 0xee, + 0xbe, 0x17, 0x63, 0x75, 0xef, 0x29, 0xbd, 0x14, + 0x6a, 0x49, 0xe0, 0x2c, 0x37, 0x57, 0x71, 0x62, + 0x82, 0x44, 0x94, 0x8f, 0x6e, 0x94, 0x08, 0x45, + 0x7f, 0xdb, 0xc1, 0x00, 0x00, 0x3e, 0xc0, 0x2c, + 0xc0, 0x30, 0x00, 0x9f, 0xcc, 0xa9, 0xcc, 0xa8, + 0xcc, 0xaa, 0xc0, 0x2b, 0xc0, 0x2f, 0x00, 0x9e, + 0xc0, 0x24, 0xc0, 0x28, 0x00, 0x6b, 0xc0, 0x23, + 0xc0, 0x27, 0x00, 0x67, 0xc0, 0x0a, 0xc0, 0x14, + 0x00, 0x39, 0xc0, 0x09, 0xc0, 0x13, 0x00, 0x33, + 0x00, 0x9d, 0x00, 0x9c, 0x13, 0x02, 0x13, 0x03, + 0x13, 0x01, 0x00, 0x3d, 0x00, 0x3c, 0x00, 0x35, + 0x00, 0x2f, 0x00, 0xff, 0x01, 0x00, 0x00, 0x98, + 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, + 0x0b, 0x31, 0x30, 0x2e, 0x34, 0x32, 0x2e, 0x30, + 0x2e, 0x32, 0x34, 0x33, 0x00, 0x0b, 0x00, 0x04, + 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x0a, + 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x19, + 0x00, 0x18, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0d, + 0x00, 0x20, 0x00, 0x1e, 0x04, 0x03, 0x05, 0x03, + 0x06, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, + 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x03, + 0x02, 0x01, 0x02, 0x02, 0x04, 0x02, 0x05, 0x02, + 0x06, 0x02, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, + 0x00, 0x00, 0x00, 0x2b, 0x00, 0x09, 0x08, 0x7f, + 0x14, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01, 0x00, + 0x2d, 0x00, 0x03, 0x02, 0x01, 0x00, 0x00, 0x28, + 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, + 0x13, 0x7c, 0x6e, 0x97, 0xc4, 0xfd, 0x09, 0x2e, + 0x70, 0x2f, 0x73, 0x5a, 0x9b, 0x57, 0x4d, 0x5f, + 0x2b, 0x73, 0x2c, 0xa5, 0x4a, 0x98, 0x40, 0x3d, + 0x75, 0x6e, 0xb4, 0x76, 0xf9, 0x48, 0x8f, 0x36, + }, + domain: "10.42.0.243", + err: false, + }, + } + + for _, test := range cases { + header, err := SniffTLS(test.input) + if test.err { + if err == nil { + t.Errorf("Exepct error but nil in test %v", test) + } + } else { + if err != nil { + t.Errorf("Expect no error but actually %s in test %v", err.Error(), test) + } + if header.Domain() != test.domain { + t.Error("expect domain ", test.domain, " but got ", header.Domain()) + } + } + } +} diff --git a/component/resolver/enhancer.go b/component/resolver/enhancer.go index 9df3f54b..77f18374 100644 --- a/component/resolver/enhancer.go +++ b/component/resolver/enhancer.go @@ -14,6 +14,7 @@ type Enhancer interface { IsExistFakeIP(net.IP) bool FindHostByIP(net.IP) (string, bool) FlushFakeIP() error + InsertHostByIP(net.IP, string) } func FakeIPEnabled() bool { @@ -56,6 +57,12 @@ func IsExistFakeIP(ip net.IP) bool { return false } +func InsertHostByIP(ip net.IP, host string) { + if mapper := DefaultHostMapper; mapper != nil { + mapper.InsertHostByIP(ip, host) + } +} + func FindHostByIP(ip net.IP) (string, bool) { if mapper := DefaultHostMapper; mapper != nil { return mapper.FindHostByIP(ip) diff --git a/dns/enhancer.go b/dns/enhancer.go index 9bf568c7..016ff02a 100644 --- a/dns/enhancer.go +++ b/dns/enhancer.go @@ -74,6 +74,12 @@ func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) { return "", false } +func (h *ResolverEnhancer) InsertHostByIP(ip net.IP, host string) { + if mapping := h.mapping; mapping != nil { + h.mapping.Set(ip.String(), host) + } +} + func (h *ResolverEnhancer) PatchFrom(o *ResolverEnhancer) { if h.mapping != nil && o.mapping != nil { o.mapping.CloneTo(h.mapping) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 1f5f1f9c..f213ca61 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -1,10 +1,14 @@ package statistic import ( + "errors" "net" "time" + "github.com/Dreamacro/clash/common/snifer/tls" + "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" "github.com/gofrs/uuid" "go.uber.org/atomic" @@ -48,7 +52,20 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) tt.manager.PushUploaded(upload) + if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443") { + header, err := tls.SniffTLS(b) + if err != nil { + // log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort) + } else { + resolver.InsertHostByIP(tt.Metadata.DstIP, header.Domain()) + log.Warnln("use sni update host: %s ip: %s", header.Domain(), tt.Metadata.DstIP.String()) + tt.manager.Leave(tt) + tt.Conn.Close() + return n, errors.New("sni update, break current link to avoid leaks") + } + } tt.UploadTotal.Add(upload) + return n, err } From e877b68179a735e5743a732b04e0de1dba238267 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Thu, 31 Mar 2022 21:20:46 +0800 Subject: [PATCH 02/14] Chore: revert "Feature: add tls SNI sniffing (#68)" This reverts commit 24ce6622a2960cdc6f84e4dde2ae0a74570415a4. --- common/snifer/tls/sniff.go | 148 ----------------------------- common/snifer/tls/sniff_test.go | 159 -------------------------------- component/resolver/enhancer.go | 7 -- dns/enhancer.go | 6 -- tunnel/statistic/tracker.go | 17 ---- 5 files changed, 337 deletions(-) delete mode 100644 common/snifer/tls/sniff.go delete mode 100644 common/snifer/tls/sniff_test.go diff --git a/common/snifer/tls/sniff.go b/common/snifer/tls/sniff.go deleted file mode 100644 index 1471fc68..00000000 --- a/common/snifer/tls/sniff.go +++ /dev/null @@ -1,148 +0,0 @@ -package tls - -import ( - "encoding/binary" - "errors" - "strings" -) - -var ErrNoClue = errors.New("not enough information for making a decision") - -type SniffHeader struct { - domain string -} - -func (h *SniffHeader) Protocol() string { - return "tls" -} - -func (h *SniffHeader) Domain() string { - return h.domain -} - -var ( - errNotTLS = errors.New("not TLS header") - errNotClientHello = errors.New("not client hello") -) - -func IsValidTLSVersion(major, minor byte) bool { - return major == 3 -} - -// ReadClientHello returns server name (if any) from TLS client hello message. -// https://github.com/golang/go/blob/master/src/crypto/tls/handshake_messages.go#L300 -func ReadClientHello(data []byte, h *SniffHeader) error { - if len(data) < 42 { - return ErrNoClue - } - sessionIDLen := int(data[38]) - if sessionIDLen > 32 || len(data) < 39+sessionIDLen { - return ErrNoClue - } - data = data[39+sessionIDLen:] - if len(data) < 2 { - return ErrNoClue - } - // cipherSuiteLen is the number of bytes of cipher suite numbers. Since - // they are uint16s, the number must be even. - cipherSuiteLen := int(data[0])<<8 | int(data[1]) - if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { - return errNotClientHello - } - data = data[2+cipherSuiteLen:] - if len(data) < 1 { - return ErrNoClue - } - compressionMethodsLen := int(data[0]) - if len(data) < 1+compressionMethodsLen { - return ErrNoClue - } - data = data[1+compressionMethodsLen:] - - if len(data) == 0 { - return errNotClientHello - } - if len(data) < 2 { - return errNotClientHello - } - - extensionsLength := int(data[0])<<8 | int(data[1]) - data = data[2:] - if extensionsLength != len(data) { - return errNotClientHello - } - - for len(data) != 0 { - if len(data) < 4 { - return errNotClientHello - } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) < length { - return errNotClientHello - } - - if extension == 0x00 { /* extensionServerName */ - d := data[:length] - if len(d) < 2 { - return errNotClientHello - } - namesLen := int(d[0])<<8 | int(d[1]) - d = d[2:] - if len(d) != namesLen { - return errNotClientHello - } - for len(d) > 0 { - if len(d) < 3 { - return errNotClientHello - } - nameType := d[0] - nameLen := int(d[1])<<8 | int(d[2]) - d = d[3:] - if len(d) < nameLen { - return errNotClientHello - } - if nameType == 0 { - serverName := string(d[:nameLen]) - // An SNI value may not include a - // trailing dot. See - // https://tools.ietf.org/html/rfc6066#section-3. - if strings.HasSuffix(serverName, ".") { - return errNotClientHello - } - h.domain = serverName - return nil - } - d = d[nameLen:] - } - } - data = data[length:] - } - - return errNotTLS -} - -func SniffTLS(b []byte) (*SniffHeader, error) { - if len(b) < 5 { - return nil, ErrNoClue - } - - if b[0] != 0x16 /* TLS Handshake */ { - return nil, errNotTLS - } - if !IsValidTLSVersion(b[1], b[2]) { - return nil, errNotTLS - } - headerLen := int(binary.BigEndian.Uint16(b[3:5])) - if 5+headerLen > len(b) { - return nil, ErrNoClue - } - - h := &SniffHeader{} - err := ReadClientHello(b[5:5+headerLen], h) - if err == nil { - return h, nil - } - return nil, err -} diff --git a/common/snifer/tls/sniff_test.go b/common/snifer/tls/sniff_test.go deleted file mode 100644 index 26f5f1ee..00000000 --- a/common/snifer/tls/sniff_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package tls - -import ( - "testing" -) - -func TestTLSHeaders(t *testing.T) { - cases := []struct { - input []byte - domain string - err bool - }{ - { - input: []byte{ - 0x16, 0x03, 0x01, 0x00, 0xc8, 0x01, 0x00, 0x00, - 0xc4, 0x03, 0x03, 0x1a, 0xac, 0xb2, 0xa8, 0xfe, - 0xb4, 0x96, 0x04, 0x5b, 0xca, 0xf7, 0xc1, 0xf4, - 0x2e, 0x53, 0x24, 0x6e, 0x34, 0x0c, 0x58, 0x36, - 0x71, 0x97, 0x59, 0xe9, 0x41, 0x66, 0xe2, 0x43, - 0xa0, 0x13, 0xb6, 0x00, 0x00, 0x20, 0x1a, 0x1a, - 0xc0, 0x2b, 0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, - 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0x14, 0xcc, 0x13, - 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x9c, 0x00, 0x9d, - 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x01, 0x00, - 0x00, 0x7b, 0xba, 0xba, 0x00, 0x00, 0xff, 0x01, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x16, 0x00, - 0x14, 0x00, 0x00, 0x11, 0x63, 0x2e, 0x73, 0x2d, - 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, - 0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x00, 0x17, 0x00, - 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0d, 0x00, - 0x14, 0x00, 0x12, 0x04, 0x03, 0x08, 0x04, 0x04, - 0x01, 0x05, 0x03, 0x08, 0x05, 0x05, 0x01, 0x08, - 0x06, 0x06, 0x01, 0x02, 0x01, 0x00, 0x05, 0x00, - 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, - 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, - 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, - 0x2f, 0x31, 0x2e, 0x31, 0x00, 0x0b, 0x00, 0x02, - 0x01, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, - 0xaa, 0xaa, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, - 0xaa, 0xaa, 0x00, 0x01, 0x00, - }, - domain: "c.s-microsoft.com", - err: false, - }, - { - input: []byte{ - 0x16, 0x03, 0x01, 0x00, 0xee, 0x01, 0x00, 0x00, - 0xea, 0x03, 0x03, 0xe7, 0x91, 0x9e, 0x93, 0xca, - 0x78, 0x1b, 0x3c, 0xe0, 0x65, 0x25, 0x58, 0xb5, - 0x93, 0xe1, 0x0f, 0x85, 0xec, 0x9a, 0x66, 0x8e, - 0x61, 0x82, 0x88, 0xc8, 0xfc, 0xae, 0x1e, 0xca, - 0xd7, 0xa5, 0x63, 0x20, 0xbd, 0x1c, 0x00, 0x00, - 0x8b, 0xee, 0x09, 0xe3, 0x47, 0x6a, 0x0e, 0x74, - 0xb0, 0xbc, 0xa3, 0x02, 0xa7, 0x35, 0xe8, 0x85, - 0x70, 0x7c, 0x7a, 0xf0, 0x00, 0xdf, 0x4a, 0xea, - 0x87, 0x01, 0x14, 0x91, 0x00, 0x20, 0xea, 0xea, - 0xc0, 0x2b, 0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, - 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0x14, 0xcc, 0x13, - 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x9c, 0x00, 0x9d, - 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x01, 0x00, - 0x00, 0x81, 0x9a, 0x9a, 0x00, 0x00, 0xff, 0x01, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, - 0x16, 0x00, 0x00, 0x13, 0x77, 0x77, 0x77, 0x30, - 0x37, 0x2e, 0x63, 0x6c, 0x69, 0x63, 0x6b, 0x74, - 0x61, 0x6c, 0x65, 0x2e, 0x6e, 0x65, 0x74, 0x00, - 0x17, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x14, 0x00, 0x12, 0x04, 0x03, 0x08, - 0x04, 0x04, 0x01, 0x05, 0x03, 0x08, 0x05, 0x05, - 0x01, 0x08, 0x06, 0x06, 0x01, 0x02, 0x01, 0x00, - 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x12, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, - 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, - 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, 0x75, 0x50, - 0x00, 0x00, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, - 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x9a, 0x9a, - 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x8a, 0x8a, - 0x00, 0x01, 0x00, - }, - domain: "www07.clicktale.net", - err: false, - }, - { - input: []byte{ - 0x16, 0x03, 0x01, 0x00, 0xe6, 0x01, 0x00, 0x00, 0xe2, 0x03, 0x03, 0x81, 0x47, 0xc1, - 0x66, 0xd5, 0x1b, 0xfa, 0x4b, 0xb5, 0xe0, 0x2a, 0xe1, 0xa7, 0x87, 0x13, 0x1d, 0x11, 0xaa, 0xc6, - 0xce, 0xfc, 0x7f, 0xab, 0x94, 0xc8, 0x62, 0xad, 0xc8, 0xab, 0x0c, 0xdd, 0xcb, 0x20, 0x6f, 0x9d, - 0x07, 0xf1, 0x95, 0x3e, 0x99, 0xd8, 0xf3, 0x6d, 0x97, 0xee, 0x19, 0x0b, 0x06, 0x1b, 0xf4, 0x84, - 0x0b, 0xb6, 0x8f, 0xcc, 0xde, 0xe2, 0xd0, 0x2d, 0x6b, 0x0c, 0x1f, 0x52, 0x53, 0x13, 0x00, 0x08, - 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0x00, 0xff, 0x01, 0x00, 0x00, 0x91, 0x00, 0x00, 0x00, 0x0c, - 0x00, 0x0a, 0x00, 0x00, 0x07, 0x64, 0x6f, 0x67, 0x66, 0x69, 0x73, 0x68, 0x00, 0x0b, 0x00, 0x04, - 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x0a, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, - 0x00, 0x19, 0x00, 0x18, 0x00, 0x23, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, - 0x00, 0x0d, 0x00, 0x1e, 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, - 0x08, 0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, - 0x06, 0x01, 0x00, 0x2b, 0x00, 0x07, 0x06, 0x7f, 0x1c, 0x7f, 0x1b, 0x7f, 0x1a, 0x00, 0x2d, 0x00, - 0x02, 0x01, 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x2f, 0x35, 0x0c, - 0xb6, 0x90, 0x0a, 0xb7, 0xd5, 0xc4, 0x1b, 0x2f, 0x60, 0xaa, 0x56, 0x7b, 0x3f, 0x71, 0xc8, 0x01, - 0x7e, 0x86, 0xd3, 0xb7, 0x0c, 0x29, 0x1a, 0x9e, 0x5b, 0x38, 0x3f, 0x01, 0x72, - }, - domain: "dogfish", - err: false, - }, - { - input: []byte{ - 0x16, 0x03, 0x01, 0x01, 0x03, 0x01, 0x00, 0x00, - 0xff, 0x03, 0x03, 0x3d, 0x89, 0x52, 0x9e, 0xee, - 0xbe, 0x17, 0x63, 0x75, 0xef, 0x29, 0xbd, 0x14, - 0x6a, 0x49, 0xe0, 0x2c, 0x37, 0x57, 0x71, 0x62, - 0x82, 0x44, 0x94, 0x8f, 0x6e, 0x94, 0x08, 0x45, - 0x7f, 0xdb, 0xc1, 0x00, 0x00, 0x3e, 0xc0, 0x2c, - 0xc0, 0x30, 0x00, 0x9f, 0xcc, 0xa9, 0xcc, 0xa8, - 0xcc, 0xaa, 0xc0, 0x2b, 0xc0, 0x2f, 0x00, 0x9e, - 0xc0, 0x24, 0xc0, 0x28, 0x00, 0x6b, 0xc0, 0x23, - 0xc0, 0x27, 0x00, 0x67, 0xc0, 0x0a, 0xc0, 0x14, - 0x00, 0x39, 0xc0, 0x09, 0xc0, 0x13, 0x00, 0x33, - 0x00, 0x9d, 0x00, 0x9c, 0x13, 0x02, 0x13, 0x03, - 0x13, 0x01, 0x00, 0x3d, 0x00, 0x3c, 0x00, 0x35, - 0x00, 0x2f, 0x00, 0xff, 0x01, 0x00, 0x00, 0x98, - 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, - 0x0b, 0x31, 0x30, 0x2e, 0x34, 0x32, 0x2e, 0x30, - 0x2e, 0x32, 0x34, 0x33, 0x00, 0x0b, 0x00, 0x04, - 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x0a, - 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x19, - 0x00, 0x18, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0d, - 0x00, 0x20, 0x00, 0x1e, 0x04, 0x03, 0x05, 0x03, - 0x06, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, - 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x03, - 0x02, 0x01, 0x02, 0x02, 0x04, 0x02, 0x05, 0x02, - 0x06, 0x02, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, - 0x00, 0x00, 0x00, 0x2b, 0x00, 0x09, 0x08, 0x7f, - 0x14, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01, 0x00, - 0x2d, 0x00, 0x03, 0x02, 0x01, 0x00, 0x00, 0x28, - 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, - 0x13, 0x7c, 0x6e, 0x97, 0xc4, 0xfd, 0x09, 0x2e, - 0x70, 0x2f, 0x73, 0x5a, 0x9b, 0x57, 0x4d, 0x5f, - 0x2b, 0x73, 0x2c, 0xa5, 0x4a, 0x98, 0x40, 0x3d, - 0x75, 0x6e, 0xb4, 0x76, 0xf9, 0x48, 0x8f, 0x36, - }, - domain: "10.42.0.243", - err: false, - }, - } - - for _, test := range cases { - header, err := SniffTLS(test.input) - if test.err { - if err == nil { - t.Errorf("Exepct error but nil in test %v", test) - } - } else { - if err != nil { - t.Errorf("Expect no error but actually %s in test %v", err.Error(), test) - } - if header.Domain() != test.domain { - t.Error("expect domain ", test.domain, " but got ", header.Domain()) - } - } - } -} diff --git a/component/resolver/enhancer.go b/component/resolver/enhancer.go index 77f18374..9df3f54b 100644 --- a/component/resolver/enhancer.go +++ b/component/resolver/enhancer.go @@ -14,7 +14,6 @@ type Enhancer interface { IsExistFakeIP(net.IP) bool FindHostByIP(net.IP) (string, bool) FlushFakeIP() error - InsertHostByIP(net.IP, string) } func FakeIPEnabled() bool { @@ -57,12 +56,6 @@ func IsExistFakeIP(ip net.IP) bool { return false } -func InsertHostByIP(ip net.IP, host string) { - if mapper := DefaultHostMapper; mapper != nil { - mapper.InsertHostByIP(ip, host) - } -} - func FindHostByIP(ip net.IP) (string, bool) { if mapper := DefaultHostMapper; mapper != nil { return mapper.FindHostByIP(ip) diff --git a/dns/enhancer.go b/dns/enhancer.go index 016ff02a..9bf568c7 100644 --- a/dns/enhancer.go +++ b/dns/enhancer.go @@ -74,12 +74,6 @@ func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) { return "", false } -func (h *ResolverEnhancer) InsertHostByIP(ip net.IP, host string) { - if mapping := h.mapping; mapping != nil { - h.mapping.Set(ip.String(), host) - } -} - func (h *ResolverEnhancer) PatchFrom(o *ResolverEnhancer) { if h.mapping != nil && o.mapping != nil { o.mapping.CloneTo(h.mapping) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index f213ca61..1f5f1f9c 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -1,14 +1,10 @@ package statistic import ( - "errors" "net" "time" - "github.com/Dreamacro/clash/common/snifer/tls" - "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/log" "github.com/gofrs/uuid" "go.uber.org/atomic" @@ -52,20 +48,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) tt.manager.PushUploaded(upload) - if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443") { - header, err := tls.SniffTLS(b) - if err != nil { - // log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort) - } else { - resolver.InsertHostByIP(tt.Metadata.DstIP, header.Domain()) - log.Warnln("use sni update host: %s ip: %s", header.Domain(), tt.Metadata.DstIP.String()) - tt.manager.Leave(tt) - tt.Conn.Close() - return n, errors.New("sni update, break current link to avoid leaks") - } - } tt.UploadTotal.Add(upload) - return n, err } From c495d314d4d94db15153cf1f558db35b0d5cc9a6 Mon Sep 17 00:00:00 2001 From: fishg <1423545+fishg@users.noreply.github.com> Date: Wed, 30 Mar 2022 23:54:52 +0800 Subject: [PATCH 03/14] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0tls=20sni=20?= =?UTF-8?q?=E5=97=85=E6=8E=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Conflicts: # tunnel/statistic/tracker.go # tunnel/tunnel.go --- common/snifer/tls/sniff.go | 148 +++++++++++++++++++++++++++++ common/snifer/tls/sniff_test.go | 159 ++++++++++++++++++++++++++++++++ component/resolver/enhancer.go | 7 ++ dns/enhancer.go | 6 ++ tunnel/statistic/tracker.go | 18 ++++ 5 files changed, 338 insertions(+) create mode 100644 common/snifer/tls/sniff.go create mode 100644 common/snifer/tls/sniff_test.go diff --git a/common/snifer/tls/sniff.go b/common/snifer/tls/sniff.go new file mode 100644 index 00000000..1471fc68 --- /dev/null +++ b/common/snifer/tls/sniff.go @@ -0,0 +1,148 @@ +package tls + +import ( + "encoding/binary" + "errors" + "strings" +) + +var ErrNoClue = errors.New("not enough information for making a decision") + +type SniffHeader struct { + domain string +} + +func (h *SniffHeader) Protocol() string { + return "tls" +} + +func (h *SniffHeader) Domain() string { + return h.domain +} + +var ( + errNotTLS = errors.New("not TLS header") + errNotClientHello = errors.New("not client hello") +) + +func IsValidTLSVersion(major, minor byte) bool { + return major == 3 +} + +// ReadClientHello returns server name (if any) from TLS client hello message. +// https://github.com/golang/go/blob/master/src/crypto/tls/handshake_messages.go#L300 +func ReadClientHello(data []byte, h *SniffHeader) error { + if len(data) < 42 { + return ErrNoClue + } + sessionIDLen := int(data[38]) + if sessionIDLen > 32 || len(data) < 39+sessionIDLen { + return ErrNoClue + } + data = data[39+sessionIDLen:] + if len(data) < 2 { + return ErrNoClue + } + // cipherSuiteLen is the number of bytes of cipher suite numbers. Since + // they are uint16s, the number must be even. + cipherSuiteLen := int(data[0])<<8 | int(data[1]) + if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { + return errNotClientHello + } + data = data[2+cipherSuiteLen:] + if len(data) < 1 { + return ErrNoClue + } + compressionMethodsLen := int(data[0]) + if len(data) < 1+compressionMethodsLen { + return ErrNoClue + } + data = data[1+compressionMethodsLen:] + + if len(data) == 0 { + return errNotClientHello + } + if len(data) < 2 { + return errNotClientHello + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if extensionsLength != len(data) { + return errNotClientHello + } + + for len(data) != 0 { + if len(data) < 4 { + return errNotClientHello + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return errNotClientHello + } + + if extension == 0x00 { /* extensionServerName */ + d := data[:length] + if len(d) < 2 { + return errNotClientHello + } + namesLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != namesLen { + return errNotClientHello + } + for len(d) > 0 { + if len(d) < 3 { + return errNotClientHello + } + nameType := d[0] + nameLen := int(d[1])<<8 | int(d[2]) + d = d[3:] + if len(d) < nameLen { + return errNotClientHello + } + if nameType == 0 { + serverName := string(d[:nameLen]) + // An SNI value may not include a + // trailing dot. See + // https://tools.ietf.org/html/rfc6066#section-3. + if strings.HasSuffix(serverName, ".") { + return errNotClientHello + } + h.domain = serverName + return nil + } + d = d[nameLen:] + } + } + data = data[length:] + } + + return errNotTLS +} + +func SniffTLS(b []byte) (*SniffHeader, error) { + if len(b) < 5 { + return nil, ErrNoClue + } + + if b[0] != 0x16 /* TLS Handshake */ { + return nil, errNotTLS + } + if !IsValidTLSVersion(b[1], b[2]) { + return nil, errNotTLS + } + headerLen := int(binary.BigEndian.Uint16(b[3:5])) + if 5+headerLen > len(b) { + return nil, ErrNoClue + } + + h := &SniffHeader{} + err := ReadClientHello(b[5:5+headerLen], h) + if err == nil { + return h, nil + } + return nil, err +} diff --git a/common/snifer/tls/sniff_test.go b/common/snifer/tls/sniff_test.go new file mode 100644 index 00000000..26f5f1ee --- /dev/null +++ b/common/snifer/tls/sniff_test.go @@ -0,0 +1,159 @@ +package tls + +import ( + "testing" +) + +func TestTLSHeaders(t *testing.T) { + cases := []struct { + input []byte + domain string + err bool + }{ + { + input: []byte{ + 0x16, 0x03, 0x01, 0x00, 0xc8, 0x01, 0x00, 0x00, + 0xc4, 0x03, 0x03, 0x1a, 0xac, 0xb2, 0xa8, 0xfe, + 0xb4, 0x96, 0x04, 0x5b, 0xca, 0xf7, 0xc1, 0xf4, + 0x2e, 0x53, 0x24, 0x6e, 0x34, 0x0c, 0x58, 0x36, + 0x71, 0x97, 0x59, 0xe9, 0x41, 0x66, 0xe2, 0x43, + 0xa0, 0x13, 0xb6, 0x00, 0x00, 0x20, 0x1a, 0x1a, + 0xc0, 0x2b, 0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, + 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0x14, 0xcc, 0x13, + 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x9c, 0x00, 0x9d, + 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x01, 0x00, + 0x00, 0x7b, 0xba, 0xba, 0x00, 0x00, 0xff, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x16, 0x00, + 0x14, 0x00, 0x00, 0x11, 0x63, 0x2e, 0x73, 0x2d, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, + 0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x00, 0x17, 0x00, + 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0d, 0x00, + 0x14, 0x00, 0x12, 0x04, 0x03, 0x08, 0x04, 0x04, + 0x01, 0x05, 0x03, 0x08, 0x05, 0x05, 0x01, 0x08, + 0x06, 0x06, 0x01, 0x02, 0x01, 0x00, 0x05, 0x00, + 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, + 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, + 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, + 0x2f, 0x31, 0x2e, 0x31, 0x00, 0x0b, 0x00, 0x02, + 0x01, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, + 0xaa, 0xaa, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, + 0xaa, 0xaa, 0x00, 0x01, 0x00, + }, + domain: "c.s-microsoft.com", + err: false, + }, + { + input: []byte{ + 0x16, 0x03, 0x01, 0x00, 0xee, 0x01, 0x00, 0x00, + 0xea, 0x03, 0x03, 0xe7, 0x91, 0x9e, 0x93, 0xca, + 0x78, 0x1b, 0x3c, 0xe0, 0x65, 0x25, 0x58, 0xb5, + 0x93, 0xe1, 0x0f, 0x85, 0xec, 0x9a, 0x66, 0x8e, + 0x61, 0x82, 0x88, 0xc8, 0xfc, 0xae, 0x1e, 0xca, + 0xd7, 0xa5, 0x63, 0x20, 0xbd, 0x1c, 0x00, 0x00, + 0x8b, 0xee, 0x09, 0xe3, 0x47, 0x6a, 0x0e, 0x74, + 0xb0, 0xbc, 0xa3, 0x02, 0xa7, 0x35, 0xe8, 0x85, + 0x70, 0x7c, 0x7a, 0xf0, 0x00, 0xdf, 0x4a, 0xea, + 0x87, 0x01, 0x14, 0x91, 0x00, 0x20, 0xea, 0xea, + 0xc0, 0x2b, 0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, + 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0x14, 0xcc, 0x13, + 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x9c, 0x00, 0x9d, + 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x01, 0x00, + 0x00, 0x81, 0x9a, 0x9a, 0x00, 0x00, 0xff, 0x01, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, + 0x16, 0x00, 0x00, 0x13, 0x77, 0x77, 0x77, 0x30, + 0x37, 0x2e, 0x63, 0x6c, 0x69, 0x63, 0x6b, 0x74, + 0x61, 0x6c, 0x65, 0x2e, 0x6e, 0x65, 0x74, 0x00, + 0x17, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x14, 0x00, 0x12, 0x04, 0x03, 0x08, + 0x04, 0x04, 0x01, 0x05, 0x03, 0x08, 0x05, 0x05, + 0x01, 0x08, 0x06, 0x06, 0x01, 0x02, 0x01, 0x00, + 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x12, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, + 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, + 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, 0x75, 0x50, + 0x00, 0x00, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, + 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x9a, 0x9a, + 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x8a, 0x8a, + 0x00, 0x01, 0x00, + }, + domain: "www07.clicktale.net", + err: false, + }, + { + input: []byte{ + 0x16, 0x03, 0x01, 0x00, 0xe6, 0x01, 0x00, 0x00, 0xe2, 0x03, 0x03, 0x81, 0x47, 0xc1, + 0x66, 0xd5, 0x1b, 0xfa, 0x4b, 0xb5, 0xe0, 0x2a, 0xe1, 0xa7, 0x87, 0x13, 0x1d, 0x11, 0xaa, 0xc6, + 0xce, 0xfc, 0x7f, 0xab, 0x94, 0xc8, 0x62, 0xad, 0xc8, 0xab, 0x0c, 0xdd, 0xcb, 0x20, 0x6f, 0x9d, + 0x07, 0xf1, 0x95, 0x3e, 0x99, 0xd8, 0xf3, 0x6d, 0x97, 0xee, 0x19, 0x0b, 0x06, 0x1b, 0xf4, 0x84, + 0x0b, 0xb6, 0x8f, 0xcc, 0xde, 0xe2, 0xd0, 0x2d, 0x6b, 0x0c, 0x1f, 0x52, 0x53, 0x13, 0x00, 0x08, + 0x13, 0x02, 0x13, 0x03, 0x13, 0x01, 0x00, 0xff, 0x01, 0x00, 0x00, 0x91, 0x00, 0x00, 0x00, 0x0c, + 0x00, 0x0a, 0x00, 0x00, 0x07, 0x64, 0x6f, 0x67, 0x66, 0x69, 0x73, 0x68, 0x00, 0x0b, 0x00, 0x04, + 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x0a, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, + 0x00, 0x19, 0x00, 0x18, 0x00, 0x23, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, + 0x00, 0x0d, 0x00, 0x1e, 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, + 0x08, 0x09, 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, + 0x06, 0x01, 0x00, 0x2b, 0x00, 0x07, 0x06, 0x7f, 0x1c, 0x7f, 0x1b, 0x7f, 0x1a, 0x00, 0x2d, 0x00, + 0x02, 0x01, 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x2f, 0x35, 0x0c, + 0xb6, 0x90, 0x0a, 0xb7, 0xd5, 0xc4, 0x1b, 0x2f, 0x60, 0xaa, 0x56, 0x7b, 0x3f, 0x71, 0xc8, 0x01, + 0x7e, 0x86, 0xd3, 0xb7, 0x0c, 0x29, 0x1a, 0x9e, 0x5b, 0x38, 0x3f, 0x01, 0x72, + }, + domain: "dogfish", + err: false, + }, + { + input: []byte{ + 0x16, 0x03, 0x01, 0x01, 0x03, 0x01, 0x00, 0x00, + 0xff, 0x03, 0x03, 0x3d, 0x89, 0x52, 0x9e, 0xee, + 0xbe, 0x17, 0x63, 0x75, 0xef, 0x29, 0xbd, 0x14, + 0x6a, 0x49, 0xe0, 0x2c, 0x37, 0x57, 0x71, 0x62, + 0x82, 0x44, 0x94, 0x8f, 0x6e, 0x94, 0x08, 0x45, + 0x7f, 0xdb, 0xc1, 0x00, 0x00, 0x3e, 0xc0, 0x2c, + 0xc0, 0x30, 0x00, 0x9f, 0xcc, 0xa9, 0xcc, 0xa8, + 0xcc, 0xaa, 0xc0, 0x2b, 0xc0, 0x2f, 0x00, 0x9e, + 0xc0, 0x24, 0xc0, 0x28, 0x00, 0x6b, 0xc0, 0x23, + 0xc0, 0x27, 0x00, 0x67, 0xc0, 0x0a, 0xc0, 0x14, + 0x00, 0x39, 0xc0, 0x09, 0xc0, 0x13, 0x00, 0x33, + 0x00, 0x9d, 0x00, 0x9c, 0x13, 0x02, 0x13, 0x03, + 0x13, 0x01, 0x00, 0x3d, 0x00, 0x3c, 0x00, 0x35, + 0x00, 0x2f, 0x00, 0xff, 0x01, 0x00, 0x00, 0x98, + 0x00, 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x00, + 0x0b, 0x31, 0x30, 0x2e, 0x34, 0x32, 0x2e, 0x30, + 0x2e, 0x32, 0x34, 0x33, 0x00, 0x0b, 0x00, 0x04, + 0x03, 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x0a, + 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x19, + 0x00, 0x18, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0d, + 0x00, 0x20, 0x00, 0x1e, 0x04, 0x03, 0x05, 0x03, + 0x06, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, + 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x03, + 0x02, 0x01, 0x02, 0x02, 0x04, 0x02, 0x05, 0x02, + 0x06, 0x02, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, + 0x00, 0x00, 0x00, 0x2b, 0x00, 0x09, 0x08, 0x7f, + 0x14, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01, 0x00, + 0x2d, 0x00, 0x03, 0x02, 0x01, 0x00, 0x00, 0x28, + 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, + 0x13, 0x7c, 0x6e, 0x97, 0xc4, 0xfd, 0x09, 0x2e, + 0x70, 0x2f, 0x73, 0x5a, 0x9b, 0x57, 0x4d, 0x5f, + 0x2b, 0x73, 0x2c, 0xa5, 0x4a, 0x98, 0x40, 0x3d, + 0x75, 0x6e, 0xb4, 0x76, 0xf9, 0x48, 0x8f, 0x36, + }, + domain: "10.42.0.243", + err: false, + }, + } + + for _, test := range cases { + header, err := SniffTLS(test.input) + if test.err { + if err == nil { + t.Errorf("Exepct error but nil in test %v", test) + } + } else { + if err != nil { + t.Errorf("Expect no error but actually %s in test %v", err.Error(), test) + } + if header.Domain() != test.domain { + t.Error("expect domain ", test.domain, " but got ", header.Domain()) + } + } + } +} diff --git a/component/resolver/enhancer.go b/component/resolver/enhancer.go index 9df3f54b..77f18374 100644 --- a/component/resolver/enhancer.go +++ b/component/resolver/enhancer.go @@ -14,6 +14,7 @@ type Enhancer interface { IsExistFakeIP(net.IP) bool FindHostByIP(net.IP) (string, bool) FlushFakeIP() error + InsertHostByIP(net.IP, string) } func FakeIPEnabled() bool { @@ -56,6 +57,12 @@ func IsExistFakeIP(ip net.IP) bool { return false } +func InsertHostByIP(ip net.IP, host string) { + if mapper := DefaultHostMapper; mapper != nil { + mapper.InsertHostByIP(ip, host) + } +} + func FindHostByIP(ip net.IP) (string, bool) { if mapper := DefaultHostMapper; mapper != nil { return mapper.FindHostByIP(ip) diff --git a/dns/enhancer.go b/dns/enhancer.go index 9bf568c7..016ff02a 100644 --- a/dns/enhancer.go +++ b/dns/enhancer.go @@ -74,6 +74,12 @@ func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) { return "", false } +func (h *ResolverEnhancer) InsertHostByIP(ip net.IP, host string) { + if mapping := h.mapping; mapping != nil { + h.mapping.Set(ip.String(), host) + } +} + func (h *ResolverEnhancer) PatchFrom(o *ResolverEnhancer) { if h.mapping != nil && o.mapping != nil { o.mapping.CloneTo(h.mapping) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 1f5f1f9c..77713da2 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -1,11 +1,15 @@ package statistic import ( + "errors" "net" "time" + "github.com/Dreamacro/clash/common/snifer/tls" + "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" "github.com/gofrs/uuid" "go.uber.org/atomic" ) @@ -48,7 +52,21 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) tt.manager.PushUploaded(upload) + if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443") { + header, err := tls.SniffTLS(b) + if err != nil { + // log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort) + } else { + tt.Metadata.Host = header.Domain() + resolver.InsertHostByIP(tt.Metadata.DstIP, tt.Metadata.Host) + log.Errorln("sni %s %s", tt.Metadata.Host, tt.Metadata.DstIP.String()) + tt.manager.Leave(tt) + tt.Conn.Close() + return n, errors.New("sni update") + } + } tt.UploadTotal.Add(upload) + return n, err } From afdcb6cfc7b8be71f2ebb7ba38629becbaab1e86 Mon Sep 17 00:00:00 2001 From: fishg <1423545+fishg@users.noreply.github.com> Date: Thu, 31 Mar 2022 11:41:40 +0800 Subject: [PATCH 04/14] fix: log level ajust and lint fix --- tunnel/statistic/tracker.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 77713da2..f213ca61 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -8,8 +8,8 @@ import ( "github.com/Dreamacro/clash/common/snifer/tls" "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/log" + "github.com/gofrs/uuid" "go.uber.org/atomic" ) @@ -57,12 +57,11 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { if err != nil { // log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort) } else { - tt.Metadata.Host = header.Domain() - resolver.InsertHostByIP(tt.Metadata.DstIP, tt.Metadata.Host) - log.Errorln("sni %s %s", tt.Metadata.Host, tt.Metadata.DstIP.String()) + resolver.InsertHostByIP(tt.Metadata.DstIP, header.Domain()) + log.Warnln("use sni update host: %s ip: %s", header.Domain(), tt.Metadata.DstIP.String()) tt.manager.Leave(tt) tt.Conn.Close() - return n, errors.New("sni update") + return n, errors.New("sni update, break current link to avoid leaks") } } tt.UploadTotal.Add(upload) From 13012a9f897adf8a7acc49c4b1f31e3fe629bc0b Mon Sep 17 00:00:00 2001 From: fishg <1423545+fishg@users.noreply.github.com> Date: Sat, 2 Apr 2022 16:03:53 +0800 Subject: [PATCH 05/14] fix: dns over proxy may due to cancel request, but proxy live status is fine --- adapter/adapter.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/adapter/adapter.go b/adapter/adapter.go index 23dc304a..f4087241 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/url" + "strings" "time" "github.com/Dreamacro/clash/common/queue" @@ -37,7 +38,11 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { // DialContext implements C.ProxyAdapter func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { conn, err := p.ProxyAdapter.DialContext(ctx, metadata, opts...) - p.alive.Store(err == nil) + wasCancel := false + if err != nil { + wasCancel = strings.Contains(err.Error(), "operation was canceled") + } + p.alive.Store(err == nil || wasCancel) return conn, err } From 93d2cfa0918e27b28a00496c57d38575a29a2f1a Mon Sep 17 00:00:00 2001 From: fishg <1423545+fishg@users.noreply.github.com> Date: Mon, 4 Apr 2022 10:39:26 +0800 Subject: [PATCH 06/14] fix: when ssh connect to a ip, if this ip map to a domain in clash, change ip to host may redirect to a diffrent ip --- constant/metadata.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/constant/metadata.go b/constant/metadata.go index f63816b1..3da67201 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -83,7 +83,11 @@ type Metadata struct { } func (m *Metadata) RemoteAddress() string { - return net.JoinHostPort(m.String(), m.DstPort) + if m.DstIP != nil { + return net.JoinHostPort(m.DstIP.String(), m.DstPort) + } else { + return net.JoinHostPort(m.String(), m.DstPort) + } } func (m *Metadata) SourceAddress() string { From 7a8af90b8654ea5b1a6af1abaae3dbede333d700 Mon Sep 17 00:00:00 2001 From: fishg <1423545+fishg@users.noreply.github.com> Date: Mon, 4 Apr 2022 10:43:25 +0800 Subject: [PATCH 07/14] feat: add SMTPS/POP3S/IMAPS port to sni detect --- tunnel/statistic/tracker.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index f213ca61..db018c05 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -52,7 +52,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) tt.manager.PushUploaded(upload) - if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443") { + if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443" || tt.Metadata.DstPort == "993" || tt.Metadata.DstPort == "465" || tt.Metadata.DstPort == "995") { header, err := tls.SniffTLS(b) if err != nil { // log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort) From b6653dd9b510fc6a7ffd05682820daeb2a1a6d61 Mon Sep 17 00:00:00 2001 From: fishg <1423545+fishg@users.noreply.github.com> Date: Sat, 9 Apr 2022 21:29:19 +0800 Subject: [PATCH 08/14] fix: trojan fail may panic --- transport/trojan/trojan.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/trojan/trojan.go b/transport/trojan/trojan.go index a0e289f1..207d7b3a 100644 --- a/transport/trojan/trojan.go +++ b/transport/trojan/trojan.go @@ -148,7 +148,7 @@ func (t *Trojan) PresetXTLSConn(conn net.Conn) (net.Conn, error) { xtlsConn.DirectMode = true } } else { - return nil, fmt.Errorf("failed to use %s, maybe \"security\" is not \"xtls\"", t.option.Flow) + return conn, fmt.Errorf("failed to use %s, maybe \"security\" is not \"xtls\"", t.option.Flow) } } From 92d9d03f992f152e337742131d66bded1a9cd023 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Sun, 10 Apr 2022 00:05:59 +0800 Subject: [PATCH 09/14] Chore: move sniffing logic into a single file & code style --- adapter/adapter.go | 12 ++++---- config/initial.go | 35 ----------------------- tunnel/statistic/sniffing.go | 54 ++++++++++++++++++++++++++++++++++++ tunnel/statistic/tracker.go | 22 ++------------- 4 files changed, 63 insertions(+), 60 deletions(-) create mode 100644 tunnel/statistic/sniffing.go diff --git a/adapter/adapter.go b/adapter/adapter.go index f4087241..1548a5f6 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -3,12 +3,13 @@ package adapter import ( "context" "encoding/json" + "errors" "fmt" "net" "net/http" "net/url" - "strings" "time" + _ "unsafe" "github.com/Dreamacro/clash/common/queue" "github.com/Dreamacro/clash/component/dialer" @@ -17,6 +18,9 @@ import ( "go.uber.org/atomic" ) +//go:linkname errCanceled net.errCanceled +var errCanceled error + type Proxy struct { C.ProxyAdapter history *queue.Queue @@ -38,11 +42,7 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { // DialContext implements C.ProxyAdapter func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { conn, err := p.ProxyAdapter.DialContext(ctx, metadata, opts...) - wasCancel := false - if err != nil { - wasCancel = strings.Contains(err.Error(), "operation was canceled") - } - p.alive.Store(err == nil || wasCancel) + p.alive.Store(err == nil || errors.Is(err, errCanceled)) return conn, err } diff --git a/config/initial.go b/config/initial.go index 9d8288ba..a365153b 100644 --- a/config/initial.go +++ b/config/initial.go @@ -50,23 +50,6 @@ func initMMDB() error { return nil } -//func downloadGeoIP(path string) (err error) { -// resp, err := http.Get("https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geoip.dat") -// if err != nil { -// return -// } -// defer resp.Body.Close() -// -// f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644) -// if err != nil { -// return err -// } -// defer f.Close() -// _, err = io.Copy(f, resp.Body) -// -// return err -//} - func downloadGeoSite(path string) (err error) { resp, err := http.Get("https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geosite.dat") if err != nil { @@ -84,19 +67,6 @@ func downloadGeoSite(path string) (err error) { return err } -// -//func initGeoIP() error { -// if _, err := os.Stat(C.Path.GeoIP()); os.IsNotExist(err) { -// log.Infoln("Can't find GeoIP.dat, start download") -// if err := downloadGeoIP(C.Path.GeoIP()); err != nil { -// return fmt.Errorf("can't download GeoIP.dat: %s", err.Error()) -// } -// log.Infoln("Download GeoIP.dat finish") -// } -// -// return nil -//} - func initGeoSite() error { if _, err := os.Stat(C.Path.GeoSite()); os.IsNotExist(err) { log.Infoln("Can't find GeoSite.dat, start download") @@ -129,11 +99,6 @@ func Init(dir string) error { f.Close() } - //// initial GeoIP - //if err := initGeoIP(); err != nil { - // return fmt.Errorf("can't initial GeoIP: %w", err) - //} - // initial mmdb if err := initMMDB(); err != nil { return fmt.Errorf("can't initial MMDB: %w", err) diff --git a/tunnel/statistic/sniffing.go b/tunnel/statistic/sniffing.go new file mode 100644 index 00000000..2d1f1bfd --- /dev/null +++ b/tunnel/statistic/sniffing.go @@ -0,0 +1,54 @@ +package statistic + +import ( + "errors" + + "github.com/Dreamacro/clash/common/snifer/tls" + "github.com/Dreamacro/clash/component/resolver" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + + "go.uber.org/atomic" +) + +type sniffing struct { + C.Conn + + metadata *C.Metadata + totalWrite *atomic.Uint64 +} + +func (r *sniffing) Read(b []byte) (int, error) { + return r.Conn.Read(b) +} + +func (r *sniffing) Write(b []byte) (int, error) { + if r.totalWrite.Load() < 128 && r.metadata.Host == "" && (r.metadata.DstPort == "443" || r.metadata.DstPort == "8443" || r.metadata.DstPort == "993" || r.metadata.DstPort == "465" || r.metadata.DstPort == "995") { + header, err := tls.SniffTLS(b) + if err != nil { + // log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort) + } else { + resolver.InsertHostByIP(r.metadata.DstIP, header.Domain()) + log.Warnln("use sni update host: %s ip: %s", header.Domain(), r.metadata.DstIP.String()) + r.Conn.Close() + return 0, errors.New("sni update, break current link to avoid leaks") + } + } + + n, err := r.Conn.Write(b) + r.totalWrite.Add(uint64(n)) + + return n, err +} + +func (r *sniffing) Close() error { + return r.Conn.Close() +} + +func NewSniffing(conn C.Conn, metadata *C.Metadata) C.Conn { + return &sniffing{ + Conn: conn, + metadata: metadata, + totalWrite: atomic.NewUint64(0), + } +} diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index db018c05..6fd8b3e7 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -1,14 +1,10 @@ package statistic import ( - "errors" "net" "time" - "github.com/Dreamacro/clash/common/snifer/tls" - "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/log" "github.com/gofrs/uuid" "go.uber.org/atomic" @@ -52,20 +48,7 @@ func (tt *tcpTracker) Write(b []byte) (int, error) { n, err := tt.Conn.Write(b) upload := int64(n) tt.manager.PushUploaded(upload) - if tt.UploadTotal.Load() < 128 && tt.Metadata.Host == "" && (tt.Metadata.DstPort == "443" || tt.Metadata.DstPort == "8443" || tt.Metadata.DstPort == "993" || tt.Metadata.DstPort == "465" || tt.Metadata.DstPort == "995") { - header, err := tls.SniffTLS(b) - if err != nil { - // log.Errorln("Expect no error but actually %s %s:%s:%s", err.Error(), tt.Metadata.Host, tt.Metadata.DstIP.String(), tt.Metadata.DstPort) - } else { - resolver.InsertHostByIP(tt.Metadata.DstIP, header.Domain()) - log.Warnln("use sni update host: %s ip: %s", header.Domain(), tt.Metadata.DstIP.String()) - tt.manager.Leave(tt) - tt.Conn.Close() - return n, errors.New("sni update, break current link to avoid leaks") - } - } tt.UploadTotal.Add(upload) - return n, err } @@ -74,7 +57,7 @@ func (tt *tcpTracker) Close() error { return tt.Conn.Close() } -func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker { +func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) C.Conn { uuid, _ := uuid.NewV4() t := &tcpTracker{ @@ -97,7 +80,8 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R } manager.Join(t) - return t + conn = NewSniffing(t, metadata) + return conn } type udpTracker struct { From 0582c608b3240c88a6833212b10c2496b125ec78 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Tue, 5 Apr 2022 20:23:16 +0800 Subject: [PATCH 10/14] Refactor: lrucache use generics --- common/cache/lrucache.go | 97 ++++++++++++++++++----------------- common/cache/lrucache_test.go | 41 +++++++-------- component/fakeip/memory.go | 40 +++++++++------ component/fakeip/pool.go | 5 +- dns/enhancer.go | 8 +-- dns/middleware.go | 2 +- dns/resolver.go | 8 +-- dns/util.go | 2 +- 8 files changed, 106 insertions(+), 97 deletions(-) diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go index 0bea06f6..82eca7f4 100644 --- a/common/cache/lrucache.go +++ b/common/cache/lrucache.go @@ -9,43 +9,43 @@ import ( ) // Option is part of Functional Options Pattern -type Option func(*LruCache) +type Option[K comparable, V any] func(*LruCache[K, V]) // EvictCallback is used to get a callback when a cache entry is evicted type EvictCallback = func(key any, value any) // WithEvict set the evict callback -func WithEvict(cb EvictCallback) Option { - return func(l *LruCache) { +func WithEvict[K comparable, V any](cb EvictCallback) Option[K, V] { + return func(l *LruCache[K, V]) { l.onEvict = cb } } // WithUpdateAgeOnGet update expires when Get element -func WithUpdateAgeOnGet() Option { - return func(l *LruCache) { +func WithUpdateAgeOnGet[K comparable, V any]() Option[K, V] { + return func(l *LruCache[K, V]) { l.updateAgeOnGet = true } } // WithAge defined element max age (second) -func WithAge(maxAge int64) Option { - return func(l *LruCache) { +func WithAge[K comparable, V any](maxAge int64) Option[K, V] { + return func(l *LruCache[K, V]) { l.maxAge = maxAge } } // WithSize defined max length of LruCache -func WithSize(maxSize int) Option { - return func(l *LruCache) { +func WithSize[K comparable, V any](maxSize int) Option[K, V] { + return func(l *LruCache[K, V]) { l.maxSize = maxSize } } // WithStale decide whether Stale return is enabled. // If this feature is enabled, element will not get Evicted according to `WithAge`. -func WithStale(stale bool) Option { - return func(l *LruCache) { +func WithStale[K comparable, V any](stale bool) Option[K, V] { + return func(l *LruCache[K, V]) { l.staleReturn = stale } } @@ -53,7 +53,7 @@ func WithStale(stale bool) Option { // LruCache is a thread-safe, in-memory lru-cache that evicts the // least recently used entries from memory when (if set) the entries are // older than maxAge (in seconds). Use the New constructor to create one. -type LruCache struct { +type LruCache[K comparable, V any] struct { maxAge int64 maxSize int mu sync.Mutex @@ -65,8 +65,8 @@ type LruCache struct { } // NewLRUCache creates an LruCache -func NewLRUCache(options ...Option) *LruCache { - lc := &LruCache{ +func NewLRUCache[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] { + lc := &LruCache[K, V]{ lru: list.New(), cache: make(map[any]*list.Element), } @@ -80,12 +80,12 @@ func NewLRUCache(options ...Option) *LruCache { // Get returns the any representation of a cached response and a bool // set to true if the key was found. -func (c *LruCache) Get(key any) (any, bool) { - entry := c.get(key) - if entry == nil { - return nil, false +func (c *LruCache[K, V]) Get(key K) (V, bool) { + el := c.get(key) + if el == nil { + return getZero[V](), false } - value := entry.value + value := el.value return value, true } @@ -94,17 +94,17 @@ func (c *LruCache) Get(key any) (any, bool) { // a time.Time Give expected expires, // and a bool set to true if the key was found. // This method will NOT check the maxAge of element and will NOT update the expires. -func (c *LruCache) GetWithExpire(key any) (any, time.Time, bool) { - entry := c.get(key) - if entry == nil { - return nil, time.Time{}, false +func (c *LruCache[K, V]) GetWithExpire(key K) (V, time.Time, bool) { + el := c.get(key) + if el == nil { + return getZero[V](), time.Time{}, false } - return entry.value, time.Unix(entry.expires, 0), true + return el.value, time.Unix(el.expires, 0), true } // Exist returns if key exist in cache but not put item to the head of linked list -func (c *LruCache) Exist(key any) bool { +func (c *LruCache[K, V]) Exist(key K) bool { c.mu.Lock() defer c.mu.Unlock() @@ -113,7 +113,7 @@ func (c *LruCache) Exist(key any) bool { } // Set stores the any representation of a response for a given key. -func (c *LruCache) Set(key any, value any) { +func (c *LruCache[K, V]) Set(key K, value V) { expires := int64(0) if c.maxAge > 0 { expires = time.Now().Unix() + c.maxAge @@ -123,21 +123,21 @@ func (c *LruCache) Set(key any, value any) { // SetWithExpire stores the any representation of a response for a given key and given expires. // The expires time will round to second. -func (c *LruCache) SetWithExpire(key any, value any, expires time.Time) { +func (c *LruCache[K, V]) SetWithExpire(key K, value V, expires time.Time) { c.mu.Lock() defer c.mu.Unlock() if le, ok := c.cache[key]; ok { c.lru.MoveToBack(le) - e := le.Value.(*entry) + e := le.Value.(*entry[K, V]) e.value = value e.expires = expires.Unix() } else { - e := &entry{key: key, value: value, expires: expires.Unix()} + e := &entry[K, V]{key: key, value: value, expires: expires.Unix()} c.cache[key] = c.lru.PushBack(e) if c.maxSize > 0 { - if len := c.lru.Len(); len > c.maxSize { + if elLen := c.lru.Len(); elLen > c.maxSize { c.deleteElement(c.lru.Front()) } } @@ -147,7 +147,7 @@ func (c *LruCache) SetWithExpire(key any, value any, expires time.Time) { } // CloneTo clone and overwrite elements to another LruCache -func (c *LruCache) CloneTo(n *LruCache) { +func (c *LruCache[K, V]) CloneTo(n *LruCache[K, V]) { c.mu.Lock() defer c.mu.Unlock() @@ -158,12 +158,12 @@ func (c *LruCache) CloneTo(n *LruCache) { n.cache = make(map[any]*list.Element) for e := c.lru.Front(); e != nil; e = e.Next() { - elm := e.Value.(*entry) + elm := e.Value.(*entry[K, V]) n.cache[elm.key] = n.lru.PushBack(elm) } } -func (c *LruCache) get(key any) *entry { +func (c *LruCache[K, V]) get(key K) *entry[K, V] { c.mu.Lock() defer c.mu.Unlock() @@ -172,7 +172,7 @@ func (c *LruCache) get(key any) *entry { return nil } - if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { + if !c.staleReturn && c.maxAge > 0 && le.Value.(*entry[K, V]).expires <= time.Now().Unix() { c.deleteElement(le) c.maybeDeleteOldest() @@ -180,15 +180,15 @@ func (c *LruCache) get(key any) *entry { } c.lru.MoveToBack(le) - entry := le.Value.(*entry) + el := le.Value.(*entry[K, V]) if c.maxAge > 0 && c.updateAgeOnGet { - entry.expires = time.Now().Unix() + c.maxAge + el.expires = time.Now().Unix() + c.maxAge } - return entry + return el } // Delete removes the value associated with a key. -func (c *LruCache) Delete(key any) { +func (c *LruCache[K, V]) Delete(key K) { c.mu.Lock() if le, ok := c.cache[key]; ok { @@ -198,25 +198,25 @@ func (c *LruCache) Delete(key any) { c.mu.Unlock() } -func (c *LruCache) maybeDeleteOldest() { +func (c *LruCache[K, V]) maybeDeleteOldest() { if !c.staleReturn && c.maxAge > 0 { now := time.Now().Unix() - for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { + for le := c.lru.Front(); le != nil && le.Value.(*entry[K, V]).expires <= now; le = c.lru.Front() { c.deleteElement(le) } } } -func (c *LruCache) deleteElement(le *list.Element) { +func (c *LruCache[K, V]) deleteElement(le *list.Element) { c.lru.Remove(le) - e := le.Value.(*entry) + e := le.Value.(*entry[K, V]) delete(c.cache, e.key) if c.onEvict != nil { c.onEvict(e.key, e.value) } } -func (c *LruCache) Clear() error { +func (c *LruCache[K, V]) Clear() error { c.mu.Lock() c.cache = make(map[any]*list.Element) @@ -225,8 +225,13 @@ func (c *LruCache) Clear() error { return nil } -type entry struct { - key any - value any +type entry[K comparable, V any] struct { + key K + value V expires int64 } + +func getZero[T any]() T { + var result T + return result +} diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go index 1a910b4a..487c184e 100644 --- a/common/cache/lrucache_test.go +++ b/common/cache/lrucache_test.go @@ -19,7 +19,7 @@ var entries = []struct { } func TestLRUCache(t *testing.T) { - c := NewLRUCache() + c := NewLRUCache[string, string]() for _, e := range entries { c.Set(e.key, e.value) @@ -32,7 +32,7 @@ func TestLRUCache(t *testing.T) { for _, e := range entries { value, ok := c.Get(e.key) if assert.True(t, ok) { - assert.Equal(t, e.value, value.(string)) + assert.Equal(t, e.value, value) } } @@ -45,25 +45,25 @@ func TestLRUCache(t *testing.T) { } func TestLRUMaxAge(t *testing.T) { - c := NewLRUCache(WithAge(86400)) + c := NewLRUCache[string, string](WithAge[string, string](86400)) now := time.Now().Unix() expected := now + 86400 // Add one expired entry c.Set("foo", "bar") - c.lru.Back().Value.(*entry).expires = now + c.lru.Back().Value.(*entry[string, string]).expires = now // Reset c.Set("foo", "bar") - e := c.lru.Back().Value.(*entry) + e := c.lru.Back().Value.(*entry[string, string]) assert.True(t, e.expires >= now) - c.lru.Back().Value.(*entry).expires = now + c.lru.Back().Value.(*entry[string, string]).expires = now // Set a few and verify expiration times for _, s := range entries { c.Set(s.key, s.value) - e := c.lru.Back().Value.(*entry) + e := c.lru.Back().Value.(*entry[string, string]) assert.True(t, e.expires >= expected && e.expires <= expected+10) } @@ -77,7 +77,7 @@ func TestLRUMaxAge(t *testing.T) { for _, s := range entries { le, ok := c.cache[s.key] if assert.True(t, ok) { - le.Value.(*entry).expires = now + le.Value.(*entry[string, string]).expires = now } } @@ -88,22 +88,22 @@ func TestLRUMaxAge(t *testing.T) { } func TestLRUpdateOnGet(t *testing.T) { - c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet()) + c := NewLRUCache[string, string](WithAge[string, string](86400), WithUpdateAgeOnGet[string, string]()) now := time.Now().Unix() expires := now + 86400/2 // Add one expired entry c.Set("foo", "bar") - c.lru.Back().Value.(*entry).expires = expires + c.lru.Back().Value.(*entry[string, string]).expires = expires _, ok := c.Get("foo") assert.True(t, ok) - assert.True(t, c.lru.Back().Value.(*entry).expires > expires) + assert.True(t, c.lru.Back().Value.(*entry[string, string]).expires > expires) } func TestMaxSize(t *testing.T) { - c := NewLRUCache(WithSize(2)) + c := NewLRUCache[string, string](WithSize[string, string](2)) // Add one expired entry c.Set("foo", "bar") _, ok := c.Get("foo") @@ -117,7 +117,7 @@ func TestMaxSize(t *testing.T) { } func TestExist(t *testing.T) { - c := NewLRUCache(WithSize(1)) + c := NewLRUCache[int, int](WithSize[int, int](1)) c.Set(1, 2) assert.True(t, c.Exist(1)) c.Set(2, 3) @@ -130,7 +130,7 @@ func TestEvict(t *testing.T) { temp = key.(int) + value.(int) } - c := NewLRUCache(WithEvict(evict), WithSize(1)) + c := NewLRUCache[int, int](WithEvict[int, int](evict), WithSize[int, int](1)) c.Set(1, 2) c.Set(2, 3) @@ -138,21 +138,22 @@ func TestEvict(t *testing.T) { } func TestSetWithExpire(t *testing.T) { - c := NewLRUCache(WithAge(1)) + c := NewLRUCache[int, *struct{}](WithAge[int, *struct{}](1)) now := time.Now().Unix() tenSecBefore := time.Unix(now-10, 0) - c.SetWithExpire(1, 2, tenSecBefore) + c.SetWithExpire(1, &struct{}{}, tenSecBefore) // res is expected not to exist, and expires should be empty time.Time res, expires, exist := c.GetWithExpire(1) - assert.Equal(t, nil, res) + + assert.True(t, nil == res) assert.Equal(t, time.Time{}, expires) assert.Equal(t, false, exist) } func TestStale(t *testing.T) { - c := NewLRUCache(WithAge(1), WithStale(true)) + c := NewLRUCache[int, int](WithAge[int, int](1), WithStale[int, int](true)) now := time.Now().Unix() tenSecBefore := time.Unix(now-10, 0) @@ -165,11 +166,11 @@ func TestStale(t *testing.T) { } func TestCloneTo(t *testing.T) { - o := NewLRUCache(WithSize(10)) + o := NewLRUCache[string, int](WithSize[string, int](10)) o.Set("1", 1) o.Set("2", 2) - n := NewLRUCache(WithSize(2)) + n := NewLRUCache[string, int](WithSize[string, int](2)) n.Set("3", 3) n.Set("4", 4) diff --git a/component/fakeip/memory.go b/component/fakeip/memory.go index a7ff3708..2568b1d9 100644 --- a/component/fakeip/memory.go +++ b/component/fakeip/memory.go @@ -7,16 +7,15 @@ import ( ) type memoryStore struct { - cache *cache.LruCache + cacheIP *cache.LruCache[string, net.IP] + cacheHost *cache.LruCache[uint32, string] } // GetByHost implements store.GetByHost func (m *memoryStore) GetByHost(host string) (net.IP, bool) { - if elm, exist := m.cache.Get(host); exist { - ip := elm.(net.IP) - + if ip, exist := m.cacheIP.Get(host); exist { // ensure ip --> host on head of linked list - m.cache.Get(ipToUint(ip.To4())) + m.cacheHost.Get(ipToUint(ip.To4())) return ip, true } @@ -25,16 +24,14 @@ func (m *memoryStore) GetByHost(host string) (net.IP, bool) { // PutByHost implements store.PutByHost func (m *memoryStore) PutByHost(host string, ip net.IP) { - m.cache.Set(host, ip) + m.cacheIP.Set(host, ip) } // GetByIP implements store.GetByIP func (m *memoryStore) GetByIP(ip net.IP) (string, bool) { - if elm, exist := m.cache.Get(ipToUint(ip.To4())); exist { - host := elm.(string) - + if host, exist := m.cacheHost.Get(ipToUint(ip.To4())); exist { // ensure host --> ip on head of linked list - m.cache.Get(host) + m.cacheIP.Get(host) return host, true } @@ -43,32 +40,41 @@ func (m *memoryStore) GetByIP(ip net.IP) (string, bool) { // PutByIP implements store.PutByIP func (m *memoryStore) PutByIP(ip net.IP, host string) { - m.cache.Set(ipToUint(ip.To4()), host) + m.cacheHost.Set(ipToUint(ip.To4()), host) } // DelByIP implements store.DelByIP func (m *memoryStore) DelByIP(ip net.IP) { ipNum := ipToUint(ip.To4()) - if elm, exist := m.cache.Get(ipNum); exist { - m.cache.Delete(elm.(string)) + if host, exist := m.cacheHost.Get(ipNum); exist { + m.cacheIP.Delete(host) } - m.cache.Delete(ipNum) + m.cacheHost.Delete(ipNum) } // Exist implements store.Exist func (m *memoryStore) Exist(ip net.IP) bool { - return m.cache.Exist(ipToUint(ip.To4())) + return m.cacheHost.Exist(ipToUint(ip.To4())) } // CloneTo implements store.CloneTo // only for memoryStore to memoryStore func (m *memoryStore) CloneTo(store store) { if ms, ok := store.(*memoryStore); ok { - m.cache.CloneTo(ms.cache) + m.cacheIP.CloneTo(ms.cacheIP) + m.cacheHost.CloneTo(ms.cacheHost) } } // FlushFakeIP implements store.FlushFakeIP func (m *memoryStore) FlushFakeIP() error { - return m.cache.Clear() + _ = m.cacheIP.Clear() + return m.cacheHost.Clear() +} + +func newMemoryStore(size int) *memoryStore { + return &memoryStore{ + cacheIP: cache.NewLRUCache[string, net.IP](cache.WithSize[string, net.IP](size)), + cacheHost: cache.NewLRUCache[uint32, string](cache.WithSize[uint32, string](size)), + } } diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index e93873c9..a55e5463 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -5,7 +5,6 @@ import ( "net" "sync" - "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/component/profile/cachefile" "github.com/Dreamacro/clash/component/trie" ) @@ -175,9 +174,7 @@ func New(options Options) (*Pool, error) { cache: cachefile.Cache(), } } else { - pool.store = &memoryStore{ - cache: cache.NewLRUCache(cache.WithSize(options.Size * 2)), - } + pool.store = newMemoryStore(options.Size) } return pool, nil diff --git a/dns/enhancer.go b/dns/enhancer.go index 016ff02a..9d708caa 100644 --- a/dns/enhancer.go +++ b/dns/enhancer.go @@ -11,7 +11,7 @@ import ( type ResolverEnhancer struct { mode C.DNSMode fakePool *fakeip.Pool - mapping *cache.LruCache + mapping *cache.LruCache[string, string] } func (h *ResolverEnhancer) FakeIPEnabled() bool { @@ -67,7 +67,7 @@ func (h *ResolverEnhancer) FindHostByIP(ip net.IP) (string, bool) { if mapping := h.mapping; mapping != nil { if host, existed := h.mapping.Get(ip.String()); existed { - return host.(string), true + return host, true } } @@ -99,11 +99,11 @@ func (h *ResolverEnhancer) FlushFakeIP() error { func NewEnhancer(cfg Config) *ResolverEnhancer { var fakePool *fakeip.Pool - var mapping *cache.LruCache + var mapping *cache.LruCache[string, string] if cfg.EnhancedMode != C.DNSNormal { fakePool = cfg.Pool - mapping = cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)) + mapping = cache.NewLRUCache[string, string](cache.WithSize[string, string](4096), cache.WithStale[string, string](true)) } return &ResolverEnhancer{ diff --git a/dns/middleware.go b/dns/middleware.go index dc7cbe33..5958fe93 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -63,7 +63,7 @@ func withHosts(hosts *trie.DomainTrie) middleware { } } -func withMapping(mapping *cache.LruCache) middleware { +func withMapping(mapping *cache.LruCache[string, string]) middleware { return func(next handler) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] diff --git a/dns/resolver.go b/dns/resolver.go index 4ff12ee8..9968eaf0 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -39,7 +39,7 @@ type Resolver struct { fallbackDomainFilters []fallbackDomainFilter fallbackIPFilters []fallbackIPFilter group singleflight.Group - lruCache *cache.LruCache + lruCache *cache.LruCache[string, *D.Msg] policy *trie.DomainTrie proxyServer []dnsClient } @@ -103,7 +103,7 @@ func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, e cache, expireTime, hit := r.lruCache.GetWithExpire(q.String()) if hit { now := time.Now() - msg = cache.(*D.Msg).Copy() + msg = cache.Copy() if expireTime.Before(now) { setMsgTTL(msg, uint32(1)) // Continue fetch go r.exchangeWithoutCache(ctx, m) @@ -337,13 +337,13 @@ type Config struct { func NewResolver(config Config) *Resolver { defaultResolver := &Resolver{ main: transform(config.Default, nil), - lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), + lruCache: cache.NewLRUCache[string, *D.Msg](cache.WithSize[string, *D.Msg](4096), cache.WithStale[string, *D.Msg](true)), } r := &Resolver{ ipv6: config.IPv6, main: transform(config.Main, defaultResolver), - lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), + lruCache: cache.NewLRUCache[string, *D.Msg](cache.WithSize[string, *D.Msg](4096), cache.WithStale[string, *D.Msg](true)), hosts: config.Hosts, } diff --git a/dns/util.go b/dns/util.go index 59eef1de..759b2209 100644 --- a/dns/util.go +++ b/dns/util.go @@ -16,7 +16,7 @@ import ( D "github.com/miekg/dns" ) -func putMsgToCache(c *cache.LruCache, key string, msg *D.Msg) { +func putMsgToCache(c *cache.LruCache[string, *D.Msg], key string, msg *D.Msg) { var ttl uint32 switch { case len(msg.Answer) != 0: From 400be9a905892ca5e014e6e01bcd494b7abbe8cb Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Tue, 5 Apr 2022 23:29:52 +0800 Subject: [PATCH 11/14] Refactor: cache use generics --- common/cache/cache.go | 48 +++++++++++++++++++------------------- common/cache/cache_test.go | 28 +++++++++++----------- listener/http/proxy.go | 10 ++++---- listener/http/server.go | 4 ++-- listener/mixed/mixed.go | 6 ++--- 5 files changed, 49 insertions(+), 47 deletions(-) diff --git a/common/cache/cache.go b/common/cache/cache.go index e587d77b..b87392b4 100644 --- a/common/cache/cache.go +++ b/common/cache/cache.go @@ -7,50 +7,50 @@ import ( ) // Cache store element with a expired time -type Cache struct { - *cache +type Cache[K comparable, V any] struct { + *cache[K, V] } -type cache struct { +type cache[K comparable, V any] struct { mapping sync.Map - janitor *janitor + janitor *janitor[K, V] } -type element struct { +type element[V any] struct { Expired time.Time - Payload any + Payload V } // Put element in Cache with its ttl -func (c *cache) Put(key any, payload any, ttl time.Duration) { - c.mapping.Store(key, &element{ +func (c *cache[K, V]) Put(key K, payload V, ttl time.Duration) { + c.mapping.Store(key, &element[V]{ Payload: payload, Expired: time.Now().Add(ttl), }) } // Get element in Cache, and drop when it expired -func (c *cache) Get(key any) any { +func (c *cache[K, V]) Get(key K) V { item, exist := c.mapping.Load(key) if !exist { - return nil + return getZero[V]() } - elm := item.(*element) + elm := item.(*element[V]) // expired if time.Since(elm.Expired) > 0 { c.mapping.Delete(key) - return nil + return getZero[V]() } return elm.Payload } // GetWithExpire element in Cache with Expire Time -func (c *cache) GetWithExpire(key any) (payload any, expired time.Time) { +func (c *cache[K, V]) GetWithExpire(key K) (payload V, expired time.Time) { item, exist := c.mapping.Load(key) if !exist { return } - elm := item.(*element) + elm := item.(*element[V]) // expired if time.Since(elm.Expired) > 0 { c.mapping.Delete(key) @@ -59,10 +59,10 @@ func (c *cache) GetWithExpire(key any) (payload any, expired time.Time) { return elm.Payload, elm.Expired } -func (c *cache) cleanup() { +func (c *cache[K, V]) cleanup() { c.mapping.Range(func(k, v any) bool { key := k.(string) - elm := v.(*element) + elm := v.(*element[V]) if time.Since(elm.Expired) > 0 { c.mapping.Delete(key) } @@ -70,12 +70,12 @@ func (c *cache) cleanup() { }) } -type janitor struct { +type janitor[K comparable, V any] struct { interval time.Duration stop chan struct{} } -func (j *janitor) process(c *cache) { +func (j *janitor[K, V]) process(c *cache[K, V]) { ticker := time.NewTicker(j.interval) for { select { @@ -88,19 +88,19 @@ func (j *janitor) process(c *cache) { } } -func stopJanitor(c *Cache) { +func stopJanitor[K comparable, V any](c *Cache[K, V]) { c.janitor.stop <- struct{}{} } // New return *Cache -func New(interval time.Duration) *Cache { - j := &janitor{ +func New[K comparable, V any](interval time.Duration) *Cache[K, V] { + j := &janitor[K, V]{ interval: interval, stop: make(chan struct{}), } - c := &cache{janitor: j} + c := &cache[K, V]{janitor: j} go j.process(c) - C := &Cache{c} - runtime.SetFinalizer(C, stopJanitor) + C := &Cache[K, V]{c} + runtime.SetFinalizer(C, stopJanitor[K, V]) return C } diff --git a/common/cache/cache_test.go b/common/cache/cache_test.go index cf4a3914..0945d905 100644 --- a/common/cache/cache_test.go +++ b/common/cache/cache_test.go @@ -11,48 +11,50 @@ import ( func TestCache_Basic(t *testing.T) { interval := 200 * time.Millisecond ttl := 20 * time.Millisecond - c := New(interval) + c := New[string, int](interval) c.Put("int", 1, ttl) - c.Put("string", "a", ttl) + + d := New[string, string](interval) + d.Put("string", "a", ttl) i := c.Get("int") - assert.Equal(t, i.(int), 1, "should recv 1") + assert.Equal(t, i, 1, "should recv 1") - s := c.Get("string") - assert.Equal(t, s.(string), "a", "should recv 'a'") + s := d.Get("string") + assert.Equal(t, s, "a", "should recv 'a'") } func TestCache_TTL(t *testing.T) { interval := 200 * time.Millisecond ttl := 20 * time.Millisecond now := time.Now() - c := New(interval) + c := New[string, int](interval) c.Put("int", 1, ttl) c.Put("int2", 2, ttl) i := c.Get("int") _, expired := c.GetWithExpire("int2") - assert.Equal(t, i.(int), 1, "should recv 1") + assert.Equal(t, i, 1, "should recv 1") assert.True(t, now.Before(expired)) time.Sleep(ttl * 2) i = c.Get("int") j, _ := c.GetWithExpire("int2") - assert.Nil(t, i, "should recv nil") - assert.Nil(t, j, "should recv nil") + assert.True(t, i == 0, "should recv 0") + assert.True(t, j == 0, "should recv 0") } func TestCache_AutoCleanup(t *testing.T) { interval := 10 * time.Millisecond ttl := 15 * time.Millisecond - c := New(interval) + c := New[string, int](interval) c.Put("int", 1, ttl) time.Sleep(ttl * 2) i := c.Get("int") j, _ := c.GetWithExpire("int") - assert.Nil(t, i, "should recv nil") - assert.Nil(t, j, "should recv nil") + assert.True(t, i == 0, "should recv 0") + assert.True(t, j == 0, "should recv 0") } func TestCache_AutoGC(t *testing.T) { @@ -60,7 +62,7 @@ func TestCache_AutoGC(t *testing.T) { go func() { interval := 10 * time.Millisecond ttl := 15 * time.Millisecond - c := New(interval) + c := New[string, int](interval) c.Put("int", 1, ttl) sign <- struct{}{} }() diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 18f1e5d4..e8a805a9 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -15,7 +15,7 @@ import ( "github.com/Dreamacro/clash/log" ) -func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache) { +func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { client := newClient(c.RemoteAddr(), in) defer client.CloseIdleConnections() @@ -98,7 +98,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache) { conn.Close() } -func authenticate(request *http.Request, cache *cache.Cache) *http.Response { +func authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response { authenticator := authStore.Authenticator() if authenticator != nil { credential := parseBasicProxyAuthorization(request) @@ -108,13 +108,13 @@ func authenticate(request *http.Request, cache *cache.Cache) *http.Response { return resp } - var authed any - if authed = cache.Get(credential); authed == nil { + var authed bool + if authed = cache.Get(credential); !authed { user, pass, err := decodeBasicProxyAuthorization(credential) authed = err == nil && authenticator.Verify(user, pass) cache.Put(credential, authed, time.Minute) } - if !authed.(bool) { + if !authed { log.Infoln("Auth failed from %s", request.RemoteAddr) return responseWith(request, http.StatusForbidden) diff --git a/listener/http/server.go b/listener/http/server.go index bfdd9f1b..6b966143 100644 --- a/listener/http/server.go +++ b/listener/http/server.go @@ -40,9 +40,9 @@ func NewWithAuthenticate(addr string, in chan<- C.ConnContext, authenticate bool return nil, err } - var c *cache.Cache + var c *cache.Cache[string, bool] if authenticate { - c = cache.New(time.Second * 30) + c = cache.New[string, bool](time.Second * 30) } hl := &Listener{ diff --git a/listener/mixed/mixed.go b/listener/mixed/mixed.go index 57fd055e..14a81bc3 100644 --- a/listener/mixed/mixed.go +++ b/listener/mixed/mixed.go @@ -16,7 +16,7 @@ import ( type Listener struct { listener net.Listener addr string - cache *cache.Cache + cache *cache.Cache[string, bool] closed bool } @@ -45,7 +45,7 @@ func New(addr string, in chan<- C.ConnContext) (*Listener, error) { ml := &Listener{ listener: l, addr: addr, - cache: cache.New(30 * time.Second), + cache: cache.New[string, bool](30 * time.Second), } go func() { for { @@ -63,7 +63,7 @@ func New(addr string, in chan<- C.ConnContext) (*Listener, error) { return ml, nil } -func handleConn(conn net.Conn, in chan<- C.ConnContext, cache *cache.Cache) { +func handleConn(conn net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { conn.(*net.TCPConn).SetKeepAlive(true) bufConn := N.NewBufferedConn(conn) From a8646082a39eec13dc36e58916890b8302c5b1de Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Wed, 6 Apr 2022 01:07:08 +0800 Subject: [PATCH 12/14] Refactor: queue use generics --- adapter/adapter.go | 12 ++++-------- common/queue/queue.go | 31 ++++++++++++++++++------------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/adapter/adapter.go b/adapter/adapter.go index 1548a5f6..986cd045 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -23,7 +23,7 @@ var errCanceled error type Proxy struct { C.ProxyAdapter - history *queue.Queue + history *queue.Queue[C.DelayHistory] alive *atomic.Bool } @@ -65,7 +65,7 @@ func (p *Proxy) DelayHistory() []C.DelayHistory { queue := p.history.Copy() histories := []C.DelayHistory{} for _, item := range queue { - histories = append(histories, item.(C.DelayHistory)) + histories = append(histories, item) } return histories } @@ -78,11 +78,7 @@ func (p *Proxy) LastDelay() (delay uint16) { return max } - last := p.history.Last() - if last == nil { - return max - } - history := last.(C.DelayHistory) + history := p.history.Last() if history.Delay == 0 { return max } @@ -166,7 +162,7 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { } func NewProxy(adapter C.ProxyAdapter) *Proxy { - return &Proxy{adapter, queue.New(10), atomic.NewBool(true)} + return &Proxy{adapter, queue.New[C.DelayHistory](10), atomic.NewBool(true)} } func urlToMetadata(rawURL string) (addr C.Metadata, err error) { diff --git a/common/queue/queue.go b/common/queue/queue.go index 60257f56..4755cb35 100644 --- a/common/queue/queue.go +++ b/common/queue/queue.go @@ -5,13 +5,13 @@ import ( ) // Queue is a simple concurrent safe queue -type Queue struct { - items []any +type Queue[T any] struct { + items []T lock sync.RWMutex } // Put add the item to the queue. -func (q *Queue) Put(items ...any) { +func (q *Queue[T]) Put(items ...T) { if len(items) == 0 { return } @@ -22,9 +22,9 @@ func (q *Queue) Put(items ...any) { } // Pop returns the head of items. -func (q *Queue) Pop() any { +func (q *Queue[T]) Pop() T { if len(q.items) == 0 { - return nil + return GetZero[T]() } q.lock.Lock() @@ -35,9 +35,9 @@ func (q *Queue) Pop() any { } // Last returns the last of item. -func (q *Queue) Last() any { +func (q *Queue[T]) Last() T { if len(q.items) == 0 { - return nil + return GetZero[T]() } q.lock.RLock() @@ -47,8 +47,8 @@ func (q *Queue) Last() any { } // Copy get the copy of queue. -func (q *Queue) Copy() []any { - items := []any{} +func (q *Queue[T]) Copy() []T { + items := []T{} q.lock.RLock() items = append(items, q.items...) q.lock.RUnlock() @@ -56,7 +56,7 @@ func (q *Queue) Copy() []any { } // Len returns the number of items in this queue. -func (q *Queue) Len() int64 { +func (q *Queue[T]) Len() int64 { q.lock.Lock() defer q.lock.Unlock() @@ -64,8 +64,13 @@ func (q *Queue) Len() int64 { } // New is a constructor for a new concurrent safe queue. -func New(hint int64) *Queue { - return &Queue{ - items: make([]any, 0, hint), +func New[T any](hint int64) *Queue[T] { + return &Queue[T]{ + items: make([]T, 0, hint), } } + +func GetZero[T any]() T { + var result T + return result +} From 5a27ebd1b3cc9e55b284da8f2ac33e7cde6e687f Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Wed, 6 Apr 2022 04:25:53 +0800 Subject: [PATCH 13/14] Refactor: DomainTrie use generics --- component/fakeip/pool.go | 4 ++-- component/fakeip/pool_test.go | 4 ++-- component/resolver/resolver.go | 14 ++++++++------ component/trie/domain.go | 24 ++++++++++++------------ component/trie/domain_test.go | 28 ++++++++++++++-------------- component/trie/node.go | 23 ++++++++++++++--------- dns/filters.go | 6 +++--- dns/middleware.go | 13 +++++++------ dns/policy.go | 30 ++++++++++++++++++++++++++++++ dns/resolver.go | 14 ++++++++------ 10 files changed, 100 insertions(+), 60 deletions(-) create mode 100644 dns/policy.go diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index a55e5463..afc1691b 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -28,7 +28,7 @@ type Pool struct { broadcast uint32 offset uint32 mux sync.Mutex - host *trie.DomainTrie + host *trie.DomainTrie[bool] ipnet *net.IPNet store store } @@ -138,7 +138,7 @@ func uintToIP(v uint32) net.IP { type Options struct { IPNet *net.IPNet - Host *trie.DomainTrie + Host *trie.DomainTrie[bool] // Size sets the maximum number of entries in memory // and does not work if Persistence is true diff --git a/component/fakeip/pool_test.go b/component/fakeip/pool_test.go index 86e80a2d..b4add98c 100644 --- a/component/fakeip/pool_test.go +++ b/component/fakeip/pool_test.go @@ -100,8 +100,8 @@ func TestPool_CycleUsed(t *testing.T) { func TestPool_Skip(t *testing.T) { _, ipnet, _ := net.ParseCIDR("192.168.0.1/29") - tree := trie.New() - tree.Insert("example.com", tree) + tree := trie.New[bool]() + tree.Insert("example.com", true) pools, tempfile, err := createPools(Options{ IPNet: ipnet, Size: 10, diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index e1100a31..3c8ba384 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -5,6 +5,7 @@ import ( "errors" "math/rand" "net" + "net/netip" "strings" "time" @@ -23,7 +24,7 @@ var ( DisableIPv6 = true // DefaultHosts aim to resolve hosts - DefaultHosts = trie.New() + DefaultHosts = trie.New[netip.Addr]() // DefaultDNSTimeout defined the default dns request timeout DefaultDNSTimeout = time.Second * 5 @@ -48,8 +49,8 @@ func ResolveIPv4(host string) (net.IP, error) { func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { if node := DefaultHosts.Search(host); node != nil { - if ip := node.Data.(net.IP).To4(); ip != nil { - return ip, nil + if ip := node.Data; ip.Is4() { + return ip.AsSlice(), nil } } @@ -92,8 +93,8 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { } if node := DefaultHosts.Search(host); node != nil { - if ip := node.Data.(net.IP).To16(); ip != nil { - return ip, nil + if ip := node.Data; ip.Is6() { + return ip.AsSlice(), nil } } @@ -128,7 +129,8 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { // ResolveIPWithResolver same as ResolveIP, but with a resolver func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) { if node := DefaultHosts.Search(host); node != nil { - return node.Data.(net.IP), nil + ip := node.Data + return ip.Unmap().AsSlice(), nil } if r != nil { diff --git a/component/trie/domain.go b/component/trie/domain.go index 8915eda3..16dd9ae9 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -17,8 +17,8 @@ var ErrInvalidDomain = errors.New("invalid domain") // DomainTrie contains the main logic for adding and searching nodes for domain segments. // support wildcard domain (e.g *.google.com) -type DomainTrie struct { - root *Node +type DomainTrie[T comparable] struct { + root *Node[T] } func ValidAndSplitDomain(domain string) ([]string, bool) { @@ -51,7 +51,7 @@ func ValidAndSplitDomain(domain string) ([]string, bool) { // 3. subdomain.*.example.com // 4. .example.com // 5. +.example.com -func (t *DomainTrie) Insert(domain string, data any) error { +func (t *DomainTrie[T]) Insert(domain string, data T) error { parts, valid := ValidAndSplitDomain(domain) if !valid { return ErrInvalidDomain @@ -68,13 +68,13 @@ func (t *DomainTrie) Insert(domain string, data any) error { return nil } -func (t *DomainTrie) insert(parts []string, data any) { +func (t *DomainTrie[T]) insert(parts []string, data T) { node := t.root // reverse storage domain part to save space for i := len(parts) - 1; i >= 0; i-- { part := parts[i] if !node.hasChild(part) { - node.addChild(part, newNode(nil)) + node.addChild(part, newNode(getZero[T]())) } node = node.getChild(part) @@ -88,7 +88,7 @@ func (t *DomainTrie) insert(parts []string, data any) { // 1. static part // 2. wildcard domain // 2. dot wildcard domain -func (t *DomainTrie) Search(domain string) *Node { +func (t *DomainTrie[T]) Search(domain string) *Node[T] { parts, valid := ValidAndSplitDomain(domain) if !valid || parts[0] == "" { return nil @@ -96,26 +96,26 @@ func (t *DomainTrie) Search(domain string) *Node { n := t.search(t.root, parts) - if n == nil || n.Data == nil { + if n == nil || n.Data == getZero[T]() { return nil } return n } -func (t *DomainTrie) search(node *Node, parts []string) *Node { +func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] { if len(parts) == 0 { return node } if c := node.getChild(parts[len(parts)-1]); c != nil { - if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != nil { + if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() { return n } } if c := node.getChild(wildcard); c != nil { - if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != nil { + if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() { return n } } @@ -124,6 +124,6 @@ func (t *DomainTrie) search(node *Node, parts []string) *Node { } // New returns a new, empty Trie. -func New() *DomainTrie { - return &DomainTrie{root: newNode(nil)} +func New[T comparable]() *DomainTrie[T] { + return &DomainTrie[T]{root: newNode[T](getZero[T]())} } diff --git a/component/trie/domain_test.go b/component/trie/domain_test.go index 4322699a..ced44d03 100644 --- a/component/trie/domain_test.go +++ b/component/trie/domain_test.go @@ -1,16 +1,16 @@ package trie import ( - "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" ) -var localIP = net.IP{127, 0, 0, 1} +var localIP = netip.AddrFrom4([4]byte{127, 0, 0, 1}) func TestTrie_Basic(t *testing.T) { - tree := New() + tree := New[netip.Addr]() domains := []string{ "example.com", "google.com", @@ -23,7 +23,7 @@ func TestTrie_Basic(t *testing.T) { node := tree.Search("example.com") assert.NotNil(t, node) - assert.True(t, node.Data.(net.IP).Equal(localIP)) + assert.True(t, node.Data == localIP) assert.NotNil(t, tree.Insert("", localIP)) assert.Nil(t, tree.Search("")) assert.NotNil(t, tree.Search("localhost")) @@ -31,7 +31,7 @@ func TestTrie_Basic(t *testing.T) { } func TestTrie_Wildcard(t *testing.T) { - tree := New() + tree := New[netip.Addr]() domains := []string{ "*.example.com", "sub.*.example.com", @@ -64,7 +64,7 @@ func TestTrie_Wildcard(t *testing.T) { } func TestTrie_Priority(t *testing.T) { - tree := New() + tree := New[int]() domains := []string{ ".dev", "example.dev", @@ -79,18 +79,18 @@ func TestTrie_Priority(t *testing.T) { } for idx, domain := range domains { - tree.Insert(domain, idx) + tree.Insert(domain, idx+1) } - assertFn("test.dev", 0) - assertFn("foo.bar.dev", 0) - assertFn("example.dev", 1) - assertFn("foo.example.dev", 2) - assertFn("test.example.dev", 3) + assertFn("test.dev", 1) + assertFn("foo.bar.dev", 1) + assertFn("example.dev", 2) + assertFn("foo.example.dev", 3) + assertFn("test.example.dev", 4) } func TestTrie_Boundary(t *testing.T) { - tree := New() + tree := New[netip.Addr]() tree.Insert("*.dev", localIP) assert.NotNil(t, tree.Insert(".", localIP)) @@ -99,7 +99,7 @@ func TestTrie_Boundary(t *testing.T) { } func TestTrie_WildcardBoundary(t *testing.T) { - tree := New() + tree := New[netip.Addr]() tree.Insert("+.*", localIP) tree.Insert("stun.*.*.*", localIP) diff --git a/component/trie/node.go b/component/trie/node.go index 67ef64a4..1545d880 100644 --- a/component/trie/node.go +++ b/component/trie/node.go @@ -1,26 +1,31 @@ package trie // Node is the trie's node -type Node struct { - children map[string]*Node - Data any +type Node[T comparable] struct { + children map[string]*Node[T] + Data T } -func (n *Node) getChild(s string) *Node { +func (n *Node[T]) getChild(s string) *Node[T] { return n.children[s] } -func (n *Node) hasChild(s string) bool { +func (n *Node[T]) hasChild(s string) bool { return n.getChild(s) != nil } -func (n *Node) addChild(s string, child *Node) { +func (n *Node[T]) addChild(s string, child *Node[T]) { n.children[s] = child } -func newNode(data any) *Node { - return &Node{ +func newNode[T comparable](data T) *Node[T] { + return &Node[T]{ Data: data, - children: map[string]*Node{}, + children: map[string]*Node[T]{}, } } + +func getZero[T comparable]() T { + var result T + return result +} diff --git a/dns/filters.go b/dns/filters.go index 61b1917b..6f316198 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -35,13 +35,13 @@ type fallbackDomainFilter interface { } type domainFilter struct { - tree *trie.DomainTrie + tree *trie.DomainTrie[bool] } func NewDomainFilter(domains []string) *domainFilter { - df := domainFilter{tree: trie.New()} + df := domainFilter{tree: trie.New[bool]()} for _, domain := range domains { - df.tree.Insert(domain, "") + df.tree.Insert(domain, true) } return &df } diff --git a/dns/middleware.go b/dns/middleware.go index 5958fe93..7259df66 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -2,6 +2,7 @@ package dns import ( "net" + "net/netip" "strings" "time" @@ -20,7 +21,7 @@ type ( middleware func(next handler) handler ) -func withHosts(hosts *trie.DomainTrie) middleware { +func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware { return func(next handler) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] @@ -34,19 +35,19 @@ func withHosts(hosts *trie.DomainTrie) middleware { return next(ctx, r) } - ip := record.Data.(net.IP) + ip := record.Data msg := r.Copy() - if v4 := ip.To4(); v4 != nil && q.Qtype == D.TypeA { + if ip.Is4() && q.Qtype == D.TypeA { rr := &D.A{} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} - rr.A = v4 + rr.A = ip.AsSlice() msg.Answer = []D.RR{rr} - } else if v6 := ip.To16(); v6 != nil && q.Qtype == D.TypeAAAA { + } else if ip.Is6() && q.Qtype == D.TypeAAAA { rr := &D.AAAA{} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL} - rr.AAAA = v6 + rr.AAAA = ip.AsSlice() msg.Answer = []D.RR{rr} } else { diff --git a/dns/policy.go b/dns/policy.go new file mode 100644 index 00000000..a8b423e1 --- /dev/null +++ b/dns/policy.go @@ -0,0 +1,30 @@ +package dns + +type Policy struct { + data []dnsClient +} + +func (p *Policy) GetData() []dnsClient { + return p.data +} + +func (p *Policy) Compare(p2 *Policy) int { + if p2 == nil { + return 1 + } + l1 := len(p.data) + l2 := len(p2.data) + if l1 == l2 { + return 0 + } + if l1 > l2 { + return 1 + } + return -1 +} + +func NewPolicy(data []dnsClient) *Policy { + return &Policy{ + data: data, + } +} diff --git a/dns/resolver.go b/dns/resolver.go index 9968eaf0..afa0d99a 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "net" + "net/netip" "strings" "time" @@ -33,14 +34,14 @@ type result struct { type Resolver struct { ipv6 bool - hosts *trie.DomainTrie + hosts *trie.DomainTrie[netip.Addr] main []dnsClient fallback []dnsClient fallbackDomainFilters []fallbackDomainFilter fallbackIPFilters []fallbackIPFilter group singleflight.Group lruCache *cache.LruCache[string, *D.Msg] - policy *trie.DomainTrie + policy *trie.DomainTrie[*Policy] proxyServer []dnsClient } @@ -194,7 +195,8 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { return nil } - return record.Data.([]dnsClient) + p := record.Data + return p.GetData() } func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { @@ -330,7 +332,7 @@ type Config struct { EnhancedMode C.DNSMode FallbackFilter FallbackFilter Pool *fakeip.Pool - Hosts *trie.DomainTrie + Hosts *trie.DomainTrie[netip.Addr] Policy map[string]NameServer } @@ -356,9 +358,9 @@ func NewResolver(config Config) *Resolver { } if len(config.Policy) != 0 { - r.policy = trie.New() + r.policy = trie.New[*Policy]() for domain, nameserver := range config.Policy { - r.policy.Insert(domain, transform([]NameServer{nameserver}, defaultResolver)) + r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver))) } } From f036e06f6f70c790f7d1dc51e3949d13e98efa5a Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Sun, 10 Apr 2022 03:59:27 +0800 Subject: [PATCH 14/14] Feature: MITM rewrite --- README.md | 40 ++- adapter/inbound/mitm.go | 22 ++ adapter/outbound/http.go | 4 + adapter/outbound/mitm.go | 68 ++++ common/cert/cert.go | 282 ++++++++++++++++ common/cert/cert_test.go | 76 +++++ common/cert/storage.go | 32 ++ component/geodata/memconservative/cache.go | 4 +- component/geodata/standard/standard.go | 2 +- config/config.go | 89 ++++- constant/adapters.go | 3 + constant/metadata.go | 4 + constant/path.go | 8 +- constant/rewrite.go | 82 +++++ constant/rule.go | 3 + dns/middleware.go | 15 +- go.mod | 10 +- go.sum | 16 +- hub/executor/executor.go | 18 +- hub/route/configs.go | 2 + listener/http/proxy.go | 20 +- listener/http/utils.go | 8 +- listener/listener.go | 109 +++++++ listener/mitm/client.go | 54 ++++ listener/mitm/proxy.go | 357 +++++++++++++++++++++ listener/mitm/server.go | 90 ++++++ listener/mitm/session.go | 56 ++++ listener/mitm/utils.go | 100 ++++++ rewrite/base.go | 72 +++++ rewrite/handler.go | 202 ++++++++++++ rewrite/parser.go | 78 +++++ rewrite/parser_test.go | 56 ++++ rewrite/rewrite.go | 89 +++++ rewrite/util.go | 28 ++ rule/parser.go | 2 + rule/user_gent.go | 52 +++ test/go.mod | 8 +- test/go.sum | 16 +- tunnel/statistic/tracker.go | 3 +- tunnel/tunnel.go | 35 +- 40 files changed, 2144 insertions(+), 71 deletions(-) create mode 100644 adapter/inbound/mitm.go create mode 100644 adapter/outbound/mitm.go create mode 100644 common/cert/cert.go create mode 100644 common/cert/cert_test.go create mode 100644 common/cert/storage.go create mode 100644 constant/rewrite.go create mode 100644 listener/mitm/client.go create mode 100644 listener/mitm/proxy.go create mode 100644 listener/mitm/server.go create mode 100644 listener/mitm/session.go create mode 100644 listener/mitm/utils.go create mode 100644 rewrite/base.go create mode 100644 rewrite/handler.go create mode 100644 rewrite/parser.go create mode 100644 rewrite/parser_test.go create mode 100644 rewrite/rewrite.go create mode 100644 rewrite/util.go create mode 100644 rule/user_gent.go diff --git a/README.md b/README.md index 2eca8bbe..be6503cb 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,44 @@ Documentations are now moved to [GitHub Wiki](https://github.com/Dreamacro/clash/wiki). ## Advanced usage for this branch +### MITM configuration +A root CA certificate is required, the +MITM proxy server will generate a CA certificate file and a CA private key file in your Clash home directory, you can use your own certificate replace it. + +Need to install and trust the CA certificate on the client device, open this URL http://mitm.clash/cert.crt by the web browser to install the CA certificate, the host name 'mitm.clash' was always been hijacked. + +NOTE: this feature cannot work on tls pinning + +WARNING: DO NOT USE THIS FEATURE TO BREAK LOCAL LAWS + +```yaml +# Port of MITM proxy server on the local end +mitm-port: 7894 + +# Man-In-The-Middle attack +mitm: + hosts: # use for others proxy type. E.g: TUN, socks + - +.example.com + rules: # rewrite rules + - '^https?://www\.example\.com/1 url reject' # The "reject" returns HTTP status code 404 with no content. + - '^https?://www\.example\.com/2 url reject-200' # The "reject-200" returns HTTP status code 200 with no content. + - '^https?://www\.example\.com/3 url reject-img' # The "reject-img" returns HTTP status code 200 with content of 1px png. + - '^https?://www\.example\.com/4 url reject-dict' # The "reject-dict" returns HTTP status code 200 with content of empty json object. + - '^https?://www\.example\.com/5 url reject-array' # The "reject-array" returns HTTP status code 200 with content of empty json array. + - '^https?://www\.example\.com/(6) url 302 https://www.example.com/new-$1' + - '^https?://www\.(example)\.com/7 url 307 https://www.$1.com/new-7' + - '^https?://www\.example\.com/8 url request-header (\r\n)User-Agent:.+(\r\n) request-header $1User-Agent: haha-wriohoh$2' # The "request-header" works for all the http headers not just one single header, so you can match two or more headers including CRLF in one regular expression. + - '^https?://www\.example\.com/9 url request-body "pos_2":\[.*\],"pos_3" request-body "pos_2":[{"xx": "xx"}],"pos_3"' + - '^https?://www\.example\.com/10 url response-header (\r\n)Tracecode:.+(\r\n) response-header $1Tracecode: 88888888888$2' + - '^https?://www\.example\.com/11 url response-body "errmsg":"ok" response-body "errmsg":"not-ok"' +``` + ### DNS configuration Support resolve ip with a proxy tunnel. Support `geosite` with `fallback-filter`. -Use curl -X POST controllerip:port/cache/fakeip/flush to flush persistence fakeip +Use `curl -X POST controllerip:port/cache/fakeip/flush` to flush persistence fakeip ```yaml dns: enable: true @@ -85,6 +117,7 @@ tun: ``` ### Rules configuration - Support rule `GEOSITE`. +- Support rule `USER-AGENT`. - Support `multiport` condition for rule `SRC-PORT` and `DST-PORT`. - Support `network` condition for all rules. - Support `process` condition for all rules. @@ -104,7 +137,10 @@ rules: # multiport condition for rules SRC-PORT and DST-PORT - DST-PORT,123/136/137-139,DIRECT,udp - + + # USER-AGENT payload cannot include the comma character, '*' meaning any character. + - USER-AGENT,*example*,PROXY + # rule GEOSITE - GEOSITE,category-ads-all,REJECT - GEOSITE,icloud@cn,DIRECT diff --git a/adapter/inbound/mitm.go b/adapter/inbound/mitm.go new file mode 100644 index 00000000..db3645ab --- /dev/null +++ b/adapter/inbound/mitm.go @@ -0,0 +1,22 @@ +package inbound + +import ( + "net" + + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/context" + "github.com/Dreamacro/clash/transport/socks5" +) + +// NewMitm receive mitm request and return MitmContext +func NewMitm(target socks5.Addr, source net.Addr, userAgent string, conn net.Conn) *context.ConnContext { + metadata := parseSocksAddr(target) + metadata.NetWork = C.TCP + metadata.Type = C.MITM + metadata.UserAgent = userAgent + if ip, port, err := parseAddr(source.String()); err == nil { + metadata.SrcIP = ip + metadata.SrcPort = port + } + return context.NewConnContext(conn, metadata) +} diff --git a/adapter/outbound/http.go b/adapter/outbound/http.go index 44dc705a..ff87af6f 100644 --- a/adapter/outbound/http.go +++ b/adapter/outbound/http.go @@ -89,6 +89,10 @@ func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error { req.Header.Add("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) } + if metadata.Type == C.MITM { + req.Header.Add("Origin-Request-Source-Address", metadata.SourceAddress()) + } + if err := req.Write(rw); err != nil { return err } diff --git a/adapter/outbound/mitm.go b/adapter/outbound/mitm.go new file mode 100644 index 00000000..80577cd9 --- /dev/null +++ b/adapter/outbound/mitm.go @@ -0,0 +1,68 @@ +package outbound + +import ( + "context" + "errors" + + "github.com/Dreamacro/clash/component/dialer" + "github.com/Dreamacro/clash/component/trie" + C "github.com/Dreamacro/clash/constant" + + "go.uber.org/atomic" +) + +var ( + errIgnored = errors.New("not match in mitm host lists") + httpProxyClient = NewHttp(HttpOption{}) + + MiddlemanServerAddress = atomic.NewString("") + MiddlemanRewriteHosts *trie.DomainTrie[bool] +) + +type Mitm struct { + *Base +} + +// DialContext implements C.ProxyAdapter +func (d *Mitm) DialContext(ctx context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.Conn, error) { + addr := MiddlemanServerAddress.Load() + if addr == "" || MiddlemanRewriteHosts == nil { + return nil, errIgnored + } + + if MiddlemanRewriteHosts.Search(metadata.String()) == nil { + return nil, errIgnored + } + + metadata.Type = C.MITM + + if metadata.Host != "" { + metadata.AddrType = C.AtypDomainName + metadata.DstIP = nil + } + + c, err := dialer.DialContext(ctx, "tcp", addr, []dialer.Option{dialer.WithInterface(""), dialer.WithRoutingMark(0)}...) + if err != nil { + return nil, err + } + + tcpKeepAlive(c) + + defer safeConnClose(c, err) + + c, err = httpProxyClient.StreamConn(c, metadata) + if err != nil { + return nil, err + } + + return NewConn(c, d), nil +} + +func NewMitm() *Mitm { + return &Mitm{ + Base: &Base{ + name: "Mitm", + tp: C.Mitm, + }, + } +} diff --git a/common/cert/cert.go b/common/cert/cert.go new file mode 100644 index 00000000..3c931665 --- /dev/null +++ b/common/cert/cert.go @@ -0,0 +1,282 @@ +package cert + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "sync/atomic" + "time" +) + +var currentSerialNumber = time.Now().Unix() + +type Config struct { + ca *x509.Certificate + caPrivateKey *rsa.PrivateKey + + roots *x509.CertPool + + privateKey *rsa.PrivateKey + + validity time.Duration + keyID []byte + organization string + + certsStorage CertsStorage +} + +type CertsStorage interface { + Get(key string) (*tls.Certificate, bool) + + Set(key string, cert *tls.Certificate) +} + +type CertsCache struct { + certsCache map[string]*tls.Certificate +} + +func (c *CertsCache) Get(key string) (*tls.Certificate, bool) { + v, ok := c.certsCache[key] + return v, ok +} + +func (c *CertsCache) Set(key string, cert *tls.Certificate) { + c.certsCache[key] = cert +} + +func NewAuthority(name, organization string, validity time.Duration) (*x509.Certificate, *rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + pub := privateKey.Public() + + pkixPub, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + return nil, nil, err + } + h := sha1.New() + _, err = h.Write(pkixPub) + if err != nil { + return nil, nil, err + } + keyID := h.Sum(nil) + + serial := atomic.AddInt64(¤tSerialNumber, 1) + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(serial), + Subject: pkix.Name{ + CommonName: name, + Organization: []string{organization}, + }, + SubjectKeyId: keyID, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + NotBefore: time.Now().Add(-validity), + NotAfter: time.Now().Add(validity), + DNSNames: []string{name}, + IsCA: true, + } + + raw, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, privateKey) + if err != nil { + return nil, nil, err + } + + x509c, err := x509.ParseCertificate(raw) + if err != nil { + return nil, nil, err + } + + return x509c, privateKey, nil +} + +func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage CertsStorage) (*Config, error) { + roots := x509.NewCertPool() + roots.AddCert(ca) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + pub := privateKey.Public() + + pkixPub, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + return nil, err + } + h := sha1.New() + _, err = h.Write(pkixPub) + if err != nil { + return nil, err + } + keyID := h.Sum(nil) + + if storage == nil { + storage = &CertsCache{certsCache: make(map[string]*tls.Certificate)} + } + + return &Config{ + ca: ca, + caPrivateKey: caPrivateKey, + privateKey: privateKey, + keyID: keyID, + validity: time.Hour, + organization: "Clash", + certsStorage: storage, + roots: roots, + }, nil +} + +func (c *Config) GetCA() *x509.Certificate { + return c.ca +} + +func (c *Config) SetOrganization(organization string) { + c.organization = organization +} + +func (c *Config) SetValidity(validity time.Duration) { + c.validity = validity +} + +func (c *Config) NewTLSConfigForHost(hostname string) *tls.Config { + tlsConfig := &tls.Config{ + GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + host := clientHello.ServerName + if host == "" { + host = hostname + } + + return c.GetOrCreateCert(host) + }, + NextProtos: []string{"http/1.1"}, + } + + tlsConfig.InsecureSkipVerify = true + + return tlsConfig +} + +func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certificate, error) { + host, _, err := net.SplitHostPort(hostname) + if err == nil { + hostname = host + } + + tlsCertificate, ok := c.certsStorage.Get(hostname) + if ok { + if _, err = tlsCertificate.Leaf.Verify(x509.VerifyOptions{ + DNSName: hostname, + Roots: c.roots, + }); err == nil { + return tlsCertificate, nil + } + } + + serial := atomic.AddInt64(¤tSerialNumber, 1) + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(serial), + Subject: pkix.Name{ + CommonName: hostname, + Organization: []string{c.organization}, + }, + SubjectKeyId: c.keyID, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + NotBefore: time.Now().Add(-c.validity), + NotAfter: time.Now().Add(c.validity), + } + + if ip := net.ParseIP(hostname); ip != nil { + ips = append(ips, ip) + } else { + tmpl.DNSNames = []string{hostname} + } + + tmpl.IPAddresses = ips + + raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey) + if err != nil { + return nil, err + } + + x509c, err := x509.ParseCertificate(raw) + if err != nil { + return nil, err + } + + tlsCertificate = &tls.Certificate{ + Certificate: [][]byte{raw, c.ca.Raw}, + PrivateKey: c.privateKey, + Leaf: x509c, + } + + c.certsStorage.Set(hostname, tlsCertificate) + return tlsCertificate, nil +} + +// GenerateAndSave generate CA private key and CA certificate and dump them to file +func GenerateAndSave(caPath string, caKeyPath string) error { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().Unix()), + Subject: pkix.Name{ + Country: []string{"US"}, + CommonName: "Clash Root CA", + Organization: []string{"Clash Trust Services"}, + }, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + NotBefore: time.Now().Add(-(time.Hour * 24 * 60)), + NotAfter: time.Now().Add(time.Hour * 24 * 365 * 25), + BasicConstraintsValid: true, + IsCA: true, + } + + caRaw, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, privateKey.Public(), privateKey) + if err != nil { + return err + } + + caOut, err := os.OpenFile(caPath, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return err + } + defer func(caOut *os.File) { + _ = caOut.Close() + }(caOut) + + if err = pem.Encode(caOut, &pem.Block{Type: "CERTIFICATE", Bytes: caRaw}); err != nil { + return err + } + + caKeyOut, err := os.OpenFile(caKeyPath, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return err + } + defer func(caKeyOut *os.File) { + _ = caKeyOut.Close() + }(caKeyOut) + + if err = pem.Encode(caKeyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}); err != nil { + return err + } + + return nil +} diff --git a/common/cert/cert_test.go b/common/cert/cert_test.go new file mode 100644 index 00000000..42265613 --- /dev/null +++ b/common/cert/cert_test.go @@ -0,0 +1,76 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "net" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCert(t *testing.T) { + ca, privateKey, err := NewAuthority("Clash ca", "Clash", 24*time.Hour) + + assert.Nil(t, err) + assert.NotNil(t, ca) + assert.NotNil(t, privateKey) + + c, err := NewConfig(ca, privateKey, nil) + assert.Nil(t, err) + + c.SetValidity(20 * time.Hour) + c.SetOrganization("Test Organization") + + conf := c.NewTLSConfigForHost("example.org") + assert.Equal(t, []string{"http/1.1"}, conf.NextProtos) + assert.True(t, conf.InsecureSkipVerify) + + // Test generating a certificate + clientHello := &tls.ClientHelloInfo{ + ServerName: "example.org", + } + tlsCert, err := conf.GetCertificate(clientHello) + assert.Nil(t, err) + assert.NotNil(t, tlsCert) + + // Assert certificate details + x509c := tlsCert.Leaf + assert.Equal(t, "example.org", x509c.Subject.CommonName) + assert.Nil(t, x509c.VerifyHostname("example.org")) + assert.Equal(t, []string{"Test Organization"}, x509c.Subject.Organization) + assert.NotNil(t, x509c.SubjectKeyId) + assert.True(t, x509c.BasicConstraintsValid) + assert.True(t, x509c.KeyUsage&x509.KeyUsageKeyEncipherment == x509.KeyUsageKeyEncipherment) + assert.True(t, x509c.KeyUsage&x509.KeyUsageDigitalSignature == x509.KeyUsageDigitalSignature) + assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509c.ExtKeyUsage) + assert.Equal(t, []string{"example.org"}, x509c.DNSNames) + assert.True(t, x509c.NotBefore.Before(time.Now().Add(-2*time.Hour))) + assert.True(t, x509c.NotAfter.After(time.Now().Add(2*time.Hour))) + + // Check that certificate is cached + tlsCert2, err := c.GetOrCreateCert("example.org") + assert.Nil(t, err) + assert.True(t, tlsCert == tlsCert2) + + // Check the certificate for an IP + tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1:443") + assert.Nil(t, err) + x509c = tlsCertForIP.Leaf + assert.Equal(t, 1, len(x509c.IPAddresses)) + assert.True(t, net.ParseIP("192.168.0.1").Equal(x509c.IPAddresses[0])) +} + +func TestGenerateAndSave(t *testing.T) { + caPath := "ca.crt" + caKeyPath := "ca.key" + + err := GenerateAndSave(caPath, caKeyPath) + + assert.Nil(t, err) + + _ = os.Remove(caPath) + _ = os.Remove(caKeyPath) +} diff --git a/common/cert/storage.go b/common/cert/storage.go new file mode 100644 index 00000000..61663e73 --- /dev/null +++ b/common/cert/storage.go @@ -0,0 +1,32 @@ +package cert + +import ( + "crypto/tls" + "time" + + "github.com/Dreamacro/clash/common/cache" +) + +var TTL = time.Hour * 2 + +// AutoGCCertsStorage cache with the generated certificates, auto released after TTL +type AutoGCCertsStorage struct { + certsCache *cache.Cache[string, *tls.Certificate] +} + +// Get gets the certificate from the storage +func (c *AutoGCCertsStorage) Get(key string) (*tls.Certificate, bool) { + ca := c.certsCache.Get(key) + return ca, ca != nil +} + +// Set saves the certificate to the storage +func (c *AutoGCCertsStorage) Set(key string, cert *tls.Certificate) { + c.certsCache.Put(key, cert, TTL) +} + +func NewAutoGCCertsStorage() *AutoGCCertsStorage { + return &AutoGCCertsStorage{ + certsCache: cache.New[string, *tls.Certificate](TTL), + } +} diff --git a/component/geodata/memconservative/cache.go b/component/geodata/memconservative/cache.go index 2981e5c0..3a94d352 100644 --- a/component/geodata/memconservative/cache.go +++ b/component/geodata/memconservative/cache.go @@ -33,7 +33,7 @@ func (g GeoIPCache) Set(key string, value *router.GeoIP) { } func (g GeoIPCache) Unmarshal(filename, code string) (*router.GeoIP, error) { - asset := C.Path.GetAssetLocation(filename) + asset := C.Path.Resolve(filename) idx := strings.ToLower(asset + ":" + code) if g.Has(idx) { return g.Get(idx), nil @@ -98,7 +98,7 @@ func (g GeoSiteCache) Set(key string, value *router.GeoSite) { } func (g GeoSiteCache) Unmarshal(filename, code string) (*router.GeoSite, error) { - asset := C.Path.GetAssetLocation(filename) + asset := C.Path.Resolve(filename) idx := strings.ToLower(asset + ":" + code) if g.Has(idx) { return g.Get(idx), nil diff --git a/component/geodata/standard/standard.go b/component/geodata/standard/standard.go index 0febbc08..86a5791d 100644 --- a/component/geodata/standard/standard.go +++ b/component/geodata/standard/standard.go @@ -26,7 +26,7 @@ func ReadFile(path string) ([]byte, error) { } func ReadAsset(file string) ([]byte, error) { - return ReadFile(C.Path.GetAssetLocation(file)) + return ReadFile(C.Path.Resolve(file)) } func loadIP(filename, country string) ([]*router.CIDR, error) { diff --git a/config/config.go b/config/config.go index eedc1959..e86a5e42 100644 --- a/config/config.go +++ b/config/config.go @@ -25,6 +25,7 @@ import ( "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/listener/tun/ipstack/commons" "github.com/Dreamacro/clash/log" + rewrites "github.com/Dreamacro/clash/rewrite" R "github.com/Dreamacro/clash/rule" T "github.com/Dreamacro/clash/tunnel" @@ -49,6 +50,7 @@ type Inbound struct { RedirPort int `json:"redir-port"` TProxyPort int `json:"tproxy-port"` MixedPort int `json:"mixed-port"` + MitmPort int `json:"mitm-port"` Authentication []string `json:"authentication"` AllowLan bool `json:"allow-lan"` BindAddress string `json:"bind-address"` @@ -72,7 +74,7 @@ type DNS struct { EnhancedMode C.DNSMode `yaml:"enhanced-mode"` DefaultNameserver []dns.NameServer `yaml:"default-nameserver"` FakeIPRange *fakeip.Pool - Hosts *trie.DomainTrie + Hosts *trie.DomainTrie[netip.Addr] NameServerPolicy map[string]dns.NameServer ProxyServerNameserver []dns.NameServer } @@ -107,6 +109,12 @@ type IPTables struct { InboundInterface string `yaml:"inbound-interface" json:"inbound-interface"` } +// Mitm config +type Mitm struct { + Hosts *trie.DomainTrie[bool] `yaml:"hosts" json:"hosts"` + Rules C.RewriteRule `yaml:"rules" json:"rules"` +} + // Experimental config type Experimental struct{} @@ -115,9 +123,10 @@ type Config struct { General *General Tun *Tun IPTables *IPTables + Mitm *Mitm DNS *DNS Experimental *Experimental - Hosts *trie.DomainTrie + Hosts *trie.DomainTrie[netip.Addr] Profile *Profile Rules []C.Rule Users []auth.AuthUser @@ -157,12 +166,18 @@ type RawTun struct { AutoRoute bool `yaml:"auto-route" json:"auto-route"` } +type RawMitm struct { + Hosts []string `yaml:"hosts" json:"hosts"` + Rules []string `yaml:"rules" json:"rules"` +} + type RawConfig struct { Port int `yaml:"port"` SocksPort int `yaml:"socks-port"` RedirPort int `yaml:"redir-port"` TProxyPort int `yaml:"tproxy-port"` MixedPort int `yaml:"mixed-port"` + MitmPort int `yaml:"mitm-port"` Authentication []string `yaml:"authentication"` AllowLan bool `yaml:"allow-lan"` BindAddress string `yaml:"bind-address"` @@ -180,6 +195,7 @@ type RawConfig struct { DNS RawDNS `yaml:"dns"` Tun RawTun `yaml:"tun"` IPTables IPTables `yaml:"iptables"` + MITM RawMitm `yaml:"mitm"` Experimental Experimental `yaml:"experimental"` Profile Profile `yaml:"profile"` Proxy []map[string]any `yaml:"proxies"` @@ -240,6 +256,10 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { "tls://223.5.5.5:853", }, }, + MITM: RawMitm{ + Hosts: []string{}, + Rules: []string{}, + }, Profile: Profile{ StoreSelected: true, }, @@ -298,6 +318,12 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { } config.DNS = dnsCfg + mitm, err := parseMitm(rawCfg.MITM) + if err != nil { + return nil, err + } + config.Mitm = mitm + config.Users = parseAuthentication(rawCfg.Authentication) return config, nil @@ -322,6 +348,7 @@ func parseGeneral(cfg *RawConfig) (*General, error) { RedirPort: cfg.RedirPort, TProxyPort: cfg.TProxyPort, MixedPort: cfg.MixedPort, + MitmPort: cfg.MitmPort, AllowLan: cfg.AllowLan, BindAddress: cfg.BindAddress, }, @@ -501,24 +528,29 @@ func parseRules(cfg *RawConfig, proxies map[string]C.Proxy) ([]C.Rule, error) { return rules, nil } -func parseHosts(cfg *RawConfig) (*trie.DomainTrie, error) { - tree := trie.New() +func parseHosts(cfg *RawConfig) (*trie.DomainTrie[netip.Addr], error) { + tree := trie.New[netip.Addr]() // add default hosts - if err := tree.Insert("localhost", net.IP{127, 0, 0, 1}); err != nil { + if err := tree.Insert("localhost", netip.AddrFrom4([4]byte{127, 0, 0, 1})); err != nil { log.Errorln("insert localhost to host error: %s", err.Error()) } if len(cfg.Hosts) != 0 { for domain, ipStr := range cfg.Hosts { - ip := net.ParseIP(ipStr) - if ip == nil { + ip, err := netip.ParseAddr(ipStr) + if err != nil { return nil, fmt.Errorf("%s is not a valid IP", ipStr) } _ = tree.Insert(domain, ip) } } + // add mitm.clash hosts + if err := tree.Insert("mitm.clash", netip.AddrFrom4([4]byte{8, 8, 9, 9})); err != nil { + log.Errorln("insert mitm.clash to host error: %s", err.Error()) + } + return tree, nil } @@ -652,7 +684,7 @@ func parseFallbackGeoSite(countries []string, rules []C.Rule) ([]*router.DomainM return sites, nil } -func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, error) { +func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[netip.Addr], rules []C.Rule) (*DNS, error) { cfg := rawCfg.DNS if cfg.Enable && len(cfg.NameServer) == 0 { return nil, fmt.Errorf("if DNS configuration is turned on, NameServer cannot be empty") @@ -705,10 +737,10 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, return nil, err } - var host *trie.DomainTrie + var host *trie.DomainTrie[bool] // fake ip skip host filter if len(cfg.FakeIPFilter) != 0 { - host = trie.New() + host = trie.New[bool]() for _, domain := range cfg.FakeIPFilter { _ = host.Insert(domain, true) } @@ -716,7 +748,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS, if len(dnsCfg.Fallback) != 0 { if host == nil { - host = trie.New() + host = trie.New[bool]() } for _, fb := range dnsCfg.Fallback { if net.ParseIP(fb.Addr) != nil { @@ -803,3 +835,38 @@ func parseTun(rawTun RawTun, general *General) (*Tun, error) { AutoRoute: rawTun.AutoRoute, }, nil } + +func parseMitm(rawMitm RawMitm) (*Mitm, error) { + var ( + req []C.Rewrite + res []C.Rewrite + ) + + for _, line := range rawMitm.Rules { + rule, err := rewrites.ParseRewrite(line) + if err != nil { + return nil, fmt.Errorf("parse rewrite rule failure: %w", err) + } + + if rule.RuleType() == C.MitmResponseHeader || rule.RuleType() == C.MitmResponseBody { + res = append(res, rule) + } else { + req = append(req, rule) + } + } + + hosts := trie.New[bool]() + + if len(rawMitm.Hosts) != 0 { + for _, domain := range rawMitm.Hosts { + _ = hosts.Insert(domain, true) + } + } + + _ = hosts.Insert("mitm.clash", true) + + return &Mitm{ + Hosts: hosts, + Rules: rewrites.NewRewriteRules(req, res), + }, nil +} diff --git a/constant/adapters.go b/constant/adapters.go index 2898d9c7..40849422 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -13,6 +13,7 @@ import ( const ( Direct AdapterType = iota Reject + Mitm Shadowsocks ShadowsocksR @@ -129,6 +130,8 @@ func (at AdapterType) String() string { return "Direct" case Reject: return "Reject" + case Mitm: + return "Mitm" case Shadowsocks: return "Shadowsocks" diff --git a/constant/metadata.go b/constant/metadata.go index 3da67201..70ed909b 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -23,6 +23,7 @@ const ( REDIR TPROXY TUN + MITM ) type NetWork int @@ -58,6 +59,8 @@ func (t Type) String() string { return "TProxy" case TUN: return "Tun" + case MITM: + return "Mitm" default: return "Unknown" } @@ -80,6 +83,7 @@ type Metadata struct { DNSMode DNSMode `json:"dnsMode"` Process string `json:"process"` ProcessPath string `json:"processPath"` + UserAgent string `json:"userAgent"` } func (m *Metadata) RemoteAddress() string { diff --git a/constant/path.go b/constant/path.go index 4580b25c..b3167edc 100644 --- a/constant/path.go +++ b/constant/path.go @@ -71,6 +71,10 @@ func (p *path) GeoSite() string { return P.Join(p.homeDir, "geosite.dat") } -func (p *path) GetAssetLocation(file string) string { - return P.Join(p.homeDir, file) +func (p *path) RootCA() string { + return p.Resolve("mitm_ca.crt") +} + +func (p *path) CAKey() string { + return p.Resolve("mitm_ca.key") } diff --git a/constant/rewrite.go b/constant/rewrite.go new file mode 100644 index 00000000..06adde35 --- /dev/null +++ b/constant/rewrite.go @@ -0,0 +1,82 @@ +package constant + +import ( + "regexp" +) + +var RewriteTypeMapping = map[string]RewriteType{ + MitmReject.String(): MitmReject, + MitmReject200.String(): MitmReject200, + MitmRejectImg.String(): MitmRejectImg, + MitmRejectDict.String(): MitmRejectDict, + MitmRejectArray.String(): MitmRejectArray, + Mitm302.String(): Mitm302, + Mitm307.String(): Mitm307, + MitmRequestHeader.String(): MitmRequestHeader, + MitmRequestBody.String(): MitmRequestBody, + MitmResponseHeader.String(): MitmResponseHeader, + MitmResponseBody.String(): MitmResponseBody, +} + +const ( + MitmReject RewriteType = iota + MitmReject200 + MitmRejectImg + MitmRejectDict + MitmRejectArray + + Mitm302 + Mitm307 + + MitmRequestHeader + MitmRequestBody + + MitmResponseHeader + MitmResponseBody +) + +type RewriteType int + +func (rt RewriteType) String() string { + switch rt { + case MitmReject: + return "reject" // 404 + case MitmReject200: + return "reject-200" + case MitmRejectImg: + return "reject-img" + case MitmRejectDict: + return "reject-dict" + case MitmRejectArray: + return "reject-array" + case Mitm302: + return "302" + case Mitm307: + return "307" + case MitmRequestHeader: + return "request-header" + case MitmRequestBody: + return "request-body" + case MitmResponseHeader: + return "response-header" + case MitmResponseBody: + return "response-body" + default: + return "Unknown" + } +} + +type Rewrite interface { + ID() string + URLRegx() *regexp.Regexp + RuleType() RewriteType + RuleRegx() *regexp.Regexp + RulePayload() string + ReplaceURLPayload([]string) string + ReplaceSubPayload(string) string +} + +type RewriteRule interface { + SearchInRequest(func(Rewrite) bool) bool + SearchInResponse(func(Rewrite) bool) bool +} diff --git a/constant/rule.go b/constant/rule.go index 23f421aa..d59658c9 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -13,6 +13,7 @@ const ( DstPort Process ProcessPath + UserAgent MATCH ) @@ -42,6 +43,8 @@ func (rt RuleType) String() string { return "Process" case ProcessPath: return "ProcessPath" + case UserAgent: + return "UserAgent" case MATCH: return "Match" default: diff --git a/dns/middleware.go b/dns/middleware.go index 7259df66..4091fa9e 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -21,7 +21,7 @@ type ( middleware func(next handler) handler ) -func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware { +func withHosts(hosts *trie.DomainTrie[netip.Addr], mapping *cache.LruCache[string, string]) middleware { return func(next handler) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] @@ -30,23 +30,28 @@ func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware { return next(ctx, r) } - record := hosts.Search(strings.TrimRight(q.Name, ".")) + qName := strings.TrimRight(q.Name, ".") + record := hosts.Search(qName) if record == nil { return next(ctx, r) } ip := record.Data + if mapping != nil { + mapping.SetWithExpire(ip.Unmap().String(), qName, time.Now().Add(time.Second*5)) + } + msg := r.Copy() if ip.Is4() && q.Qtype == D.TypeA { rr := &D.A{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 1} rr.A = ip.AsSlice() msg.Answer = []D.RR{rr} } else if ip.Is6() && q.Qtype == D.TypeAAAA { rr := &D.AAAA{} - rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL} + rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 1} rr.AAAA = ip.AsSlice() msg.Answer = []D.RR{rr} @@ -177,7 +182,7 @@ func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { middlewares := []middleware{} if resolver.hosts != nil { - middlewares = append(middlewares, withHosts(resolver.hosts)) + middlewares = append(middlewares, withHosts(resolver.hosts, mapper.mapping)) } if mapper.mode == C.DNSFakeIP { diff --git a/go.mod b/go.mod index 5f4d8fbd..31d291dc 100644 --- a/go.mod +++ b/go.mod @@ -18,10 +18,11 @@ require ( go.etcd.io/bbolt v1.3.6 go.uber.org/atomic v1.9.0 go.uber.org/automaxprocs v1.4.0 - golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd - golang.org/x/net v0.0.0-20220225172249-27dd8689420f + golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 + golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 + golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f + golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 golang.zx2c4.com/wireguard v0.0.0-20220318042302-193cf8d6a5d6 golang.zx2c4.com/wireguard/windows v0.5.4-0.20220317000008-6432784c2469 @@ -37,8 +38,7 @@ require ( github.com/oschwald/maxminddb-golang v1.8.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/u-root/uio v0.0.0-20210528114334-82958018845c // indirect - golang.org/x/mod v0.5.1 // indirect - golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab // indirect + golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect golang.org/x/tools v0.1.9 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect diff --git a/go.sum b/go.sum index 40b7a038..479f39bf 100644 --- a/go.sum +++ b/go.sum @@ -81,11 +81,11 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210317152858-513c2a44f670/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= -golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 h1:iU7T1X1J6yxDr0rda54sWGkHgOp5XJrqm79gcNlC2VM= +golang.org/x/crypto v0.0.0-20220408190544-5352b0902921/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38= -golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= +golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 h1:LQmS1nU0twXLA96Kt7U9qtHJEbBk3z6Q0V4UXjZkpr4= +golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190419010253-1f3472d942ba/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -99,8 +99,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 h1:EN5+DfgmRMvRUrMGERW2gQl3Vc+Z7ZMnI/xdEpPSf0c= +golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -125,8 +125,8 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 h1:A9i04dxx7Cribqbs8jf3FQLogkL/CV2YN7hj9KWJCkc= -golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f h1:8w7RhxzTVgUzw/AH/9mUV5q0vMgy40SQRursCcfmkCw= +golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 887c473a..5166ba04 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -3,12 +3,14 @@ package executor import ( "fmt" "net" + "net/netip" "os" "runtime" "strconv" "sync" "github.com/Dreamacro/clash/adapter" + "github.com/Dreamacro/clash/adapter/outbound" "github.com/Dreamacro/clash/adapter/outboundgroup" "github.com/Dreamacro/clash/component/auth" "github.com/Dreamacro/clash/component/dialer" @@ -71,12 +73,17 @@ func ApplyConfig(cfg *config.Config, force bool) { mux.Lock() defer mux.Unlock() - log.SetLevel(log.DEBUG) + if cfg.General.LogLevel == log.DEBUG { + log.SetLevel(log.DEBUG) + } else { + log.SetLevel(log.INFO) + } updateUsers(cfg.Users) updateProxies(cfg.Proxies, cfg.Providers) updateRules(cfg.Rules) updateHosts(cfg.Hosts) + updateMitm(cfg.Mitm) updateProfile(cfg) updateDNS(cfg.DNS, cfg.Tun) updateGeneral(cfg.General, force) @@ -101,6 +108,7 @@ func GetGeneral() *config.General { RedirPort: ports.RedirPort, TProxyPort: ports.TProxyPort, MixedPort: ports.MixedPort, + MitmPort: ports.MitmPort, Authentication: authenticator, AllowLan: P.AllowLan(), BindAddress: P.BindAddress(), @@ -168,7 +176,7 @@ func updateDNS(c *config.DNS, t *config.Tun) { } } -func updateHosts(tree *trie.DomainTrie) { +func updateHosts(tree *trie.DomainTrie[netip.Addr]) { resolver.DefaultHosts = tree } @@ -225,6 +233,7 @@ func updateGeneral(general *config.General, force bool) { P.ReCreateRedir(general.RedirPort, tcpIn, udpIn) P.ReCreateTProxy(general.TProxyPort, tcpIn, udpIn) P.ReCreateMixed(general.MixedPort, tcpIn, udpIn) + P.ReCreateMitm(general.MitmPort, tcpIn) } func updateUsers(users []auth.AuthUser) { @@ -330,6 +339,11 @@ func updateIPTables(cfg *config.Config) { log.Infoln("[IPTABLES] Setting iptables completed") } +func updateMitm(mitm *config.Mitm) { + outbound.MiddlemanRewriteHosts = mitm.Hosts + tunnel.UpdateRewrites(mitm.Rules) +} + func Shutdown() { P.Cleanup() tproxy.CleanupTProxyIPTables() diff --git a/hub/route/configs.go b/hub/route/configs.go index 3e36c054..a930c32b 100644 --- a/hub/route/configs.go +++ b/hub/route/configs.go @@ -30,6 +30,7 @@ type configSchema struct { RedirPort *int `json:"redir-port"` TProxyPort *int `json:"tproxy-port"` MixedPort *int `json:"mixed-port"` + MitmPort *int `json:"mitm-port"` Tun *config.Tun `json:"tun"` AllowLan *bool `json:"allow-lan"` BindAddress *string `json:"bind-address"` @@ -77,6 +78,7 @@ func patchConfigs(w http.ResponseWriter, r *http.Request) { P.ReCreateRedir(pointerOrDefault(general.RedirPort, ports.RedirPort), tcpIn, udpIn) P.ReCreateTProxy(pointerOrDefault(general.TProxyPort, ports.TProxyPort), tcpIn, udpIn) P.ReCreateMixed(pointerOrDefault(general.MixedPort, ports.MixedPort), tcpIn, udpIn) + P.ReCreateMitm(pointerOrDefault(general.MitmPort, ports.MitmPort), tcpIn) if general.Mode != nil { tunnel.SetMode(*general.Mode) diff --git a/listener/http/proxy.go b/listener/http/proxy.go index e8a805a9..d29f80f5 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -42,7 +42,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, var resp *http.Response if !trusted { - resp = authenticate(request, cache) + resp = Authenticate(request, cache) trusted = resp == nil } @@ -66,19 +66,19 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, request.RequestURI = "" - removeHopByHopHeaders(request.Header) - removeExtraHTTPHostPort(request) + RemoveHopByHopHeaders(request.Header) + RemoveExtraHTTPHostPort(request) if request.URL.Scheme == "" || request.URL.Host == "" { - resp = responseWith(request, http.StatusBadRequest) + resp = ResponseWith(request, http.StatusBadRequest) } else { resp, err = client.Do(request) if err != nil { - resp = responseWith(request, http.StatusBadGateway) + resp = ResponseWith(request, http.StatusBadGateway) } } - removeHopByHopHeaders(resp.Header) + RemoveHopByHopHeaders(resp.Header) } if keepAlive { @@ -98,12 +98,12 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, conn.Close() } -func authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response { +func Authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response { authenticator := authStore.Authenticator() if authenticator != nil { credential := parseBasicProxyAuthorization(request) if credential == "" { - resp := responseWith(request, http.StatusProxyAuthRequired) + resp := ResponseWith(request, http.StatusProxyAuthRequired) resp.Header.Set("Proxy-Authenticate", "Basic") return resp } @@ -117,14 +117,14 @@ func authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http if !authed { log.Infoln("Auth failed from %s", request.RemoteAddr) - return responseWith(request, http.StatusForbidden) + return ResponseWith(request, http.StatusForbidden) } } return nil } -func responseWith(request *http.Request, statusCode int) *http.Response { +func ResponseWith(request *http.Request, statusCode int) *http.Response { return &http.Response{ StatusCode: statusCode, Status: http.StatusText(statusCode), diff --git a/listener/http/utils.go b/listener/http/utils.go index 74b12005..0e7c7535 100644 --- a/listener/http/utils.go +++ b/listener/http/utils.go @@ -8,8 +8,8 @@ import ( "strings" ) -// removeHopByHopHeaders remove hop-by-hop header -func removeHopByHopHeaders(header http.Header) { +// RemoveHopByHopHeaders remove hop-by-hop header +func RemoveHopByHopHeaders(header http.Header) { // Strip hop-by-hop header based on RFC: // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1 // https://www.mnot.net/blog/2011/07/11/what_proxies_must_do @@ -32,9 +32,9 @@ func removeHopByHopHeaders(header http.Header) { } } -// removeExtraHTTPHostPort remove extra host port (example.com:80 --> example.com) +// RemoveExtraHTTPHostPort remove extra host port (example.com:80 --> example.com) // It resolves the behavior of some HTTP servers that do not handle host:80 (e.g. baidu.com) -func removeExtraHTTPHostPort(req *http.Request) { +func RemoveExtraHTTPHostPort(req *http.Request) { host := req.Host if host == "" { host = req.URL.Host diff --git a/listener/listener.go b/listener/listener.go index 46157e5d..1bda3a8d 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -1,6 +1,9 @@ package proxy import ( + "crypto/rsa" + "crypto/tls" + "crypto/x509" "fmt" "net" "os" @@ -8,9 +11,12 @@ import ( "sync" "github.com/Dreamacro/clash/adapter/inbound" + "github.com/Dreamacro/clash/adapter/outbound" + "github.com/Dreamacro/clash/common/cert" "github.com/Dreamacro/clash/config" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/listener/http" + "github.com/Dreamacro/clash/listener/mitm" "github.com/Dreamacro/clash/listener/mixed" "github.com/Dreamacro/clash/listener/redir" "github.com/Dreamacro/clash/listener/socks" @@ -18,6 +24,8 @@ import ( "github.com/Dreamacro/clash/listener/tun" "github.com/Dreamacro/clash/listener/tun/ipstack" "github.com/Dreamacro/clash/log" + rewrites "github.com/Dreamacro/clash/rewrite" + "github.com/Dreamacro/clash/tunnel" ) var ( @@ -34,6 +42,7 @@ var ( mixedListener *mixed.Listener mixedUDPLister *socks.UDPListener tunStackListener ipstack.Stack + mitmListener *mitm.Listener // lock for recreate function socksMux sync.Mutex @@ -42,6 +51,7 @@ var ( tproxyMux sync.Mutex mixedMux sync.Mutex tunMux sync.Mutex + mitmMux sync.Mutex ) type Ports struct { @@ -50,6 +60,7 @@ type Ports struct { RedirPort int `json:"redir-port"` TProxyPort int `json:"tproxy-port"` MixedPort int `json:"mixed-port"` + MitmPort int `json:"mitm-port"` } func AllowLan() bool { @@ -331,6 +342,85 @@ func ReCreateTun(tunConf *config.Tun, tunAddressPrefix string, tcpIn chan<- C.Co tunStackListener, err = tun.New(tunConf, tunAddressPrefix, tcpIn, udpIn) } +func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) { + mitmMux.Lock() + defer mitmMux.Unlock() + + var err error + defer func() { + if err != nil { + log.Errorln("Start MITM server error: %s", err.Error()) + } + }() + + addr := genAddr(bindAddress, port, allowLan) + + if mitmListener != nil { + if mitmListener.RawAddress() == addr { + return + } + outbound.MiddlemanServerAddress.Store("") + tunnel.MitmOutbound = nil + _ = mitmListener.Close() + mitmListener = nil + } + + if portIsZero(addr) { + return + } + + if err = initCert(); err != nil { + return + } + + var ( + rootCACert tls.Certificate + x509c *x509.Certificate + certOption *cert.Config + ) + + rootCACert, err = tls.LoadX509KeyPair(C.Path.RootCA(), C.Path.CAKey()) + if err != nil { + return + } + + privateKey := rootCACert.PrivateKey.(*rsa.PrivateKey) + + x509c, err = x509.ParseCertificate(rootCACert.Certificate[0]) + if err != nil { + return + } + + certOption, err = cert.NewConfig( + x509c, + privateKey, + cert.NewAutoGCCertsStorage(), + ) + if err != nil { + return + } + + certOption.SetValidity(cert.TTL << 3) + certOption.SetOrganization("Clash ManInTheMiddle Proxy Services") + + opt := &mitm.Option{ + Addr: addr, + ApiHost: "mitm.clash", + CertConfig: certOption, + Handler: &rewrites.RewriteHandler{}, + } + + mitmListener, err = mitm.New(opt, tcpIn) + if err != nil { + return + } + + outbound.MiddlemanServerAddress.Store(mitmListener.Address()) + tunnel.MitmOutbound = outbound.NewMitm() + + log.Infoln("Mitm proxy listening at: %s", mitmListener.Address()) +} + // GetPorts return the ports of proxy servers func GetPorts() *Ports { ports := &Ports{} @@ -365,6 +455,12 @@ func GetPorts() *Ports { ports.MixedPort = port } + if mitmListener != nil { + _, portStr, _ := net.SplitHostPort(mitmListener.Address()) + port, _ := strconv.Atoi(portStr) + ports.MitmPort = port + } + return ports } @@ -387,6 +483,19 @@ func genAddr(host string, port int, allowLan bool) string { return fmt.Sprintf("127.0.0.1:%d", port) } +func initCert() error { + if _, err := os.Stat(C.Path.RootCA()); os.IsNotExist(err) { + log.Infoln("Can't find mitm_ca.crt, start generate") + err = cert.GenerateAndSave(C.Path.RootCA(), C.Path.CAKey()) + if err != nil { + return err + } + log.Infoln("Generated CA private key and CA certificate finish") + } + + return nil +} + func Cleanup() { if tunStackListener != nil { _ = tunStackListener.Close() diff --git a/listener/mitm/client.go b/listener/mitm/client.go new file mode 100644 index 00000000..278de173 --- /dev/null +++ b/listener/mitm/client.go @@ -0,0 +1,54 @@ +package mitm + +import ( + "context" + "crypto/tls" + "errors" + "net" + "net/http" + "time" + + "github.com/Dreamacro/clash/adapter/inbound" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/socks5" +) + +var ErrCertUnsupported = errors.New("tls: client cert unsupported") + +func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http.Client { + return &http.Client{ + Transport: &http.Transport{ + // excepted HTTP/2 + TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), + // from http.DefaultTransport + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{ + GetClientCertificate: func(info *tls.CertificateRequestInfo) (certificate *tls.Certificate, e error) { + return nil, ErrCertUnsupported + }, + }, + DialContext: func(context context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, errors.New("unsupported network " + network) + } + + dstAddr := socks5.ParseAddr(address) + if dstAddr == nil { + return nil, socks5.ErrAddressNotSupported + } + + left, right := net.Pipe() + + in <- inbound.NewMitm(dstAddr, source, userAgent, right) + + return left, nil + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } +} diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go new file mode 100644 index 00000000..a0d3bab6 --- /dev/null +++ b/listener/mitm/proxy.go @@ -0,0 +1,357 @@ +package mitm + +import ( + "bytes" + "crypto/tls" + "encoding/pem" + "errors" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Dreamacro/clash/adapter/inbound" + "github.com/Dreamacro/clash/common/cache" + N "github.com/Dreamacro/clash/common/net" + C "github.com/Dreamacro/clash/constant" + httpL "github.com/Dreamacro/clash/listener/http" +) + +func HandleConn(c net.Conn, opt *Option, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { + var ( + source net.Addr + client *http.Client + ) + + defer func() { + if client != nil { + client.CloseIdleConnections() + } + }() + +startOver: + if tc, ok := c.(*net.TCPConn); ok { + _ = tc.SetKeepAlive(true) + } + + var conn *N.BufferedConn + if bufConn, ok := c.(*N.BufferedConn); ok { + conn = bufConn + } else { + conn = N.NewBufferedConn(c) + } + + trusted := cache == nil // disable authenticate if cache is nil + +readLoop: + for { + _ = conn.SetDeadline(time.Now().Add(30 * time.Second)) // use SetDeadline instead of Proxy-Connection keep-alive + + request, err := httpL.ReadRequest(conn.Reader()) + if err != nil { + handleError(opt, nil, err) + break readLoop + } + + var response *http.Response + + session := NewSession(conn, request, response) + + source = parseSourceAddress(session.request, c, source) + request.RemoteAddr = source.String() + + if !trusted { + response = httpL.Authenticate(request, cache) + + trusted = response == nil + } + + if trusted { + if session.request.Method == http.MethodConnect { + // Manual writing to support CONNECT for http 1.0 (workaround for uplay client) + if _, err = fmt.Fprintf(session.conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.request.ProtoMajor, session.request.ProtoMinor, http.StatusOK, "Connection established"); err != nil { + handleError(opt, session, err) + break readLoop // close connection + } + + if couldBeWithManInTheMiddleAttack(session.request.URL.Host, opt) { + b := make([]byte, 1) + if _, err = session.conn.Read(b); err != nil { + handleError(opt, session, err) + break readLoop // close connection + } + + buf := make([]byte, session.conn.(*N.BufferedConn).Buffered()) + _, _ = session.conn.Read(buf) + + mc := &MultiReaderConn{ + Conn: session.conn, + reader: io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), session.conn), + } + + // 22 is the TLS handshake. + // https://tools.ietf.org/html/rfc5246#section-6.2.1 + if b[0] == 22 { + // TODO serve by generic host name maybe better? + tlsConn := tls.Server(mc, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) + + // Handshake with the local client + if err = tlsConn.Handshake(); err != nil { + handleError(opt, session, err) + break readLoop // close connection + } + + c = tlsConn + goto startOver // hijack and decrypt tls connection + } + + // maybe it's the others encrypted connection + in <- inbound.NewHTTPS(request, mc) + } + + // maybe it's a http connection + goto readLoop + } + + // hijack api + if getHostnameWithoutPort(session.request) == opt.ApiHost { + if err = handleApiRequest(session, opt); err != nil { + handleError(opt, session, err) + break readLoop + } + return + } + + prepareRequest(c, session.request) + + // hijack custom request and write back custom response if necessary + if opt.Handler != nil { + newReq, newRes := opt.Handler.HandleRequest(session) + if newReq != nil { + session.request = newReq + } + if newRes != nil { + session.response = newRes + + if err = writeResponse(session, false); err != nil { + handleError(opt, session, err) + break readLoop + } + return + } + } + + httpL.RemoveHopByHopHeaders(session.request.Header) + httpL.RemoveExtraHTTPHostPort(request) + + session.request.RequestURI = "" + + if session.request.URL.Scheme == "" || session.request.URL.Host == "" { + session.response = session.NewErrorResponse(errors.New("invalid URL")) + } else { + client = newClientBySourceAndUserAgentIfNil(client, session.request, source, in) + + // send the request to remote server + session.response, err = client.Do(session.request) + + if err != nil { + handleError(opt, session, err) + session.response = session.NewErrorResponse(err) + if errors.Is(err, ErrCertUnsupported) || strings.Contains(err.Error(), "x509: ") { + // TODO block unsupported host? + } + } + } + } + + if err = writeResponseWithHandler(session, opt); err != nil { + handleError(opt, session, err) + break readLoop // close connection + } + } + + _ = conn.Close() +} + +func writeResponseWithHandler(session *Session, opt *Option) error { + if opt.Handler != nil { + res := opt.Handler.HandleResponse(session) + + if res != nil { + body := res.Body + defer func(body io.ReadCloser) { + _ = body.Close() + }(body) + + session.response = res + } + } + + return writeResponse(session, true) +} + +func writeResponse(session *Session, keepAlive bool) error { + httpL.RemoveHopByHopHeaders(session.response.Header) + + if keepAlive { + session.response.Header.Set("Connection", "keep-alive") + session.response.Header.Set("Keep-Alive", "timeout=25") + } + + // session.response.Close = !keepAlive // let handler do it + + return session.response.Write(session.conn) +} + +func handleApiRequest(session *Session, opt *Option) error { + if opt.CertConfig != nil && strings.ToLower(session.request.URL.Path) == "/cert.crt" { + b := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: opt.CertConfig.GetCA().Raw, + }) + + session.response = session.NewResponse(http.StatusOK, bytes.NewReader(b)) + + defer func(body io.ReadCloser) { + _ = body.Close() + }(session.response.Body) + + session.response.Close = true + session.response.Header.Set("Content-Type", "application/x-x509-ca-cert") + session.response.ContentLength = int64(len(b)) + + return session.response.Write(session.conn) + } + + b := ` + + Clash ManInTheMiddle Proxy Services - 404 Not Found + + +

Not Found

+

The requested URL %s was not found on this server.

+ + +` + if opt.Handler != nil { + if opt.Handler.HandleApiRequest(session) { + return nil + } + } + + b = fmt.Sprintf(b, session.request.URL.Path) + + session.response = session.NewResponse(http.StatusNotFound, bytes.NewReader([]byte(b))) + + defer func(body io.ReadCloser) { + _ = body.Close() + }(session.response.Body) + + session.response.Close = true + session.response.Header.Set("Content-Type", "text/html;charset=utf-8") + session.response.ContentLength = int64(len(b)) + + return session.response.Write(session.conn) +} + +func handleError(opt *Option, session *Session, err error) { + if opt.Handler != nil { + opt.Handler.HandleError(session, err) + return + } + + // log.Errorln("[MITM] process mitm error: %v", err) +} + +func prepareRequest(conn net.Conn, request *http.Request) { + host := request.Header.Get("Host") + if host != "" { + request.Host = host + } + + if request.URL.Host == "" { + request.URL.Host = request.Host + } + + request.URL.Scheme = "http" + + if tlsConn, ok := conn.(*tls.Conn); ok { + cs := tlsConn.ConnectionState() + request.TLS = &cs + + request.URL.Scheme = "https" + } + + if request.Header.Get("Accept-Encoding") != "" { + request.Header.Set("Accept-Encoding", "gzip") + } +} + +func couldBeWithManInTheMiddleAttack(hostname string, opt *Option) bool { + if opt.CertConfig == nil { + return false + } + + if _, port, err := net.SplitHostPort(hostname); err == nil && (port == "443" || port == "8443") { + return true + } + + return false +} + +func getHostnameWithoutPort(req *http.Request) string { + host := req.Host + if host == "" { + host = req.URL.Host + } + + if pHost, _, err := net.SplitHostPort(host); err == nil { + host = pHost + } + + return host +} + +func parseSourceAddress(req *http.Request, c net.Conn, source net.Addr) net.Addr { + if source != nil { + return source + } + + sourceAddress := req.Header.Get("Origin-Request-Source-Address") + if sourceAddress == "" { + return c.RemoteAddr() + } + + req.Header.Del("Origin-Request-Source-Address") + + host, port, err := net.SplitHostPort(sourceAddress) + if err != nil { + return c.RemoteAddr() + } + + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return c.RemoteAddr() + } + + if ip := net.ParseIP(host); ip != nil { + return &net.TCPAddr{ + IP: ip, + Port: int(p), + } + } + + return c.RemoteAddr() +} + +func newClientBySourceAndUserAgentIfNil(cli *http.Client, req *http.Request, source net.Addr, in chan<- C.ConnContext) *http.Client { + if cli != nil { + return cli + } + + return newClient(source, req.Header.Get("User-Agent"), in) +} diff --git a/listener/mitm/server.go b/listener/mitm/server.go new file mode 100644 index 00000000..d7699b81 --- /dev/null +++ b/listener/mitm/server.go @@ -0,0 +1,90 @@ +package mitm + +import ( + "crypto/tls" + "net" + "net/http" + "time" + + "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/common/cert" + C "github.com/Dreamacro/clash/constant" +) + +type Handler interface { + HandleRequest(*Session) (*http.Request, *http.Response) // Session.Response maybe nil + HandleResponse(*Session) *http.Response + HandleApiRequest(*Session) bool + HandleError(*Session, error) // Session maybe nil +} + +type Option struct { + Addr string + ApiHost string + + TLSConfig *tls.Config + CertConfig *cert.Config + + Handler Handler +} + +type Listener struct { + *Option + + listener net.Listener + addr string + closed bool +} + +// RawAddress implements C.Listener +func (l *Listener) RawAddress() string { + return l.addr +} + +// Address implements C.Listener +func (l *Listener) Address() string { + return l.listener.Addr().String() +} + +// Close implements C.Listener +func (l *Listener) Close() error { + l.closed = true + return l.listener.Close() +} + +// New the MITM proxy actually is a type of HTTP proxy +func New(option *Option, in chan<- C.ConnContext) (*Listener, error) { + return NewWithAuthenticate(option, in, false) +} + +func NewWithAuthenticate(option *Option, in chan<- C.ConnContext, authenticate bool) (*Listener, error) { + l, err := net.Listen("tcp", option.Addr) + if err != nil { + return nil, err + } + + var c *cache.Cache[string, bool] + if authenticate { + c = cache.New[string, bool](time.Second * 30) + } + + hl := &Listener{ + listener: l, + addr: option.Addr, + Option: option, + } + go func() { + for { + conn, err1 := hl.listener.Accept() + if err1 != nil { + if hl.closed { + break + } + continue + } + go HandleConn(conn, option, in, c) + } + }() + + return hl, nil +} diff --git a/listener/mitm/session.go b/listener/mitm/session.go new file mode 100644 index 00000000..2572d879 --- /dev/null +++ b/listener/mitm/session.go @@ -0,0 +1,56 @@ +package mitm + +import ( + "fmt" + "io" + "net" + "net/http" + + C "github.com/Dreamacro/clash/constant" +) + +var serverName = fmt.Sprintf("Clash server (%s)", C.Version) + +type Session struct { + conn net.Conn + request *http.Request + response *http.Response + + props map[string]any +} + +func (s *Session) Request() *http.Request { + return s.request +} + +func (s *Session) Response() *http.Response { + return s.response +} + +func (s *Session) GetProperties(key string) (any, bool) { + v, ok := s.props[key] + return v, ok +} + +func (s *Session) SetProperties(key string, val any) { + s.props[key] = val +} + +func (s *Session) NewResponse(code int, body io.Reader) *http.Response { + res := NewResponse(code, body, s.request) + res.Header.Set("Server", serverName) + return res +} + +func (s *Session) NewErrorResponse(err error) *http.Response { + return NewErrorResponse(s.request, err) +} + +func NewSession(conn net.Conn, request *http.Request, response *http.Response) *Session { + return &Session{ + conn: conn, + request: request, + response: response, + props: map[string]any{}, + } +} diff --git a/listener/mitm/utils.go b/listener/mitm/utils.go new file mode 100644 index 00000000..7d681d42 --- /dev/null +++ b/listener/mitm/utils.go @@ -0,0 +1,100 @@ +package mitm + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "time" + + "golang.org/x/text/encoding/charmap" + "golang.org/x/text/transform" +) + +type MultiReaderConn struct { + net.Conn + reader io.Reader +} + +func (c *MultiReaderConn) Read(buf []byte) (int, error) { + return c.reader.Read(buf) +} + +func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { + if body == nil { + body = &bytes.Buffer{} + } + + rc, ok := body.(io.ReadCloser) + if !ok { + rc = ioutil.NopCloser(body) + } + + res := &http.Response{ + StatusCode: code, + Status: fmt.Sprintf("%d %s", code, http.StatusText(code)), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Body: rc, + Request: req, + } + + if req != nil { + res.Close = req.Close + res.Proto = req.Proto + res.ProtoMajor = req.ProtoMajor + res.ProtoMinor = req.ProtoMinor + } + + return res +} + +func NewErrorResponse(req *http.Request, err error) *http.Response { + res := NewResponse(http.StatusBadGateway, nil, req) + res.Close = true + + date := res.Header.Get("Date") + if date == "" { + date = time.Now().Format(http.TimeFormat) + } + + w := fmt.Sprintf(`199 "clash" %q %q`, err.Error(), date) + res.Header.Add("Warning", w) + res.Header.Set("Server", serverName) + return res +} + +func ReadDecompressedBody(res *http.Response) ([]byte, error) { + rBody := res.Body + if res.Header.Get("Content-Encoding") == "gzip" { + gzReader, err := gzip.NewReader(rBody) + if err != nil { + return nil, err + } + rBody = gzReader + + defer func(gzReader *gzip.Reader) { + _ = gzReader.Close() + }(gzReader) + } + return ioutil.ReadAll(rBody) +} + +func DecodeLatin1(reader io.Reader) (string, error) { + r := transform.NewReader(reader, charmap.ISO8859_1.NewDecoder()) + b, err := ioutil.ReadAll(r) + if err != nil { + return "", err + } + + return string(b), nil +} + +func EncodeLatin1(str string) ([]byte, error) { + return charmap.ISO8859_1.NewEncoder().Bytes([]byte(str)) +} diff --git a/rewrite/base.go b/rewrite/base.go new file mode 100644 index 00000000..29ba0dc2 --- /dev/null +++ b/rewrite/base.go @@ -0,0 +1,72 @@ +package rewrites + +import ( + "bytes" + "io" + "io/ioutil" + + C "github.com/Dreamacro/clash/constant" +) + +var ( + EmptyDict = NewResponseBody([]byte("{}")) + EmptyArray = NewResponseBody([]byte("[]")) + OnePixelPNG = NewResponseBody([]byte{0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, 0x89, 0x00, 0x00, 0x00, 0x11, 0x49, 0x44, 0x41, 0x54, 0x78, 0x9c, 0x62, 0x62, 0x60, 0x60, 0x60, 0x00, 0x04, 0x00, 0x00, 0xff, 0xff, 0x00, 0x0f, 0x00, 0x03, 0xfe, 0x8f, 0xeb, 0xcf, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82}) +) + +type Body interface { + Body() io.ReadCloser + ContentLength() int64 +} + +type ResponseBody struct { + data []byte + length int64 +} + +func (r *ResponseBody) Body() io.ReadCloser { + return ioutil.NopCloser(bytes.NewReader(r.data)) +} + +func (r *ResponseBody) ContentLength() int64 { + return r.length +} + +func NewResponseBody(data []byte) *ResponseBody { + return &ResponseBody{ + data: data, + length: int64(len(data)), + } +} + +type RewriteRules struct { + request []C.Rewrite + response []C.Rewrite +} + +func (rr *RewriteRules) SearchInRequest(do func(C.Rewrite) bool) bool { + for _, v := range rr.request { + if do(v) { + return true + } + } + return false +} + +func (rr *RewriteRules) SearchInResponse(do func(C.Rewrite) bool) bool { + for _, v := range rr.response { + if do(v) { + return true + } + } + return false +} + +func NewRewriteRules(req []C.Rewrite, res []C.Rewrite) *RewriteRules { + return &RewriteRules{ + request: req, + response: res, + } +} + +var _ C.RewriteRule = (*RewriteRules)(nil) diff --git a/rewrite/handler.go b/rewrite/handler.go new file mode 100644 index 00000000..ddbafeb9 --- /dev/null +++ b/rewrite/handler.go @@ -0,0 +1,202 @@ +package rewrites + +import ( + "bufio" + "bytes" + "errors" + "io" + "io/ioutil" + "net/http" + "net/textproto" + "strconv" + "strings" + + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/listener/mitm" + "github.com/Dreamacro/clash/tunnel" +) + +var _ mitm.Handler = (*RewriteHandler)(nil) + +type RewriteHandler struct{} + +func (*RewriteHandler) HandleRequest(session *mitm.Session) (*http.Request, *http.Response) { + var ( + request = session.Request() + response *http.Response + ) + + rule, sub, found := matchRewriteRule(request.URL.String(), true) + if !found { + return nil, nil + } + + switch rule.RuleType() { + case C.MitmReject: + response = session.NewResponse(http.StatusNotFound, nil) + response.Header.Set("Content-Type", "text/html; charset=utf-8") + case C.MitmReject200: + response = session.NewResponse(http.StatusOK, nil) + response.Header.Set("Content-Type", "text/html; charset=utf-8") + case C.MitmRejectImg: + response = session.NewResponse(http.StatusOK, OnePixelPNG.Body()) + response.Header.Set("Content-Type", "image/png") + response.ContentLength = OnePixelPNG.ContentLength() + case C.MitmRejectDict: + response = session.NewResponse(http.StatusOK, EmptyDict.Body()) + response.Header.Set("Content-Type", "application/json; charset=utf-8") + response.ContentLength = EmptyDict.ContentLength() + case C.MitmRejectArray: + response = session.NewResponse(http.StatusOK, EmptyArray.Body()) + response.Header.Set("Content-Type", "application/json; charset=utf-8") + response.ContentLength = EmptyArray.ContentLength() + case C.Mitm302: + response = session.NewResponse(http.StatusFound, nil) + response.Header.Set("Location", rule.ReplaceURLPayload(sub)) + case C.Mitm307: + response = session.NewResponse(http.StatusTemporaryRedirect, nil) + response.Header.Set("Location", rule.ReplaceURLPayload(sub)) + case C.MitmRequestHeader: + if len(request.Header) == 0 { + return nil, nil + } + + rawHeader := &bytes.Buffer{} + oldHeader := request.Header + if err := oldHeader.Write(rawHeader); err != nil { + return nil, nil + } + + newRawHeader := rule.ReplaceSubPayload(rawHeader.String()) + tb := textproto.NewReader(bufio.NewReader(strings.NewReader(newRawHeader))) + newHeader, err := tb.ReadMIMEHeader() + if err != nil && !errors.Is(err, io.EOF) { + return nil, nil + } + request.Header = http.Header(newHeader) + case C.MitmRequestBody: + if !CanRewriteBody(request.ContentLength, request.Header.Get("Content-Type")) { + return nil, nil + } + + buf := make([]byte, request.ContentLength) + _, err := io.ReadFull(request.Body, buf) + if err != nil { + return nil, nil + } + + newBody := rule.ReplaceSubPayload(string(buf)) + request.Body = io.NopCloser(strings.NewReader(newBody)) + request.ContentLength = int64(len(newBody)) + default: + found = false + } + + if found { + if response != nil { + response.Close = true + } + return request, response + } + return nil, nil +} + +func (*RewriteHandler) HandleResponse(session *mitm.Session) *http.Response { + var ( + request = session.Request() + response = session.Response() + ) + + rule, _, found := matchRewriteRule(request.URL.String(), false) + found = found && rule.RuleRegx() != nil + if !found { + return nil + } + + switch rule.RuleType() { + case C.MitmResponseHeader: + if len(response.Header) == 0 { + return nil + } + + rawHeader := &bytes.Buffer{} + oldHeader := response.Header + if err := oldHeader.Write(rawHeader); err != nil { + return nil + } + + newRawHeader := rule.ReplaceSubPayload(rawHeader.String()) + tb := textproto.NewReader(bufio.NewReader(strings.NewReader(newRawHeader))) + newHeader, err := tb.ReadMIMEHeader() + if err != nil && !errors.Is(err, io.EOF) { + return nil + } + + response.Header = http.Header(newHeader) + response.Header.Set("Content-Length", strconv.FormatInt(response.ContentLength, 10)) + case C.MitmResponseBody: + if !CanRewriteBody(response.ContentLength, response.Header.Get("Content-Type")) { + return nil + } + + b, err := mitm.ReadDecompressedBody(response) + _ = response.Body.Close() + if err != nil { + return nil + } + + body, err := mitm.DecodeLatin1(bytes.NewReader(b)) + if err != nil { + return nil + } + + newBody := rule.ReplaceSubPayload(body) + + modifiedBody, err := mitm.EncodeLatin1(newBody) + if err != nil { + return nil + } + + response.Body = ioutil.NopCloser(bytes.NewReader(modifiedBody)) + response.Header.Del("Content-Encoding") + response.ContentLength = int64(len(modifiedBody)) + default: + found = false + } + + if found { + return response + } + return nil +} + +func (h *RewriteHandler) HandleApiRequest(*mitm.Session) bool { + return false +} + +// HandleError session maybe nil +func (h *RewriteHandler) HandleError(*mitm.Session, error) {} + +func matchRewriteRule(url string, isRequest bool) (rr C.Rewrite, sub []string, found bool) { + rewrites := tunnel.Rewrites() + if isRequest { + found = rewrites.SearchInRequest(func(r C.Rewrite) bool { + sub = r.URLRegx().FindStringSubmatch(url) + if len(sub) != 0 { + rr = r + return true + } + return false + }) + } else { + found = rewrites.SearchInResponse(func(r C.Rewrite) bool { + if r.URLRegx().FindString(url) != "" { + rr = r + return true + } + return false + }) + } + + return +} diff --git a/rewrite/parser.go b/rewrite/parser.go new file mode 100644 index 00000000..f97134d3 --- /dev/null +++ b/rewrite/parser.go @@ -0,0 +1,78 @@ +package rewrites + +import ( + "regexp" + "strings" + + C "github.com/Dreamacro/clash/constant" +) + +func ParseRewrite(line string) (C.Rewrite, error) { + url, others, found := strings.Cut(strings.TrimSpace(line), "url") + if !found { + return nil, errInvalid + } + + var ( + urlRegx *regexp.Regexp + ruleType *C.RewriteType + ruleRegx *regexp.Regexp + rulePayload string + + err error + ) + + urlRegx, err = regexp.Compile(strings.Trim(url, " ")) + if err != nil { + return nil, err + } + + others = strings.Trim(others, " ") + first := strings.Split(others, " ")[0] + for k, v := range C.RewriteTypeMapping { + if k == others { + ruleType = &v + break + } + + if k != first { + continue + } + + rs := trimArr(strings.Split(others, k)) + l := len(rs) + if l > 2 { + continue + } + + if l == 1 { + ruleType = &v + rulePayload = rs[0] + break + } else { + ruleRegx, err = regexp.Compile(rs[0]) + if err != nil { + return nil, err + } + + ruleType = &v + rulePayload = rs[1] + break + } + } + + if ruleType == nil { + return nil, errInvalid + } + + return NewRewriteRule(urlRegx, *ruleType, ruleRegx, rulePayload), nil +} + +func trimArr(arr []string) (r []string) { + for _, e := range arr { + if s := strings.Trim(e, " "); s != "" { + r = append(r, s) + } + } + return +} diff --git a/rewrite/parser_test.go b/rewrite/parser_test.go new file mode 100644 index 00000000..58d1149a --- /dev/null +++ b/rewrite/parser_test.go @@ -0,0 +1,56 @@ +package rewrites + +import ( + "bytes" + "fmt" + "image" + "image/color" + "image/draw" + "image/png" + "regexp" + "testing" + + "github.com/Dreamacro/clash/constant" + + "github.com/stretchr/testify/assert" +) + +func TestParseRewrite(t *testing.T) { + line0 := `^https?://example\.com/resource1/3/ url reject-dict` + line1 := `^https?://example\.com/(resource2)/ url 307 https://example.com/new-$1` + line2 := `^https?://example\.com/resource4/ url request-header (\r\n)User-Agent:.+(\r\n) request-header $1User-Agent: Fuck-Who$2` + line3 := `should be error` + + c0, err0 := ParseRewrite(line0) + c1, err1 := ParseRewrite(line1) + c2, err2 := ParseRewrite(line2) + _, err3 := ParseRewrite(line3) + + assert.NotNil(t, err3) + + assert.Nil(t, err0) + assert.Equal(t, c0.RuleType(), constant.MitmRejectDict) + + assert.Nil(t, err1) + assert.Equal(t, c1.RuleType(), constant.Mitm307) + assert.Equal(t, c1.URLRegx(), regexp.MustCompile(`^https?://example\.com/(resource2)/`)) + assert.Equal(t, c1.RulePayload(), "https://example.com/new-$1") + + assert.Nil(t, err2) + assert.Equal(t, c2.RuleType(), constant.MitmRequestHeader) + assert.Equal(t, c2.RuleRegx(), regexp.MustCompile(`(\r\n)User-Agent:.+(\r\n)`)) + assert.Equal(t, c2.RulePayload(), "$1User-Agent: Fuck-Who$2") +} + +func Test1PxPNG(t *testing.T) { + m := image.NewRGBA(image.Rect(0, 0, 1, 1)) + + draw.Draw(m, m.Bounds(), &image.Uniform{C: color.Transparent}, image.Point{}, draw.Src) + + buf := &bytes.Buffer{} + + assert.Nil(t, png.Encode(buf, m)) + + fmt.Printf("len: %d\n", buf.Len()) + fmt.Printf("% #x\n", buf.Bytes()) +} diff --git a/rewrite/rewrite.go b/rewrite/rewrite.go new file mode 100644 index 00000000..d88d4efe --- /dev/null +++ b/rewrite/rewrite.go @@ -0,0 +1,89 @@ +package rewrites + +import ( + "errors" + "regexp" + "strconv" + "strings" + + C "github.com/Dreamacro/clash/constant" + + "github.com/gofrs/uuid" +) + +var errInvalid = errors.New("invalid rewrite rule") + +type RewriteRule struct { + id string + urlRegx *regexp.Regexp + ruleType C.RewriteType + ruleRegx *regexp.Regexp + rulePayload string +} + +func (r *RewriteRule) ID() string { + return r.id +} + +func (r *RewriteRule) URLRegx() *regexp.Regexp { + return r.urlRegx +} + +func (r *RewriteRule) RuleType() C.RewriteType { + return r.ruleType +} + +func (r *RewriteRule) RuleRegx() *regexp.Regexp { + return r.ruleRegx +} + +func (r *RewriteRule) RulePayload() string { + return r.rulePayload +} + +func (r *RewriteRule) ReplaceURLPayload(matchSub []string) string { + url := r.rulePayload + + l := len(matchSub) + if l < 2 { + return url + } + + for i := 1; i < l; i++ { + url = strings.ReplaceAll(url, "$"+strconv.Itoa(i), matchSub[i]) + } + return url +} + +func (r *RewriteRule) ReplaceSubPayload(oldData string) string { + payload := r.rulePayload + if r.ruleRegx == nil { + return oldData + } + + sub := r.ruleRegx.FindStringSubmatch(oldData) + l := len(sub) + + if l == 0 { + return oldData + } + + for i := 1; i < l; i++ { + payload = strings.ReplaceAll(payload, "$"+strconv.Itoa(i), sub[i]) + } + + return strings.ReplaceAll(oldData, sub[0], payload) +} + +func NewRewriteRule(urlRegx *regexp.Regexp, ruleType C.RewriteType, ruleRegx *regexp.Regexp, rulePayload string) *RewriteRule { + id, _ := uuid.NewV4() + return &RewriteRule{ + id: id.String(), + urlRegx: urlRegx, + ruleType: ruleType, + ruleRegx: ruleRegx, + rulePayload: rulePayload, + } +} + +var _ C.Rewrite = (*RewriteRule)(nil) diff --git a/rewrite/util.go b/rewrite/util.go new file mode 100644 index 00000000..a12e4fa9 --- /dev/null +++ b/rewrite/util.go @@ -0,0 +1,28 @@ +package rewrites + +import ( + "strings" +) + +var allowContentType = []string{ + "text/", + "application/xhtml", + "application/xml", + "application/atom+xml", + "application/json", + "application/x-www-form-urlencoded", +} + +func CanRewriteBody(contentLength int64, contentType string) bool { + if contentLength <= 0 { + return false + } + + for _, v := range allowContentType { + if strings.HasPrefix(contentType, v) { + return true + } + } + + return false +} diff --git a/rule/parser.go b/rule/parser.go index 3374108e..5dc8dba4 100644 --- a/rule/parser.go +++ b/rule/parser.go @@ -37,6 +37,8 @@ func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { parsed, parseErr = NewProcess(payload, target, true) case "PROCESS-PATH": parsed, parseErr = NewProcess(payload, target, false) + case "USER-AGENT": + parsed, parseErr = NewUserAgent(payload, target) case "MATCH": parsed = NewMatch(target) default: diff --git a/rule/user_gent.go b/rule/user_gent.go new file mode 100644 index 00000000..162188e5 --- /dev/null +++ b/rule/user_gent.go @@ -0,0 +1,52 @@ +package rules + +import ( + "strings" + + C "github.com/Dreamacro/clash/constant" +) + +type UserAgent struct { + *Base + ua string + adapter string +} + +func (d *UserAgent) RuleType() C.RuleType { + return C.UserAgent +} + +func (d *UserAgent) Match(metadata *C.Metadata) bool { + if metadata.Type != C.MITM || metadata.UserAgent == "" { + return false + } + + return strings.Contains(metadata.UserAgent, d.ua) +} + +func (d *UserAgent) Adapter() string { + return d.adapter +} + +func (d *UserAgent) Payload() string { + return d.ua +} + +func (d *UserAgent) ShouldResolveIP() bool { + return false +} + +func NewUserAgent(ua string, adapter string) (*UserAgent, error) { + ua = strings.Trim(ua, "*") + if ua == "" { + return nil, errPayload + } + + return &UserAgent{ + Base: &Base{}, + ua: ua, + adapter: adapter, + }, nil +} + +var _ C.Rule = (*UserAgent)(nil) diff --git a/test/go.mod b/test/go.mod index b4ec09f0..94f109dc 100644 --- a/test/go.mod +++ b/test/go.mod @@ -8,7 +8,7 @@ require ( github.com/docker/go-connections v0.4.0 github.com/miekg/dns v1.1.47 github.com/stretchr/testify v1.7.1 - golang.org/x/net v0.0.0-20220225172249-27dd8689420f + golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 ) replace github.com/Dreamacro/clash => ../ @@ -39,10 +39,10 @@ require ( github.com/xtls/go v0.0.0-20210920065950-d4af136d3672 // indirect go.etcd.io/bbolt v1.3.6 // indirect go.uber.org/atomic v1.9.0 // indirect - golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect - golang.org/x/mod v0.5.1 // indirect + golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 // indirect + golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 // indirect + golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f // indirect golang.org/x/text v0.3.8-0.20220124021120-d1c84af989ab // indirect golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect golang.org/x/tools v0.1.9 // indirect diff --git a/test/go.sum b/test/go.sum index f1ffa840..991673db 100644 --- a/test/go.sum +++ b/test/go.sum @@ -913,8 +913,8 @@ golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWP golang.org/x/crypto v0.0.0-20210317152858-513c2a44f670/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= -golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220408190544-5352b0902921 h1:iU7T1X1J6yxDr0rda54sWGkHgOp5XJrqm79gcNlC2VM= +golang.org/x/crypto v0.0.0-20220408190544-5352b0902921/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -950,8 +950,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38= -golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= +golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 h1:LQmS1nU0twXLA96Kt7U9qtHJEbBk3z6Q0V4UXjZkpr4= +golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1012,8 +1012,8 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210825183410-e898025ed96a/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3 h1:EN5+DfgmRMvRUrMGERW2gQl3Vc+Z7ZMnI/xdEpPSf0c= +golang.org/x/net v0.0.0-20220407224826-aac1ed45d8e3/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1144,8 +1144,8 @@ golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211116061358-0a5406a5449c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 h1:A9i04dxx7Cribqbs8jf3FQLogkL/CV2YN7hj9KWJCkc= -golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f h1:8w7RhxzTVgUzw/AH/9mUV5q0vMgy40SQRursCcfmkCw= +golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 6fd8b3e7..1b24c107 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -80,8 +80,7 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R } manager.Join(t) - conn = NewSniffing(t, metadata) - return conn + return NewSniffing(t, metadata) } type udpTracker struct { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index b816871c..d2cb95be 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -26,6 +26,7 @@ var ( udpQueue = make(chan *inbound.PacketAdapter, 200) natTable = nat.New() rules []C.Rule + rewrites C.RewriteRule proxies = make(map[string]C.Proxy) providers map[string]provider.ProxyProvider configMux sync.RWMutex @@ -35,6 +36,9 @@ var ( // default timeout for UDP session udpTimeout = 60 * time.Second + + // MitmOutbound mitm proxy adapter + MitmOutbound C.ProxyAdapter ) func init() { @@ -91,6 +95,18 @@ func SetMode(m TunnelMode) { mode = m } +// Rewrites return all rewrites +func Rewrites() C.RewriteRule { + return rewrites +} + +// UpdateRewrites handle update rewrites +func UpdateRewrites(newRewrites C.RewriteRule) { + configMux.Lock() + rewrites = newRewrites + configMux.Unlock() +} + // processUDP starts a loop to handle udp packet func processUDP() { queue := udpQueue @@ -142,7 +158,7 @@ func preHandleMetadata(metadata *C.Metadata) error { metadata.DNSMode = C.DNSFakeIP } else if node := resolver.DefaultHosts.Search(host); node != nil { // redir-host should lookup the hosts - metadata.DstIP = node.Data.(net.IP) + metadata.DstIP = node.Data.AsSlice() } } else if resolver.IsFakeIP(metadata.DstIP) { return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) @@ -281,14 +297,24 @@ func handleTCPConn(connCtx C.ConnContext) { return } + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) + defer cancel() + if MitmOutbound != nil && metadata.Type != C.MITM { + if remoteConn, err1 := MitmOutbound.DialContext(ctx, metadata); err1 == nil { + remoteConn = statistic.NewSniffing(remoteConn, metadata) + defer remoteConn.Close() + + handleSocket(connCtx, remoteConn) + return + } + } + proxy, rule, err := resolveMetadata(connCtx, metadata) if err != nil { log.Warnln("[Metadata] parse failed: %s", err.Error()) return } - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout) - defer cancel() remoteConn, err := proxy.DialContext(ctx, metadata.Pure()) if err != nil { if rule == nil { @@ -326,8 +352,7 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { var resolved bool if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { - ip := node.Data.(net.IP) - metadata.DstIP = ip + metadata.DstIP = node.Data.AsSlice() resolved = true }