From 62bc75af8ae01613929699410b0eae2958792d26 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Mon, 25 Apr 2022 05:02:24 +0800 Subject: [PATCH] Chore: signature wildcard certificates --- common/cert/cert.go | 65 ++++++++++++++++++------------------- common/cert/cert_test.go | 20 ++++++++---- common/cert/storage.go | 30 ++++++++--------- listener/listener.go | 3 +- listener/mitm/client.go | 8 +++-- listener/mitm/proxy.go | 27 +++++---------- listener/mitm/session.go | 3 ++ listener/tun/tun_adapter.go | 1 + 8 files changed, 80 insertions(+), 77 deletions(-) diff --git a/common/cert/cert.go b/common/cert/cert.go index 3c931665..f8e826e2 100644 --- a/common/cert/cert.go +++ b/common/cert/cert.go @@ -11,6 +11,7 @@ import ( "math/big" "net" "os" + "strings" "sync/atomic" "time" ) @@ -38,19 +39,6 @@ type CertsStorage interface { Set(key string, cert *tls.Certificate) } -type CertsCache struct { - certsCache map[string]*tls.Certificate -} - -func (c *CertsCache) Get(key string) (*tls.Certificate, bool) { - v, ok := c.certsCache[key] - return v, ok -} - -func (c *CertsCache) Set(key string, cert *tls.Certificate) { - c.certsCache[key] = cert -} - func NewAuthority(name, organization string, validity time.Duration) (*x509.Certificate, *rsa.PrivateKey, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -100,7 +88,7 @@ func NewAuthority(name, organization string, validity time.Duration) (*x509.Cert return x509c, privateKey, nil } -func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage CertsStorage) (*Config, error) { +func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*Config, error) { roots := x509.NewCertPool() roots.AddCert(ca) @@ -121,10 +109,6 @@ func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage Certs } keyID := h.Sum(nil) - if storage == nil { - storage = &CertsCache{certsCache: make(map[string]*tls.Certificate)} - } - return &Config{ ca: ca, caPrivateKey: caPrivateKey, @@ -132,7 +116,7 @@ func NewConfig(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey, storage Certs keyID: keyID, validity: time.Hour, organization: "Clash", - certsStorage: storage, + certsStorage: NewDomainTrieCertsStorage(), roots: roots, }, nil } @@ -168,14 +152,9 @@ func (c *Config) NewTLSConfigForHost(hostname string) *tls.Config { } func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certificate, error) { - host, _, err := net.SplitHostPort(hostname) - if err == nil { - hostname = host - } - tlsCertificate, ok := c.certsStorage.Get(hostname) if ok { - if _, err = tlsCertificate.Leaf.Verify(x509.VerifyOptions{ + if _, err := tlsCertificate.Leaf.Verify(x509.VerifyOptions{ DNSName: hostname, Roots: c.roots, }); err == nil { @@ -183,12 +162,37 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica } } + var ( + key = hostname + topHost = hostname + dnsNames []string + ) + + if ip := net.ParseIP(hostname); ip != nil { + ips = append(ips, ip) + } else { + parts := strings.Split(topHost, ".") + l := len(parts) + + if l >= 2 { + for i := l - 2; i >= 0; i-- { + topHost = strings.Join(parts[i:], ".") + dnsNames = append(dnsNames, topHost, "*."+topHost) + } + + topHost = strings.Join(parts[l-2:], ".") + key = "+." + topHost + } else { + dnsNames = append(dnsNames, topHost) + } + } + serial := atomic.AddInt64(¤tSerialNumber, 1) tmpl := &x509.Certificate{ SerialNumber: big.NewInt(serial), Subject: pkix.Name{ - CommonName: hostname, + CommonName: topHost, Organization: []string{c.organization}, }, SubjectKeyId: c.keyID, @@ -199,12 +203,7 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica NotAfter: time.Now().Add(c.validity), } - if ip := net.ParseIP(hostname); ip != nil { - ips = append(ips, ip) - } else { - tmpl.DNSNames = []string{hostname} - } - + tmpl.DNSNames = dnsNames tmpl.IPAddresses = ips raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.privateKey.Public(), c.caPrivateKey) @@ -223,7 +222,7 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica Leaf: x509c, } - c.certsStorage.Set(hostname, tlsCertificate) + c.certsStorage.Set(key, tlsCertificate) return tlsCertificate, nil } diff --git a/common/cert/cert_test.go b/common/cert/cert_test.go index 42265613..de8202d9 100644 --- a/common/cert/cert_test.go +++ b/common/cert/cert_test.go @@ -18,19 +18,19 @@ func TestCert(t *testing.T) { assert.NotNil(t, ca) assert.NotNil(t, privateKey) - c, err := NewConfig(ca, privateKey, nil) + c, err := NewConfig(ca, privateKey) assert.Nil(t, err) c.SetValidity(20 * time.Hour) c.SetOrganization("Test Organization") - conf := c.NewTLSConfigForHost("example.org") + conf := c.NewTLSConfigForHost("abc.example.org") assert.Equal(t, []string{"http/1.1"}, conf.NextProtos) assert.True(t, conf.InsecureSkipVerify) // Test generating a certificate clientHello := &tls.ClientHelloInfo{ - ServerName: "example.org", + ServerName: "abc.example.org", } tlsCert, err := conf.GetCertificate(clientHello) assert.Nil(t, err) @@ -40,13 +40,15 @@ func TestCert(t *testing.T) { x509c := tlsCert.Leaf assert.Equal(t, "example.org", x509c.Subject.CommonName) assert.Nil(t, x509c.VerifyHostname("example.org")) + assert.Nil(t, x509c.VerifyHostname("abc.example.org")) + assert.Nil(t, x509c.VerifyHostname("efg.abc.example.org")) assert.Equal(t, []string{"Test Organization"}, x509c.Subject.Organization) assert.NotNil(t, x509c.SubjectKeyId) assert.True(t, x509c.BasicConstraintsValid) assert.True(t, x509c.KeyUsage&x509.KeyUsageKeyEncipherment == x509.KeyUsageKeyEncipherment) assert.True(t, x509c.KeyUsage&x509.KeyUsageDigitalSignature == x509.KeyUsageDigitalSignature) assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509c.ExtKeyUsage) - assert.Equal(t, []string{"example.org"}, x509c.DNSNames) + assert.Equal(t, []string{"example.org", "*.example.org", "abc.example.org", "*.abc.example.org"}, x509c.DNSNames) assert.True(t, x509c.NotBefore.Before(time.Now().Add(-2*time.Hour))) assert.True(t, x509c.NotAfter.After(time.Now().Add(2*time.Hour))) @@ -56,11 +58,17 @@ func TestCert(t *testing.T) { assert.True(t, tlsCert == tlsCert2) // Check the certificate for an IP - tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1:443") - assert.Nil(t, err) + tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1") x509c = tlsCertForIP.Leaf + assert.Nil(t, err) assert.Equal(t, 1, len(x509c.IPAddresses)) assert.True(t, net.ParseIP("192.168.0.1").Equal(x509c.IPAddresses[0])) + + tlsCertForIP2, err := c.GetOrCreateCert("192.168.0.1") + x509c = tlsCertForIP2.Leaf + assert.Nil(t, err) + assert.True(t, tlsCertForIP == tlsCertForIP2) + assert.Nil(t, x509c.VerifyHostname("192.168.0.1")) } func TestGenerateAndSave(t *testing.T) { diff --git a/common/cert/storage.go b/common/cert/storage.go index 61663e73..a55d065c 100644 --- a/common/cert/storage.go +++ b/common/cert/storage.go @@ -2,31 +2,31 @@ package cert import ( "crypto/tls" - "time" - "github.com/Dreamacro/clash/common/cache" + "github.com/Dreamacro/clash/component/trie" ) -var TTL = time.Hour * 2 - -// AutoGCCertsStorage cache with the generated certificates, auto released after TTL -type AutoGCCertsStorage struct { - certsCache *cache.Cache[string, *tls.Certificate] +// DomainTrieCertsStorage cache wildcard certificates +type DomainTrieCertsStorage struct { + certsCache *trie.DomainTrie[*tls.Certificate] } // Get gets the certificate from the storage -func (c *AutoGCCertsStorage) Get(key string) (*tls.Certificate, bool) { - ca := c.certsCache.Get(key) - return ca, ca != nil +func (c *DomainTrieCertsStorage) Get(key string) (*tls.Certificate, bool) { + ca := c.certsCache.Search(key) + if ca == nil { + return nil, false + } + return ca.Data, true } // Set saves the certificate to the storage -func (c *AutoGCCertsStorage) Set(key string, cert *tls.Certificate) { - c.certsCache.Put(key, cert, TTL) +func (c *DomainTrieCertsStorage) Set(key string, cert *tls.Certificate) { + _ = c.certsCache.Insert(key, cert) } -func NewAutoGCCertsStorage() *AutoGCCertsStorage { - return &AutoGCCertsStorage{ - certsCache: cache.New[string, *tls.Certificate](TTL), +func NewDomainTrieCertsStorage() *DomainTrieCertsStorage { + return &DomainTrieCertsStorage{ + certsCache: trie.New[*tls.Certificate](), } } diff --git a/listener/listener.go b/listener/listener.go index f75a282f..bda55e71 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -395,13 +395,12 @@ func ReCreateMitm(port int, tcpIn chan<- C.ConnContext) { certOption, err = cert.NewConfig( x509c, privateKey, - cert.NewAutoGCCertsStorage(), ) if err != nil { return } - certOption.SetValidity(time.Hour * 24 * 90) + certOption.SetValidity(time.Hour * 24 * 365 * 2) // 2 years certOption.SetOrganization("Clash ManInTheMiddle Proxy Services") opt := &mitm.Option{ diff --git a/listener/mitm/client.go b/listener/mitm/client.go index b20d8586..a01c65d8 100644 --- a/listener/mitm/client.go +++ b/listener/mitm/client.go @@ -18,9 +18,11 @@ func newClient(source net.Addr, userAgent string, in chan<- C.ConnContext) *http Transport: &http.Transport{ // excepted HTTP/2 TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), - // from http.DefaultTransport - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, + // only needed 1 connection + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: 60 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{ diff --git a/listener/mitm/proxy.go b/listener/mitm/proxy.go index 98323e9a..eb8876c0 100644 --- a/listener/mitm/proxy.go +++ b/listener/mitm/proxy.go @@ -44,7 +44,7 @@ startOver: readLoop: for { // use SetReadDeadline instead of Proxy-Connection keep-alive - if err := conn.SetReadDeadline(time.Now().Add(95 * time.Second)); err != nil { + if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil { break readLoop } @@ -86,7 +86,7 @@ readLoop: // TLS handshake. if b[0] == 0x16 { - tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Host)) + tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.request.URL.Hostname())) // Handshake with the local client if err = tlsConn.Handshake(); err != nil { @@ -167,13 +167,7 @@ readLoop: func writeResponseWithHandler(session *Session, opt *Option) error { if opt.Handler != nil { res := opt.Handler.HandleResponse(session) - if res != nil { - body := res.Body - defer func(body io.ReadCloser) { - _ = body.Close() - }(body) - session.response = res } } @@ -186,7 +180,7 @@ func writeResponse(session *Session, keepAlive bool) error { if keepAlive { session.response.Header.Set("Connection", "keep-alive") - session.response.Header.Set("Keep-Alive", "timeout=90") + session.response.Header.Set("Keep-Alive", "timeout=60") } return session.writeResponse() @@ -201,10 +195,6 @@ func handleApiRequest(session *Session, opt *Option) error { session.response = session.NewResponse(http.StatusOK, bytes.NewReader(b)) - defer func(body io.ReadCloser) { - _ = body.Close() - }(session.response.Body) - session.response.Close = true session.response.Header.Set("Content-Type", "application/x-x509-ca-cert") session.response.ContentLength = int64(len(b)) @@ -230,11 +220,6 @@ func handleApiRequest(session *Session, opt *Option) error { b = fmt.Sprintf(b, session.request.URL.Path) session.response = session.NewResponse(http.StatusNotFound, bytes.NewReader([]byte(b))) - - defer func(body io.ReadCloser) { - _ = body.Close() - }(session.response.Body) - session.response.Close = true session.response.Header.Set("Content-Type", "text/html;charset=utf-8") session.response.ContentLength = int64(len(b)) @@ -243,6 +228,12 @@ func handleApiRequest(session *Session, opt *Option) error { } func handleError(opt *Option, session *Session, err error) { + if session.response != nil { + defer func() { + _, _ = io.Copy(io.Discard, session.response.Body) + _ = session.response.Body.Close() + }() + } if opt.Handler != nil { opt.Handler.HandleError(session, err) } diff --git a/listener/mitm/session.go b/listener/mitm/session.go index c2622a69..99979a98 100644 --- a/listener/mitm/session.go +++ b/listener/mitm/session.go @@ -43,6 +43,9 @@ func (s *Session) writeResponse() error { if s.response == nil { return ErrInvalidResponse } + defer func(resp *http.Response) { + _ = resp.Body.Close() + }(s.response) return s.response.Write(s.conn) } diff --git a/listener/tun/tun_adapter.go b/listener/tun/tun_adapter.go index 7461a492..98b0467c 100644 --- a/listener/tun/tun_adapter.go +++ b/listener/tun/tun_adapter.go @@ -145,6 +145,7 @@ func setAtLatest(stackType C.TUNStack, devName string) { case "darwin": // _, _ = cmd.ExecCmd("sysctl -w net.inet.ip.forwarding=1") // _, _ = cmd.ExecCmd("sysctl -w net.inet6.ip6.forwarding=1") + _, _ = cmd.ExecCmd("sudo launchctl limit maxfiles 10240 unlimited") case "windows": _, _ = cmd.ExecCmd("ipconfig /renew") case "linux":