diff --git a/common/cert/cert.go b/common/cert/cert.go index f8e826e2..274ef55e 100644 --- a/common/cert/cert.go +++ b/common/cert/cert.go @@ -152,9 +152,11 @@ func (c *Config) NewTLSConfigForHost(hostname string) *tls.Config { } func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certificate, error) { + var leaf *x509.Certificate tlsCertificate, ok := c.certsStorage.Get(hostname) if ok { - if _, err := tlsCertificate.Leaf.Verify(x509.VerifyOptions{ + leaf = tlsCertificate.Leaf + if _, err := leaf.Verify(x509.VerifyOptions{ DNSName: hostname, Roots: c.roots, }); err == nil { @@ -163,28 +165,40 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica } var ( - key = hostname - topHost = hostname - dnsNames []string + key = hostname + topHost = hostname + wildcardHost = "*." + hostname + dnsNames []string ) if ip := net.ParseIP(hostname); ip != nil { ips = append(ips, ip) } else { - parts := strings.Split(topHost, ".") + parts := strings.Split(hostname, ".") l := len(parts) - if l >= 2 { - for i := l - 2; i >= 0; i-- { - topHost = strings.Join(parts[i:], ".") - dnsNames = append(dnsNames, topHost, "*."+topHost) - } - - topHost = strings.Join(parts[l-2:], ".") - key = "+." + topHost - } else { - dnsNames = append(dnsNames, topHost) + if leaf != nil { + dnsNames = append(dnsNames, leaf.DNSNames...) } + + if l > 2 { + topIndex := l - 2 + topHost = strings.Join(parts[topIndex:], ".") + + for i := topIndex; i > 0; i-- { + wildcardHost = "*." + strings.Join(parts[i:], ".") + + if i == topIndex && (len(dnsNames) == 0 || dnsNames[0] != topHost) { + dnsNames = append(dnsNames, topHost, wildcardHost) + } else if !hasDnsNames(dnsNames, wildcardHost) { + dnsNames = append(dnsNames, wildcardHost) + } + } + } else { + dnsNames = append(dnsNames, topHost, wildcardHost) + } + + key = "+." + topHost } serial := atomic.AddInt64(¤tSerialNumber, 1) @@ -279,3 +293,12 @@ func GenerateAndSave(caPath string, caKeyPath string) error { return nil } + +func hasDnsNames(dnsNames []string, hostname string) bool { + for _, name := range dnsNames { + if name == hostname { + return true + } + } + return false +} diff --git a/common/cert/cert_test.go b/common/cert/cert_test.go index de8202d9..c237c588 100644 --- a/common/cert/cert_test.go +++ b/common/cert/cert_test.go @@ -24,13 +24,13 @@ func TestCert(t *testing.T) { c.SetValidity(20 * time.Hour) c.SetOrganization("Test Organization") - conf := c.NewTLSConfigForHost("abc.example.org") + conf := c.NewTLSConfigForHost("example.org") assert.Equal(t, []string{"http/1.1"}, conf.NextProtos) assert.True(t, conf.InsecureSkipVerify) // Test generating a certificate clientHello := &tls.ClientHelloInfo{ - ServerName: "abc.example.org", + ServerName: "example.org", } tlsCert, err := conf.GetCertificate(clientHello) assert.Nil(t, err) @@ -41,22 +41,41 @@ func TestCert(t *testing.T) { assert.Equal(t, "example.org", x509c.Subject.CommonName) assert.Nil(t, x509c.VerifyHostname("example.org")) assert.Nil(t, x509c.VerifyHostname("abc.example.org")) - assert.Nil(t, x509c.VerifyHostname("efg.abc.example.org")) assert.Equal(t, []string{"Test Organization"}, x509c.Subject.Organization) assert.NotNil(t, x509c.SubjectKeyId) assert.True(t, x509c.BasicConstraintsValid) assert.True(t, x509c.KeyUsage&x509.KeyUsageKeyEncipherment == x509.KeyUsageKeyEncipherment) assert.True(t, x509c.KeyUsage&x509.KeyUsageDigitalSignature == x509.KeyUsageDigitalSignature) assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509c.ExtKeyUsage) - assert.Equal(t, []string{"example.org", "*.example.org", "abc.example.org", "*.abc.example.org"}, x509c.DNSNames) + assert.Equal(t, []string{"example.org", "*.example.org"}, x509c.DNSNames) assert.True(t, x509c.NotBefore.Before(time.Now().Add(-2*time.Hour))) assert.True(t, x509c.NotAfter.After(time.Now().Add(2*time.Hour))) // Check that certificate is cached - tlsCert2, err := c.GetOrCreateCert("example.org") + tlsCert2, err := c.GetOrCreateCert("abc.example.org") assert.Nil(t, err) assert.True(t, tlsCert == tlsCert2) + // Check that certificate is new + _, _ = c.GetOrCreateCert("a.b.c.d.e.f.g.h.i.j.example.org") + tlsCert3, err := c.GetOrCreateCert("m.k.l.example.org") + x509c = tlsCert3.Leaf + assert.Nil(t, err) + assert.False(t, tlsCert == tlsCert3) + assert.Equal(t, []string{"example.org", "*.example.org", "*.j.example.org", "*.i.j.example.org", "*.h.i.j.example.org", "*.g.h.i.j.example.org", "*.f.g.h.i.j.example.org", "*.e.f.g.h.i.j.example.org", "*.d.e.f.g.h.i.j.example.org", "*.c.d.e.f.g.h.i.j.example.org", "*.b.c.d.e.f.g.h.i.j.example.org", "*.l.example.org", "*.k.l.example.org"}, x509c.DNSNames) + + // Check that certificate is cached + tlsCert4, err := c.GetOrCreateCert("xyz.example.org") + x509c = tlsCert4.Leaf + assert.Nil(t, err) + assert.True(t, tlsCert3 == tlsCert4) + assert.Nil(t, x509c.VerifyHostname("example.org")) + assert.Nil(t, x509c.VerifyHostname("jkf.example.org")) + assert.Nil(t, x509c.VerifyHostname("n.j.example.org")) + assert.Nil(t, x509c.VerifyHostname("c.i.j.example.org")) + assert.Nil(t, x509c.VerifyHostname("m.l.example.org")) + assert.Error(t, x509c.VerifyHostname("m.l.jkf.example.org")) + // Check the certificate for an IP tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1") x509c = tlsCertForIP.Leaf @@ -64,6 +83,7 @@ func TestCert(t *testing.T) { assert.Equal(t, 1, len(x509c.IPAddresses)) assert.True(t, net.ParseIP("192.168.0.1").Equal(x509c.IPAddresses[0])) + // Check that certificate is cached tlsCertForIP2, err := c.GetOrCreateCert("192.168.0.1") x509c = tlsCertForIP2.Leaf assert.Nil(t, err)