summaryrefslogtreecommitdiff
path: root/vendor/github.com/testcontainers/testcontainers-go/wait/tls.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/testcontainers/testcontainers-go/wait/tls.go')
-rw-r--r--vendor/github.com/testcontainers/testcontainers-go/wait/tls.go167
1 files changed, 167 insertions, 0 deletions
diff --git a/vendor/github.com/testcontainers/testcontainers-go/wait/tls.go b/vendor/github.com/testcontainers/testcontainers-go/wait/tls.go
new file mode 100644
index 0000000..ab904b2
--- /dev/null
+++ b/vendor/github.com/testcontainers/testcontainers-go/wait/tls.go
@@ -0,0 +1,167 @@
+package wait
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io"
+ "time"
+)
+
+// Validate we implement interface.
+var _ Strategy = (*TLSStrategy)(nil)
+
+// TLSStrategy is a strategy for handling TLS.
+type TLSStrategy struct {
+ // General Settings.
+ timeout *time.Duration
+ pollInterval time.Duration
+
+ // Custom Settings.
+ certFiles *x509KeyPair
+ rootFiles []string
+
+ // State.
+ tlsConfig *tls.Config
+}
+
+// x509KeyPair is a pair of certificate and key files.
+type x509KeyPair struct {
+ certPEMFile string
+ keyPEMFile string
+}
+
+// ForTLSCert returns a CertStrategy that will add a Certificate to the [tls.Config]
+// constructed from PEM formatted certificate key file pair in the container.
+func ForTLSCert(certPEMFile, keyPEMFile string) *TLSStrategy {
+ return &TLSStrategy{
+ certFiles: &x509KeyPair{
+ certPEMFile: certPEMFile,
+ keyPEMFile: keyPEMFile,
+ },
+ tlsConfig: &tls.Config{},
+ pollInterval: defaultPollInterval(),
+ }
+}
+
+// ForTLSRootCAs returns a CertStrategy that sets the root CAs for the [tls.Config]
+// using the given PEM formatted files from the container.
+func ForTLSRootCAs(pemFiles ...string) *TLSStrategy {
+ return &TLSStrategy{
+ rootFiles: pemFiles,
+ tlsConfig: &tls.Config{},
+ pollInterval: defaultPollInterval(),
+ }
+}
+
+// WithRootCAs sets the root CAs for the [tls.Config] using the given files from
+// the container.
+func (ws *TLSStrategy) WithRootCAs(files ...string) *TLSStrategy {
+ ws.rootFiles = files
+ return ws
+}
+
+// WithCert sets the [tls.Config] Certificates using the given files from the container.
+func (ws *TLSStrategy) WithCert(certPEMFile, keyPEMFile string) *TLSStrategy {
+ ws.certFiles = &x509KeyPair{
+ certPEMFile: certPEMFile,
+ keyPEMFile: keyPEMFile,
+ }
+ return ws
+}
+
+// WithServerName sets the server for the [tls.Config].
+func (ws *TLSStrategy) WithServerName(serverName string) *TLSStrategy {
+ ws.tlsConfig.ServerName = serverName
+ return ws
+}
+
+// WithStartupTimeout can be used to change the default startup timeout.
+func (ws *TLSStrategy) WithStartupTimeout(startupTimeout time.Duration) *TLSStrategy {
+ ws.timeout = &startupTimeout
+ return ws
+}
+
+// WithPollInterval can be used to override the default polling interval of 100 milliseconds.
+func (ws *TLSStrategy) WithPollInterval(pollInterval time.Duration) *TLSStrategy {
+ ws.pollInterval = pollInterval
+ return ws
+}
+
+// TLSConfig returns the TLS config once the strategy is ready.
+// If the strategy is nil, it returns nil.
+func (ws *TLSStrategy) TLSConfig() *tls.Config {
+ if ws == nil {
+ return nil
+ }
+
+ return ws.tlsConfig
+}
+
+// WaitUntilReady implements the [Strategy] interface.
+// It waits for the CA, client cert and key files to be available in the container and
+// uses them to setup the TLS config.
+func (ws *TLSStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error {
+ size := len(ws.rootFiles)
+ if ws.certFiles != nil {
+ size += 2
+ }
+ strategies := make([]Strategy, 0, size)
+ for _, file := range ws.rootFiles {
+ strategies = append(strategies,
+ ForFile(file).WithMatcher(func(r io.Reader) error {
+ buf, err := io.ReadAll(r)
+ if err != nil {
+ return fmt.Errorf("read CA cert file %q: %w", file, err)
+ }
+
+ if ws.tlsConfig.RootCAs == nil {
+ ws.tlsConfig.RootCAs = x509.NewCertPool()
+ }
+
+ if !ws.tlsConfig.RootCAs.AppendCertsFromPEM(buf) {
+ return fmt.Errorf("invalid CA cert file %q", file)
+ }
+
+ return nil
+ }).WithPollInterval(ws.pollInterval),
+ )
+ }
+
+ if ws.certFiles != nil {
+ var certPEMBlock []byte
+ strategies = append(strategies,
+ ForFile(ws.certFiles.certPEMFile).WithMatcher(func(r io.Reader) error {
+ var err error
+ if certPEMBlock, err = io.ReadAll(r); err != nil {
+ return fmt.Errorf("read certificate cert %q: %w", ws.certFiles.certPEMFile, err)
+ }
+
+ return nil
+ }).WithPollInterval(ws.pollInterval),
+ ForFile(ws.certFiles.keyPEMFile).WithMatcher(func(r io.Reader) error {
+ keyPEMBlock, err := io.ReadAll(r)
+ if err != nil {
+ return fmt.Errorf("read certificate key %q: %w", ws.certFiles.keyPEMFile, err)
+ }
+
+ cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
+ if err != nil {
+ return fmt.Errorf("x509 key pair %q %q: %w", ws.certFiles.certPEMFile, ws.certFiles.keyPEMFile, err)
+ }
+
+ ws.tlsConfig.Certificates = []tls.Certificate{cert}
+
+ return nil
+ }).WithPollInterval(ws.pollInterval),
+ )
+ }
+
+ strategy := ForAll(strategies...)
+ if ws.timeout != nil {
+ strategy.WithStartupTimeout(*ws.timeout)
+ }
+
+ return strategy.WaitUntilReady(ctx, target)
+}