Fix: wildcard certificates

This commit is contained in:
yaling888
2022-04-25 10:54:12 +08:00
parent 62bc75af8a
commit 7115f7e61b
2 changed files with 63 additions and 20 deletions

View File

@ -152,9 +152,11 @@ func (c *Config) NewTLSConfigForHost(hostname string) *tls.Config {
} }
func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certificate, error) { func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certificate, error) {
var leaf *x509.Certificate
tlsCertificate, ok := c.certsStorage.Get(hostname) tlsCertificate, ok := c.certsStorage.Get(hostname)
if ok { if ok {
if _, err := tlsCertificate.Leaf.Verify(x509.VerifyOptions{ leaf = tlsCertificate.Leaf
if _, err := leaf.Verify(x509.VerifyOptions{
DNSName: hostname, DNSName: hostname,
Roots: c.roots, Roots: c.roots,
}); err == nil { }); err == nil {
@ -165,27 +167,39 @@ func (c *Config) GetOrCreateCert(hostname string, ips ...net.IP) (*tls.Certifica
var ( var (
key = hostname key = hostname
topHost = hostname topHost = hostname
wildcardHost = "*." + hostname
dnsNames []string dnsNames []string
) )
if ip := net.ParseIP(hostname); ip != nil { if ip := net.ParseIP(hostname); ip != nil {
ips = append(ips, ip) ips = append(ips, ip)
} else { } else {
parts := strings.Split(topHost, ".") parts := strings.Split(hostname, ".")
l := len(parts) l := len(parts)
if l >= 2 { if leaf != nil {
for i := l - 2; i >= 0; i-- { dnsNames = append(dnsNames, leaf.DNSNames...)
topHost = strings.Join(parts[i:], ".")
dnsNames = append(dnsNames, topHost, "*."+topHost)
} }
topHost = strings.Join(parts[l-2:], ".") if l > 2 {
key = "+." + topHost topIndex := l - 2
} else { topHost = strings.Join(parts[topIndex:], ".")
dnsNames = append(dnsNames, topHost)
for i := topIndex; i > 0; i-- {
wildcardHost = "*." + strings.Join(parts[i:], ".")
if i == topIndex && (len(dnsNames) == 0 || dnsNames[0] != topHost) {
dnsNames = append(dnsNames, topHost, wildcardHost)
} else if !hasDnsNames(dnsNames, wildcardHost) {
dnsNames = append(dnsNames, wildcardHost)
} }
} }
} else {
dnsNames = append(dnsNames, topHost, wildcardHost)
}
key = "+." + topHost
}
serial := atomic.AddInt64(&currentSerialNumber, 1) serial := atomic.AddInt64(&currentSerialNumber, 1)
@ -279,3 +293,12 @@ func GenerateAndSave(caPath string, caKeyPath string) error {
return nil return nil
} }
func hasDnsNames(dnsNames []string, hostname string) bool {
for _, name := range dnsNames {
if name == hostname {
return true
}
}
return false
}

View File

@ -24,13 +24,13 @@ func TestCert(t *testing.T) {
c.SetValidity(20 * time.Hour) c.SetValidity(20 * time.Hour)
c.SetOrganization("Test Organization") c.SetOrganization("Test Organization")
conf := c.NewTLSConfigForHost("abc.example.org") conf := c.NewTLSConfigForHost("example.org")
assert.Equal(t, []string{"http/1.1"}, conf.NextProtos) assert.Equal(t, []string{"http/1.1"}, conf.NextProtos)
assert.True(t, conf.InsecureSkipVerify) assert.True(t, conf.InsecureSkipVerify)
// Test generating a certificate // Test generating a certificate
clientHello := &tls.ClientHelloInfo{ clientHello := &tls.ClientHelloInfo{
ServerName: "abc.example.org", ServerName: "example.org",
} }
tlsCert, err := conf.GetCertificate(clientHello) tlsCert, err := conf.GetCertificate(clientHello)
assert.Nil(t, err) assert.Nil(t, err)
@ -41,22 +41,41 @@ func TestCert(t *testing.T) {
assert.Equal(t, "example.org", x509c.Subject.CommonName) assert.Equal(t, "example.org", x509c.Subject.CommonName)
assert.Nil(t, x509c.VerifyHostname("example.org")) assert.Nil(t, x509c.VerifyHostname("example.org"))
assert.Nil(t, x509c.VerifyHostname("abc.example.org")) assert.Nil(t, x509c.VerifyHostname("abc.example.org"))
assert.Nil(t, x509c.VerifyHostname("efg.abc.example.org"))
assert.Equal(t, []string{"Test Organization"}, x509c.Subject.Organization) assert.Equal(t, []string{"Test Organization"}, x509c.Subject.Organization)
assert.NotNil(t, x509c.SubjectKeyId) assert.NotNil(t, x509c.SubjectKeyId)
assert.True(t, x509c.BasicConstraintsValid) assert.True(t, x509c.BasicConstraintsValid)
assert.True(t, x509c.KeyUsage&x509.KeyUsageKeyEncipherment == x509.KeyUsageKeyEncipherment) assert.True(t, x509c.KeyUsage&x509.KeyUsageKeyEncipherment == x509.KeyUsageKeyEncipherment)
assert.True(t, x509c.KeyUsage&x509.KeyUsageDigitalSignature == x509.KeyUsageDigitalSignature) assert.True(t, x509c.KeyUsage&x509.KeyUsageDigitalSignature == x509.KeyUsageDigitalSignature)
assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509c.ExtKeyUsage) assert.Equal(t, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, x509c.ExtKeyUsage)
assert.Equal(t, []string{"example.org", "*.example.org", "abc.example.org", "*.abc.example.org"}, x509c.DNSNames) assert.Equal(t, []string{"example.org", "*.example.org"}, x509c.DNSNames)
assert.True(t, x509c.NotBefore.Before(time.Now().Add(-2*time.Hour))) assert.True(t, x509c.NotBefore.Before(time.Now().Add(-2*time.Hour)))
assert.True(t, x509c.NotAfter.After(time.Now().Add(2*time.Hour))) assert.True(t, x509c.NotAfter.After(time.Now().Add(2*time.Hour)))
// Check that certificate is cached // Check that certificate is cached
tlsCert2, err := c.GetOrCreateCert("example.org") tlsCert2, err := c.GetOrCreateCert("abc.example.org")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, tlsCert == tlsCert2) assert.True(t, tlsCert == tlsCert2)
// Check that certificate is new
_, _ = c.GetOrCreateCert("a.b.c.d.e.f.g.h.i.j.example.org")
tlsCert3, err := c.GetOrCreateCert("m.k.l.example.org")
x509c = tlsCert3.Leaf
assert.Nil(t, err)
assert.False(t, tlsCert == tlsCert3)
assert.Equal(t, []string{"example.org", "*.example.org", "*.j.example.org", "*.i.j.example.org", "*.h.i.j.example.org", "*.g.h.i.j.example.org", "*.f.g.h.i.j.example.org", "*.e.f.g.h.i.j.example.org", "*.d.e.f.g.h.i.j.example.org", "*.c.d.e.f.g.h.i.j.example.org", "*.b.c.d.e.f.g.h.i.j.example.org", "*.l.example.org", "*.k.l.example.org"}, x509c.DNSNames)
// Check that certificate is cached
tlsCert4, err := c.GetOrCreateCert("xyz.example.org")
x509c = tlsCert4.Leaf
assert.Nil(t, err)
assert.True(t, tlsCert3 == tlsCert4)
assert.Nil(t, x509c.VerifyHostname("example.org"))
assert.Nil(t, x509c.VerifyHostname("jkf.example.org"))
assert.Nil(t, x509c.VerifyHostname("n.j.example.org"))
assert.Nil(t, x509c.VerifyHostname("c.i.j.example.org"))
assert.Nil(t, x509c.VerifyHostname("m.l.example.org"))
assert.Error(t, x509c.VerifyHostname("m.l.jkf.example.org"))
// Check the certificate for an IP // Check the certificate for an IP
tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1") tlsCertForIP, err := c.GetOrCreateCert("192.168.0.1")
x509c = tlsCertForIP.Leaf x509c = tlsCertForIP.Leaf
@ -64,6 +83,7 @@ func TestCert(t *testing.T) {
assert.Equal(t, 1, len(x509c.IPAddresses)) assert.Equal(t, 1, len(x509c.IPAddresses))
assert.True(t, net.ParseIP("192.168.0.1").Equal(x509c.IPAddresses[0])) assert.True(t, net.ParseIP("192.168.0.1").Equal(x509c.IPAddresses[0]))
// Check that certificate is cached
tlsCertForIP2, err := c.GetOrCreateCert("192.168.0.1") tlsCertForIP2, err := c.GetOrCreateCert("192.168.0.1")
x509c = tlsCertForIP2.Leaf x509c = tlsCertForIP2.Leaf
assert.Nil(t, err) assert.Nil(t, err)