summaryrefslogtreecommitdiff
path: root/vendor/github.com/testcontainers/testcontainers-go/wait/tls.go
blob: ab904b271e579bb8d272203c274bf15527c07d2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
package wait

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"time"
)

// Validate we implement interface.
var _ Strategy = (*TLSStrategy)(nil)

// TLSStrategy is a strategy for handling TLS.
type TLSStrategy struct {
	// General Settings.
	timeout      *time.Duration
	pollInterval time.Duration

	// Custom Settings.
	certFiles *x509KeyPair
	rootFiles []string

	// State.
	tlsConfig *tls.Config
}

// x509KeyPair is a pair of certificate and key files.
type x509KeyPair struct {
	certPEMFile string
	keyPEMFile  string
}

// ForTLSCert returns a CertStrategy that will add a Certificate to the [tls.Config]
// constructed from PEM formatted certificate key file pair in the container.
func ForTLSCert(certPEMFile, keyPEMFile string) *TLSStrategy {
	return &TLSStrategy{
		certFiles: &x509KeyPair{
			certPEMFile: certPEMFile,
			keyPEMFile:  keyPEMFile,
		},
		tlsConfig:    &tls.Config{},
		pollInterval: defaultPollInterval(),
	}
}

// ForTLSRootCAs returns a CertStrategy that sets the root CAs for the [tls.Config]
// using the given PEM formatted files from the container.
func ForTLSRootCAs(pemFiles ...string) *TLSStrategy {
	return &TLSStrategy{
		rootFiles:    pemFiles,
		tlsConfig:    &tls.Config{},
		pollInterval: defaultPollInterval(),
	}
}

// WithRootCAs sets the root CAs for the [tls.Config] using the given files from
// the container.
func (ws *TLSStrategy) WithRootCAs(files ...string) *TLSStrategy {
	ws.rootFiles = files
	return ws
}

// WithCert sets the [tls.Config] Certificates using the given files from the container.
func (ws *TLSStrategy) WithCert(certPEMFile, keyPEMFile string) *TLSStrategy {
	ws.certFiles = &x509KeyPair{
		certPEMFile: certPEMFile,
		keyPEMFile:  keyPEMFile,
	}
	return ws
}

// WithServerName sets the server for the [tls.Config].
func (ws *TLSStrategy) WithServerName(serverName string) *TLSStrategy {
	ws.tlsConfig.ServerName = serverName
	return ws
}

// WithStartupTimeout can be used to change the default startup timeout.
func (ws *TLSStrategy) WithStartupTimeout(startupTimeout time.Duration) *TLSStrategy {
	ws.timeout = &startupTimeout
	return ws
}

// WithPollInterval can be used to override the default polling interval of 100 milliseconds.
func (ws *TLSStrategy) WithPollInterval(pollInterval time.Duration) *TLSStrategy {
	ws.pollInterval = pollInterval
	return ws
}

// TLSConfig returns the TLS config once the strategy is ready.
// If the strategy is nil, it returns nil.
func (ws *TLSStrategy) TLSConfig() *tls.Config {
	if ws == nil {
		return nil
	}

	return ws.tlsConfig
}

// WaitUntilReady implements the [Strategy] interface.
// It waits for the CA, client cert and key files to be available in the container and
// uses them to setup the TLS config.
func (ws *TLSStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error {
	size := len(ws.rootFiles)
	if ws.certFiles != nil {
		size += 2
	}
	strategies := make([]Strategy, 0, size)
	for _, file := range ws.rootFiles {
		strategies = append(strategies,
			ForFile(file).WithMatcher(func(r io.Reader) error {
				buf, err := io.ReadAll(r)
				if err != nil {
					return fmt.Errorf("read CA cert file %q: %w", file, err)
				}

				if ws.tlsConfig.RootCAs == nil {
					ws.tlsConfig.RootCAs = x509.NewCertPool()
				}

				if !ws.tlsConfig.RootCAs.AppendCertsFromPEM(buf) {
					return fmt.Errorf("invalid CA cert file %q", file)
				}

				return nil
			}).WithPollInterval(ws.pollInterval),
		)
	}

	if ws.certFiles != nil {
		var certPEMBlock []byte
		strategies = append(strategies,
			ForFile(ws.certFiles.certPEMFile).WithMatcher(func(r io.Reader) error {
				var err error
				if certPEMBlock, err = io.ReadAll(r); err != nil {
					return fmt.Errorf("read certificate cert %q: %w", ws.certFiles.certPEMFile, err)
				}

				return nil
			}).WithPollInterval(ws.pollInterval),
			ForFile(ws.certFiles.keyPEMFile).WithMatcher(func(r io.Reader) error {
				keyPEMBlock, err := io.ReadAll(r)
				if err != nil {
					return fmt.Errorf("read certificate key %q: %w", ws.certFiles.keyPEMFile, err)
				}

				cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
				if err != nil {
					return fmt.Errorf("x509 key pair %q %q: %w", ws.certFiles.certPEMFile, ws.certFiles.keyPEMFile, err)
				}

				ws.tlsConfig.Certificates = []tls.Certificate{cert}

				return nil
			}).WithPollInterval(ws.pollInterval),
		)
	}

	strategy := ForAll(strategies...)
	if ws.timeout != nil {
		strategy.WithStartupTimeout(*ws.timeout)
	}

	return strategy.WaitUntilReady(ctx, target)
}