feat: plugins/common/tls/config.go: Filter client certificates by DNS names (#9910)
This commit is contained in:
parent
9d5eb7dd68
commit
76251d34f3
|
|
@ -31,6 +31,12 @@ The server TLS configuration provides support for TLS mutual authentication:
|
|||
## enable mutually authenticated TLS connections.
|
||||
# tls_allowed_cacerts = ["/etc/telegraf/clientca.pem"]
|
||||
|
||||
## Set one or more allowed DNS name to enable a whitelist
|
||||
## to verify incoming client certificates.
|
||||
## It will go through all available SAN in the certificate,
|
||||
## if of them matches the request is accepted.
|
||||
# tls_allowed_dns_names = ["client.example.org"]
|
||||
|
||||
## Add service certificate and key.
|
||||
# tls_cert = "/etc/telegraf/cert.pem"
|
||||
# tls_key = "/etc/telegraf/key.pem"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"github.com/influxdata/telegraf/internal/choice"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
|
@ -24,12 +25,13 @@ type ClientConfig struct {
|
|||
|
||||
// ServerConfig represents the standard server TLS config.
|
||||
type ServerConfig struct {
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"`
|
||||
TLSCipherSuites []string `toml:"tls_cipher_suites"`
|
||||
TLSMinVersion string `toml:"tls_min_version"`
|
||||
TLSMaxVersion string `toml:"tls_max_version"`
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"`
|
||||
TLSCipherSuites []string `toml:"tls_cipher_suites"`
|
||||
TLSMinVersion string `toml:"tls_min_version"`
|
||||
TLSMaxVersion string `toml:"tls_max_version"`
|
||||
TLSAllowedDNSNames []string `toml:"tls_allowed_dns_names"`
|
||||
}
|
||||
|
||||
// TLSConfig returns a tls.Config, may be nil without error if TLS is not
|
||||
|
|
@ -141,6 +143,12 @@ func (c *ServerConfig) TLSConfig() (*tls.Config, error) {
|
|||
"tls min version %q can't be greater than tls max version %q", tlsConfig.MinVersion, tlsConfig.MaxVersion)
|
||||
}
|
||||
|
||||
// Since clientAuth is tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
// there must be certs to validate.
|
||||
if len(c.TLSAllowedCACerts) > 0 && len(c.TLSAllowedDNSNames) > 0 {
|
||||
tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
|
|
@ -152,8 +160,7 @@ func makeCertPool(certFiles []string) (*x509.CertPool, error) {
|
|||
return nil, fmt.Errorf(
|
||||
"could not read certificate %q: %v", certFile, err)
|
||||
}
|
||||
ok := pool.AppendCertsFromPEM(pem)
|
||||
if !ok {
|
||||
if !pool.AppendCertsFromPEM(pem) {
|
||||
return nil, fmt.Errorf(
|
||||
"could not parse any PEM certificates %q: %v", certFile, err)
|
||||
}
|
||||
|
|
@ -172,3 +179,20 @@ func loadCertificate(config *tls.Config, certFile, keyFile string) error {
|
|||
config.BuildNameToCertificate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ServerConfig) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
// The certificate chain is client + intermediate + root.
|
||||
// Let's review the client certificate.
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not validate peer certificate: %v", err)
|
||||
}
|
||||
|
||||
for _, name := range cert.DNSNames {
|
||||
if choice.Contains(name, c.TLSAllowedDNSNames) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("peer certificate not in allowed DNS Name list: %v", cert.DNSNames)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,12 +128,13 @@ func TestServerConfig(t *testing.T) {
|
|||
{
|
||||
name: "success",
|
||||
server: tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
TLSMinVersion: pki.TLSMinVersion(),
|
||||
TLSMaxVersion: pki.TLSMaxVersion(),
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCipherSuites: []string{pki.CipherSuite()},
|
||||
TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"},
|
||||
TLSMinVersion: pki.TLSMinVersion(),
|
||||
TLSMaxVersion: pki.TLSMaxVersion(),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -293,9 +294,10 @@ func TestConnect(t *testing.T) {
|
|||
}
|
||||
|
||||
serverConfig := tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"},
|
||||
}
|
||||
|
||||
serverTLSConfig, err := serverConfig.TLSConfig()
|
||||
|
|
@ -323,3 +325,46 @@ func TestConnect(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestConnectWrongDNS(t *testing.T) {
|
||||
clientConfig := tls.ClientConfig{
|
||||
TLSCA: pki.CACertPath(),
|
||||
TLSCert: pki.ClientCertPath(),
|
||||
TLSKey: pki.ClientKeyPath(),
|
||||
}
|
||||
|
||||
serverConfig := tls.ServerConfig{
|
||||
TLSCert: pki.ServerCertPath(),
|
||||
TLSKey: pki.ServerKeyPath(),
|
||||
TLSAllowedCACerts: []string{pki.CACertPath()},
|
||||
TLSAllowedDNSNames: []string{"localhos", "127.0.0.2"},
|
||||
}
|
||||
|
||||
serverTLSConfig, err := serverConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
ts.TLS = serverTLSConfig
|
||||
|
||||
ts.StartTLS()
|
||||
defer ts.Close()
|
||||
|
||||
clientTLSConfig, err := clientConfig.TLSConfig()
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: clientTLSConfig,
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Get(ts.URL)
|
||||
require.Error(t, err)
|
||||
if resp != nil {
|
||||
err = resp.Body.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue