From add8eef1e3962517aca506668c272cae07c34b1c Mon Sep 17 00:00:00 2001 From: Kevin Joiner <10265309+KevinJoiner@users.noreply.github.com> Date: Tue, 24 Sep 2024 09:27:59 -0400 Subject: [PATCH 1/2] Adds new tls options specifically for servers --- internal/tls/docs.go | 67 +++++++++----- internal/tls/type.go | 38 ++++++++ internal/tls/type_test.go | 163 ++++++++++++++++++++++++++++++++++ public/service/config_test.go | 24 +++++ public/service/config_tls.go | 25 +++++- 5 files changed, 295 insertions(+), 22 deletions(-) diff --git a/internal/tls/docs.go b/internal/tls/docs.go index 027f9f6df..ada739a33 100644 --- a/internal/tls/docs.go +++ b/internal/tls/docs.go @@ -1,9 +1,31 @@ package tls -import "github.com/redpanda-data/benthos/v4/internal/docs" +import ( + "github.com/redpanda-data/benthos/v4/internal/docs" +) // FieldSpec returns a spec for a common TLS field. func FieldSpec() docs.FieldSpec { + return baseFieldSpec().WithChildren( + baseCertType("client_certs", "A list of client certificates to use. For each certificate either the fields `cert` and `key`, or `cert_file` and `key_file` should be specified, but not both."), + ) +} + +// ServerFieldSpec returns a spec for a common TLS field used by a server. +func ServerFieldSpec() docs.FieldSpec { + return baseFieldSpec().WithChildren( + baseCertType("server_certs", "A list of server certificates to use. For each certificate either the fields `cert` and `key`, or `cert_file` and `key_file` should be specified, but not both."), + docs.FieldBool("require_mutual_tls", "Whether to require mutual TLS authentication. When enabled the server will require a client certificate to be presented during the handshake.").HasDefault(false), + docs.FieldString( + "client_root_cas", "An optional root certificate authority to use for checking client certs when mTLS is required. This is a string, representing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate.", "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----", + ).HasDefault("").Secret(), + docs.FieldString( + "client_root_cas_file", "An optional path of a root certificate authority file to use for checking client certs when mTLS is required. This is a file, often with a .pem extension, containing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate.", "./root_cas.pem", + ).HasDefault(""), + ) +} + +func baseFieldSpec() docs.FieldSpec { return docs.FieldObject( "tls", "Custom TLS settings can be used to override system defaults.", ).WithChildren( @@ -26,30 +48,33 @@ func FieldSpec() docs.FieldSpec { docs.FieldString( "root_cas_file", "An optional path of a root certificate authority file to use. This is a file, often with a .pem extension, containing a certificate chain from the parent trusted root certificate, to possible intermediate signing certificates, to the host certificate.", "./root_cas.pem", ).HasDefault(""), + ).Advanced() +} - docs.FieldObject( - "client_certs", "A list of client certificates to use. For each certificate either the fields `cert` and `key`, or `cert_file` and `key_file` should be specified, but not both.", - []any{ - map[string]any{ - "cert": "foo", - "key": "bar", - }, +func baseCertType(name, description string) docs.FieldSpec { + return docs.FieldObject( + name, description, + []any{ + map[string]any{ + "cert": "foo", + "key": "bar", }, - []any{ - map[string]any{ - "cert_file": "./example.pem", - "key_file": "./example.key", - }, + }, + []any{ + map[string]any{ + "cert_file": "./example.pem", + "key_file": "./example.key", }, - ).Array().WithChildren( - docs.FieldString("cert", "A plain text certificate to use.").HasDefault(""), - docs.FieldString("key", "A plain text certificate key to use.").HasDefault("").Secret(), - docs.FieldString("cert_file", "The path of a certificate to use.").HasDefault(""), - docs.FieldString("key_file", "The path of a certificate key to use.").HasDefault(""), - docs.FieldString("password", `A plain text password for when the private key is password encrypted in PKCS#1 or PKCS#8 format. The obsolete `+"`pbeWithMD5AndDES-CBC`"+` algorithm is not supported for the PKCS#8 format. + }, + ).Array().WithChildren( + docs.FieldString("cert", "A plain text certificate to use.").HasDefault(""), + docs.FieldString("key", "A plain text certificate key to use.").HasDefault("").Secret(), + docs.FieldString("cert_file", "The path of a certificate to use.").HasDefault(""), + docs.FieldString("key_file", "The path of a certificate key to use.").HasDefault(""), + docs.FieldString("password", `A plain text password for when the private key is password encrypted in PKCS#1 or PKCS#8 format. The obsolete `+"`pbeWithMD5AndDES-CBC`"+` algorithm is not supported for the PKCS#8 format. Because the obsolete pbeWithMD5AndDES-CBC algorithm does not authenticate the ciphertext, it is vulnerable to padding oracle attacks that can let an attacker recover the plaintext. `, "foo", "${KEY_PASSWORD}").HasDefault("").Secret(), - ).HasDefault([]any{}), - ).Advanced() + ).HasDefault([]any{}) + } diff --git a/internal/tls/type.go b/internal/tls/type.go index 70b660d3c..a9071bdcb 100644 --- a/internal/tls/type.go +++ b/internal/tls/type.go @@ -29,7 +29,11 @@ type Config struct { RootCAsFile string `json:"root_cas_file" yaml:"root_cas_file"` InsecureSkipVerify bool `json:"skip_cert_verify" yaml:"skip_cert_verify"` ClientCertificates []ClientCertConfig `json:"client_certs" yaml:"client_certs"` + ServerCertificates []ClientCertConfig `json:"server_certs" yaml:"server_certs"` EnableRenegotiation bool `json:"enable_renegotiation" yaml:"enable_renegotiation"` + RequireMutualTLS bool `json:"require_mutual_tls" yaml:"require_mutual_tls"` + ClientRootCAS string `json:"client_root_cas" yaml:"client_root_cas"` + ClientRootCASFile string `json:"client_root_cas_file" yaml:"client_root_cas_file"` } // NewConfig creates a new Config with default values. @@ -41,6 +45,7 @@ func NewConfig() Config { InsecureSkipVerify: false, ClientCertificates: []ClientCertConfig{}, EnableRenegotiation: false, + RequireMutualTLS: false, } } @@ -67,6 +72,9 @@ func (c *Config) GetNonToggled(f ifs.FS) (*tls.Config, error) { if c.RootCAs != "" && c.RootCAsFile != "" { return nil, errors.New("only one field between root_cas and root_cas_file can be specified") } + if c.ClientRootCAS != "" && c.ClientRootCASFile != "" { + return nil, errors.New("only one field between client_root_cas and client_root_cas_file can be specified") + } if c.RootCAsFile != "" { caCert, err := ifs.ReadFile(f, c.RootCAsFile) @@ -84,6 +92,22 @@ func (c *Config) GetNonToggled(f ifs.FS) (*tls.Config, error) { tlsConf.RootCAs.AppendCertsFromPEM([]byte(c.RootCAs)) } + if c.ClientRootCASFile != "" { + caCert, err := ifs.ReadFile(f, c.ClientRootCASFile) + if err != nil { + return nil, err + } + initConf() + tlsConf.ClientCAs = x509.NewCertPool() + tlsConf.ClientCAs.AppendCertsFromPEM(caCert) + } + + if c.ClientRootCAS != "" { + initConf() + tlsConf.ClientCAs = x509.NewCertPool() + tlsConf.ClientCAs.AppendCertsFromPEM([]byte(c.ClientRootCAS)) + } + for _, conf := range c.ClientCertificates { cert, err := conf.Load(f) if err != nil { @@ -93,6 +117,15 @@ func (c *Config) GetNonToggled(f ifs.FS) (*tls.Config, error) { tlsConf.Certificates = append(tlsConf.Certificates, cert) } + for _, conf := range c.ServerCertificates { + cert, err := conf.Load(f) + if err != nil { + return nil, err + } + initConf() + tlsConf.Certificates = append(tlsConf.Certificates, cert) + } + if c.EnableRenegotiation { initConf() tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient @@ -103,6 +136,11 @@ func (c *Config) GetNonToggled(f ifs.FS) (*tls.Config, error) { tlsConf.InsecureSkipVerify = true } + if c.RequireMutualTLS { + initConf() + tlsConf.ClientAuth = tls.RequireAndVerifyClientCert + } + return tlsConf, nil } diff --git a/internal/tls/type_test.go b/internal/tls/type_test.go index aeb5a7142..5fd3128a3 100644 --- a/internal/tls/type_test.go +++ b/internal/tls/type_test.go @@ -6,9 +6,15 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" + "io" "log" "math/big" + "net" + "net/http" + "net/http/httptest" "os" + "strings" "testing" "time" @@ -52,6 +58,67 @@ func createCertificates() (certPem, keyPem []byte) { return certPem, keyPem } +// CreateSignedCertificate generates a certificate signed by a CA. +func CreateSignedCertificate(caCert *x509.Certificate, caKey *rsa.PrivateKey, ipAddress string) (certPEM, keyPEM []byte, err error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + + tml := x509.Certificate{ + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(2, 0, 0), + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{ + CommonName: "localhost", + Organization: []string{"Benthos"}, + }, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP(ipAddress)}, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &tml, caCert, &key.PublicKey, caKey) + if err != nil { + return nil, nil, err + } + + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + return certPEM, keyPEM, nil +} + +// CreateCACertificate generates a CA certificate. +func CreateCACertificate() (caCertPEM, caKeyPEM []byte, caCert *x509.Certificate, caKey *rsa.PrivateKey, err error) { + caKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, nil, nil, err + } + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Benthos CA"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), // Valid for 10 years + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + IsCA: true, + BasicConstraintsValid: true, + } + + caCertBytes, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return nil, nil, nil, nil, err + } + + caCertPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertBytes}) + caKeyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(caKey)}) + + caCert, err = x509.ParseCertificate(caCertBytes) + return caCertPEM, caKeyPEM, caCert, caKey, err +} + type keyPair struct { cert []byte key []byte @@ -260,3 +327,99 @@ func TestCertificateWithNoEncryption(t *testing.T) { t.Errorf("Failed to load certificate %s", err) } } + +func TestRequireMutualTLS(t *testing.T) { + // First create a test server so we can get its IP address for the certificate. + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + serverAddr := server.Listener.Addr().String() + serverIP := strings.Split(serverAddr, ":")[0] + + // Generate CA certificate. + caCertPem, _, caCert, caKey, err := CreateCACertificate() + require.NoError(t, err) + + // Generate server certificate signed by the CA. + serverCert, serverKey, err := CreateSignedCertificate(caCert, caKey, serverIP) + require.NoError(t, err) + + // Setup the server configuration with the server certificate and the CA root. + serverConfig := Config{ + Enabled: true, + ClientRootCAS: string(caCertPem), + RequireMutualTLS: true, + ClientCertificates: []ClientCertConfig{{Cert: string(serverCert), Key: string(serverKey)}}, + } + + // Get the server TLS configuration. + serverTLSConfig, err := serverConfig.GetNonToggled(nil) + require.NoError(t, err) + + // Set the test server's TLS configuration and start it. + server.TLS = serverTLSConfig + server.StartTLS() + defer server.Close() + + // --------------------------------------------------------------------------------------------------- + // Setup the client without a client certificate (to test rejection). + clientConfigWithoutCert := Config{ + Enabled: true, + RootCAs: string(caCertPem), // Use the CA certificate. + ClientCertificates: []ClientCertConfig{}, + } + + // Get the client TLS configuration for the client without a certificate. + clientWithoutTLSConfig, err := clientConfigWithoutCert.GetNonToggled(nil) + require.NoError(t, err) + + // Create a client without a client certificate. + clientWithoutCert := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: clientWithoutTLSConfig, + }, + } + + // Attempt to connect without a client certificate. + _, err = clientWithoutCert.Get(server.URL) + require.Error(t, err, "Expected error when client does not provide a certificate") + + // --------------------------------------------------------------------------------------------------- + // Setup the client with a client certificate + + // Generate client certificate signed by the CA. + clientCert, clientKey, err := CreateSignedCertificate(caCert, caKey, serverIP) + require.NoError(t, err) + + // Setup the client configuration with the client certificate and the same CA root. + clientConfig := Config{ + Enabled: true, + RootCAs: string(caCertPem), + ClientCertificates: []ClientCertConfig{{Cert: string(clientCert), Key: string(clientKey)}}, + } + + // Get the client TLS configuration with client cert. + clientTLSConfig, err := clientConfig.GetNonToggled(nil) + require.NoError(t, err) + + // Test connection with a client certificate (should succeed). + clientWithCert := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: clientTLSConfig, + }, + } + + // Attempt to connect with a client certificate. + resp, err := clientWithCert.Get(server.URL) + require.NoError(t, err, "Expected no error when client provides a valid certificate") + + // Read and verify the response body. + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "Hello, client\n", string(body)) + + err = resp.Body.Close() + require.NoError(t, err) +} diff --git a/public/service/config_test.go b/public/service/config_test.go index a91141f1e..207cf13a1 100644 --- a/public/service/config_test.go +++ b/public/service/config_test.go @@ -338,6 +338,30 @@ b: and this assert.True(t, tConf.InsecureSkipVerify) } +func TestServerConfigTLS(t *testing.T) { + spec := NewConfigSpec(). + Field(NewServerTLSField("a")). + Field(NewStringField("b")) + + parsedConfig, err := spec.ParseYAML(` +a: + skip_cert_verify: true +b: and this +`, nil) + require.NoError(t, err) + + _, err = parsedConfig.FieldTLS("b") + require.Error(t, err) + + _, err = parsedConfig.FieldTLS("c") + require.Error(t, err) + + tConf, err := parsedConfig.FieldTLS("a") + require.NoError(t, err) + + assert.True(t, tConf.InsecureSkipVerify) +} + func TestConfigInterpolatedString(t *testing.T) { spec := NewConfigSpec(). Field(NewInterpolatedStringField("a")). diff --git a/public/service/config_tls.go b/public/service/config_tls.go index 867bc81fc..b394e8bf0 100644 --- a/public/service/config_tls.go +++ b/public/service/config_tls.go @@ -15,7 +15,17 @@ import ( // settings for networked components. It is then possible to extract a // *tls.Config from the resulting parsed config with the method FieldTLS. func NewTLSField(name string) *ConfigField { - tf := btls.FieldSpec() + return newTLSField(name, btls.FieldSpec()) +} + +// NewServerTLSField defines a new object type config field that describes TLS +// settings for server side networked components. It is then possible to extract a +// *tls.Config from the resulting parsed config with the method FieldTLS. +func NewServerTLSField(name string) *ConfigField { + return newTLSField(name, btls.ServerFieldSpec()) +} + +func newTLSField(name string, tf docs.FieldSpec) *ConfigField { tf.Name = name var newChildren []docs.FieldSpec for _, f := range tf.Children { @@ -62,6 +72,19 @@ func NewTLSToggledField(name string) *ConfigField { return &ConfigField{field: tf} } +// NewServerTLSToggledField defines a new object type config field that describes +// TLS settings for server side networked components. This field differs from a +// a standard TLSField as it includes a boolean field `enabled` which allows +// users to explicitly configure whether TLS should be enabled or not. +// +// A *tls.Config as well as an enabled boolean value can be extracted from the +// resulting parsed config with the method FieldTLSToggled. +func NewServerTLSToggledField(name string) *ConfigField { + tf := btls.ServerFieldSpec() + tf.Name = name + return &ConfigField{field: tf} +} + // FieldTLSToggled accesses a field from a parsed config that was defined with // NewTLSFieldToggled and returns a *tls.Config and a boolean flag indicating // whether tls is explicitly enabled, or an error if the configuration was From 195d023de9a50c4c1291fc3b3006fc0f6c167a95 Mon Sep 17 00:00:00 2001 From: Kevin Joiner <10265309+KevinJoiner@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:01:19 -0400 Subject: [PATCH 2/2] Add TLS object to http_server --- internal/impl/io/input_http_server.go | 25 +++- internal/impl/io/input_http_server_test.go | 131 +++++++++++++++++++- internal/impl/io/output_http_server.go | 24 +++- internal/impl/io/output_http_server_test.go | 88 +++++++++++++ 4 files changed, 259 insertions(+), 9 deletions(-) diff --git a/internal/impl/io/input_http_server.go b/internal/impl/io/input_http_server.go index 8d3ca22e5..06d93efc9 100644 --- a/internal/impl/io/input_http_server.go +++ b/internal/impl/io/input_http_server.go @@ -57,6 +57,7 @@ const ( hsiFieldResponseStatus = "status" hsiFieldResponseHeaders = "headers" hsiFieldResponseExtractMetadata = "metadata_headers" + hsiFieldTLS = "tls" ) type hsiConfig struct { @@ -72,6 +73,7 @@ type hsiConfig struct { KeyFile string CORS httpserver.CORSConfig Response hsiResponseConfig + TLSConfig *tls.Config } type hsiResponseConfig struct { @@ -128,6 +130,13 @@ func hsiConfigFromParsed(pConf *service.ParsedConfig) (conf hsiConfig, err error if conf.Response, err = hsiResponseConfigFromParsed(pConf.Namespace(hsiFieldResponse)); err != nil { return } + tlsConf, enabled, err := pConf.FieldTLSToggled(hsiFieldTLS) + if err != nil { + return + } + if enabled { + conf.TLSConfig = tlsConf + } return } @@ -244,14 +253,17 @@ You can access these metadata fields using xref:configuration:interpolation.adoc service.NewStringField(hsiFieldRateLimit). Description("An optional xref:components:rate_limits/about.adoc[rate limit] to throttle requests by."). Default(""), + service.NewServerTLSToggledField(hsiFieldTLS), service.NewStringField(hsiFieldCertFile). Description("Enable TLS by specifying a certificate and key file. Only valid with a custom `address`."). Advanced(). - Default(""), + Default(""). + Deprecated(), service.NewStringField(hsiFieldKeyFile). Description("Enable TLS by specifying a certificate and key file. Only valid with a custom `address`."). Advanced(). - Default(""), + Default(""). + Deprecated(), service.NewInternalField(corsSpec), service.NewObjectField(hsiFieldResponse, service.NewInterpolatedStringField(hsiFieldResponseStatus). @@ -375,7 +387,10 @@ func newHTTPServerInput(conf hsiConfig, mgr bundle.NewManagement) (input.Streame var err error if conf.Address != "" { gMux = mux.NewRouter() - server = &http.Server{Addr: conf.Address} + server = &http.Server{ + Addr: conf.Address, + TLSConfig: conf.TLSConfig, + } if server.Handler, err = conf.CORS.WrapHandler(gMux); err != nil { return nil, fmt.Errorf("bad CORS configuration: %w", err) } @@ -877,11 +892,13 @@ func (h *httpServerInput) loop() { if h.server != nil { go func() { - if h.conf.KeyFile != "" || h.conf.CertFile != "" { + if h.conf.TLSConfig != nil || h.conf.KeyFile != "" || h.conf.CertFile != "" { h.log.Info( "Receiving HTTPS messages at: https://%s\n", h.conf.Address+h.conf.Path, ) + + // if TLSConfig.ClientCertificates are set and CertFile or KeyFile are not empty, the server will use the CertFile and KeyFile instead of the ClientCertificates. if err := h.server.ListenAndServeTLS( h.conf.CertFile, h.conf.KeyFile, ); err != http.ErrServerClosed { diff --git a/internal/impl/io/input_http_server_test.go b/internal/impl/io/input_http_server_test.go index bc98acd11..6bb0c27c6 100644 --- a/internal/impl/io/input_http_server_test.go +++ b/internal/impl/io/input_http_server_test.go @@ -3,9 +3,16 @@ package io_test import ( "bytes" "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" "io" + "math/big" "mime" "mime/multipart" "net" @@ -13,6 +20,7 @@ import ( "net/http/httptest" "net/textproto" "net/url" + "os" "sync" "testing" "time" @@ -1200,7 +1208,6 @@ http_server: Content-Type: application/json foo: '${!json("field1")}' `) - h, err := mgr.NewInput(conf) require.NoError(t, err) @@ -1336,3 +1343,125 @@ http_server: assert.Equal(t, "200 OK", resp.Status) assert.Equal(t, "foo", resp.Header.Get("Access-Control-Allow-Origin")) } + +func TestHTTPServerInputTLSParameters(t *testing.T) { + tCtx, done := context.WithTimeout(context.Background(), time.Minute) + defer done() + + freePort := getFreePort(t) + certFile, keyFile, caCert, err := createCertFiles() + require.NoError(t, err) + t.Cleanup(func() { + os.Remove(certFile.Name()) + os.Remove(keyFile.Name()) + }) + + conf := parseYAMLInputConf(t, ` +http_server: + address: 0.0.0.0:%v + path: /test/tls + allowed_verbs: [ POST ] + tls: + enabled: true + server_certs: + - cert_file: %s + key_file: %s +`, freePort, certFile.Name(), keyFile.Name()) + server, err := mock.NewManager().NewInput(conf) + require.NoError(t, err) + + defer func() { + server.TriggerStopConsuming() + assert.NoError(t, server.WaitForClose(tCtx)) + }() + + rootCA := x509.NewCertPool() + rootCA.AddCert(caCert) + httpClient := http.DefaultClient + httpClient.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCA, + }, + } + var resp *http.Response + inputData := "a bunch of jolly leprechauns await" + go func() { + require.Eventually(t, func() (succeeded bool) { + req, cerr := http.NewRequest(http.MethodPost, fmt.Sprintf("https://localhost:%v/test/tls", freePort), bytes.NewBufferString(inputData)) + require.NoError(t, cerr) + req.Header.Set("Content-Type", "text/plain") + if resp, cerr = httpClient.Do(req); cerr == nil { + succeeded = true + assert.Equal(t, "200 OK", resp.Status) + resp.Body.Close() + } + return + }, time.Second, 50*time.Millisecond) + }() + + readNextMsg := func() (message.Batch, error) { + var tran message.Transaction + select { + case tran = <-server.TransactionChan(): + require.NoError(t, tran.Ack(tCtx, nil)) + case <-time.After(time.Second): + return nil, errors.New("timed out") + } + return tran.Payload, nil + } + + msg, err := readNextMsg() + require.NoError(t, err) + assert.Equal(t, inputData, string(message.GetAllBytes(msg)[0])) +} + +// createCACertificate generates a CA certificate. +func createCertFiles() (certFile, keyFile *os.File, caCert *x509.Certificate, err error) { + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, nil, err + } + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "localhost"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + IsCA: true, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + + caCertBytes, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + return nil, nil, nil, err + } + + caCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertBytes}) + caKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(caKey)}) + + caCert, err = x509.ParseCertificate(caCertBytes) + if err != nil { + return nil, nil, nil, err + } + + certFile, err = os.CreateTemp("", "ca.pem") + if err != nil { + return nil, nil, nil, err + } + _, err = certFile.Write(caCertPEM) + if err != nil { + return nil, nil, nil, err + } + keyFile, err = os.CreateTemp("", "key.pem") + if err != nil { + return nil, nil, nil, err + } + _, err = keyFile.Write(caKeyPEM) + if err != nil { + return nil, nil, nil, err + } + + return certFile, keyFile, caCert, err +} diff --git a/internal/impl/io/output_http_server.go b/internal/impl/io/output_http_server.go index 1875fbfa4..626a81c17 100644 --- a/internal/impl/io/output_http_server.go +++ b/internal/impl/io/output_http_server.go @@ -3,6 +3,7 @@ package io import ( "bytes" "context" + "crypto/tls" "errors" "fmt" "io" @@ -54,6 +55,7 @@ type hsoConfig struct { CertFile string KeyFile string CORS httpserver.CORSConfig + TLSConfig *tls.Config } func hsoConfigFromParsed(pConf *service.ParsedConfig) (conf hsoConfig, err error) { @@ -95,6 +97,14 @@ func hsoConfigFromParsed(pConf *service.ParsedConfig) (conf hsoConfig, err error if conf.CORS, err = corsConfigFromParsed(pConf.Namespace(hsoFieldCORS)); err != nil { return } + + tlsConf, enabled, err := pConf.FieldTLSToggled(hsiFieldTLS) + if err != nil { + return + } + if enabled { + conf.TLSConfig = tlsConf + } return } @@ -136,14 +146,17 @@ Please note, messages are considered delivered as soon as the data is written to Description("The maximum time to wait before a blocking, inactive connection is dropped (only applies to the `path` endpoint)."). Default("5s"). Advanced(), + service.NewServerTLSToggledField(hsiFieldTLS), service.NewStringField(hsoFieldCertFile). Description("Enable TLS by specifying a certificate and key file. Only valid with a custom `address`."). Advanced(). - Default(""), + Default(""). + Deprecated(), service.NewStringField(hsoFieldKeyFile). Description("Enable TLS by specifying a certificate and key file. Only valid with a custom `address`."). Advanced(). - Default(""), + Default(""). + Deprecated(), service.NewInternalField(corsSpec), ) } @@ -209,7 +222,10 @@ func newHTTPServerOutput(conf hsoConfig, mgr bundle.NewManagement) (output.Strea var err error if conf.Address != "" { gMux = mux.NewRouter() - server = &http.Server{Addr: conf.Address} + server = &http.Server{ + Addr: conf.Address, + TLSConfig: conf.TLSConfig, + } if server.Handler, err = conf.CORS.WrapHandler(gMux); err != nil { return nil, fmt.Errorf("bad CORS configuration: %w", err) } @@ -448,7 +464,7 @@ func (h *httpServerOutput) Consume(ts <-chan message.Transaction) error { if h.server != nil { go func() { - if h.conf.KeyFile != "" || h.conf.CertFile != "" { + if h.conf.TLSConfig != nil || h.conf.KeyFile != "" || h.conf.CertFile != "" { h.log.Info( "Serving messages through HTTPS GET request at: https://%s\n", h.conf.Address+h.conf.Path, diff --git a/internal/impl/io/output_http_server_test.go b/internal/impl/io/output_http_server_test.go index 2803c939d..e9e686aa3 100644 --- a/internal/impl/io/output_http_server_test.go +++ b/internal/impl/io/output_http_server_test.go @@ -2,8 +2,11 @@ package io_test import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "net/http" + "os" "testing" "time" @@ -159,3 +162,88 @@ http_server: h.TriggerCloseNow() require.NoError(t, h.WaitForClose(ctx)) } + +func TestHTTPServerOutputTLS(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), time.Second*30) + defer done() + + nTestLoops := 10 + + certFile, keyFile, caCert, err := createCertFiles() + require.NoError(t, err) + t.Cleanup(func() { + os.Remove(certFile.Name()) + os.Remove(keyFile.Name()) + }) + + port := getFreePort(t) + conf := parseYAMLOutputConf(t, ` +http_server: + address: localhost:%v + path: /testpost + tls: + enabled: true + server_certs: + - cert_file: %s + key_file: %s +`, port, certFile.Name(), keyFile.Name()) + + h, err := mock.NewManager().NewOutput(conf) + require.NoError(t, err) + + msgChan := make(chan message.Transaction) + resChan := make(chan error) + + if err = h.Consume(msgChan); err != nil { + t.Error(err) + return + } + + <-time.After(time.Millisecond * 100) + + // Test both single and multipart messages. + for i := 0; i < nTestLoops; i++ { + testStr := fmt.Sprintf("test%v", i) + + go func() { + testMsg := message.QuickBatch([][]byte{[]byte(testStr)}) + select { + case msgChan <- message.NewTransaction(testMsg, resChan): + case <-time.After(time.Second): + t.Error("Timed out waiting for message") + return + } + select { + case resMsg := <-resChan: + if resMsg != nil { + t.Error(resMsg) + } + case <-time.After(time.Second): + t.Error("Timed out waiting for response") + } + }() + + rootCA := x509.NewCertPool() + rootCA.AddCert(caCert) + httpClient := http.DefaultClient + httpClient.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCA, + }, + } + + res, err := httpClient.Get(fmt.Sprintf("https://localhost:%v/testpost", port)) + if err != nil { + t.Error(err) + return + } + res.Body.Close() + if res.StatusCode != 200 { + t.Errorf("Wrong error code returned: %v", res.StatusCode) + return + } + } + + h.TriggerCloseNow() + require.NoError(t, h.WaitForClose(ctx)) +}