summaryrefslogtreecommitdiff
path: root/vendor/github.com/testcontainers/testcontainers-go/modules/postgres/postgres.go
blob: f03adc7e163671bd0dd162880b532aaeb189c17f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
package postgres

import (
	"context"
	"database/sql"
	_ "embed"
	"errors"
	"fmt"
	"io"
	"path/filepath"
	"strings"

	"github.com/testcontainers/testcontainers-go"
	"github.com/testcontainers/testcontainers-go/log"
)

const (
	defaultUser         = "postgres"
	defaultPassword     = "postgres"
	defaultSnapshotName = "migrated_template"
)

//go:embed resources/customEntrypoint.sh
var embeddedCustomEntrypoint string

// PostgresContainer represents the postgres container type used in the module
type PostgresContainer struct {
	testcontainers.Container
	dbName       string
	user         string
	password     string
	snapshotName string
	// sqlDriverName is passed to sql.Open() to connect to the database when making or restoring snapshots.
	// This can be set if your app imports a different postgres driver, f.ex. "pgx"
	sqlDriverName string
}

// MustConnectionString panics if the address cannot be determined.
func (c *PostgresContainer) MustConnectionString(ctx context.Context, args ...string) string {
	addr, err := c.ConnectionString(ctx, args...)
	if err != nil {
		panic(err)
	}
	return addr
}

// ConnectionString returns the connection string for the postgres container, using the default 5432 port, and
// obtaining the host and exposed port from the container. It also accepts a variadic list of extra arguments
// which will be appended to the connection string. The format of the extra arguments is the same as the
// connection string format, e.g. "connect_timeout=10" or "application_name=myapp"
func (c *PostgresContainer) ConnectionString(ctx context.Context, args ...string) (string, error) {
	endpoint, err := c.PortEndpoint(ctx, "5432/tcp", "")
	if err != nil {
		return "", err
	}

	extraArgs := strings.Join(args, "&")
	connStr := fmt.Sprintf("postgres://%s:%s@%s/%s?%s", c.user, c.password, endpoint, c.dbName, extraArgs)
	return connStr, nil
}

// WithConfigFile sets the config file to be used for the postgres container
// It will also set the "config_file" parameter to the path of the config file
// as a command line argument to the container
func WithConfigFile(cfg string) testcontainers.CustomizeRequestOption {
	return func(req *testcontainers.GenericContainerRequest) error {
		cfgFile := testcontainers.ContainerFile{
			HostFilePath:      cfg,
			ContainerFilePath: "/etc/postgresql.conf",
			FileMode:          0o755,
		}

		req.Files = append(req.Files, cfgFile)
		req.Cmd = append(req.Cmd, "-c", "config_file=/etc/postgresql.conf")

		return nil
	}
}

// WithDatabase sets the initial database to be created when the container starts
// It can be used to define a different name for the default database that is created when the image is first started.
// If it is not specified, then the value of WithUser will be used.
func WithDatabase(dbName string) testcontainers.CustomizeRequestOption {
	return func(req *testcontainers.GenericContainerRequest) error {
		req.Env["POSTGRES_DB"] = dbName

		return nil
	}
}

// WithInitScripts sets the init scripts to be run when the container starts.
// These init scripts will be executed in sorted name order as defined by the container's current locale, which defaults to en_US.utf8.
// If you need to run your scripts in a specific order, consider using `WithOrderedInitScripts` instead.
func WithInitScripts(scripts ...string) testcontainers.CustomizeRequestOption {
	containerFiles := []testcontainers.ContainerFile{}
	for _, script := range scripts {
		initScript := testcontainers.ContainerFile{
			HostFilePath:      script,
			ContainerFilePath: "/docker-entrypoint-initdb.d/" + filepath.Base(script),
			FileMode:          0o755,
		}
		containerFiles = append(containerFiles, initScript)
	}

	return testcontainers.WithFiles(containerFiles...)
}

// WithOrderedInitScripts sets the init scripts to be run when the container starts.
// The scripts will be run in the order that they are provided in this function.
func WithOrderedInitScripts(scripts ...string) testcontainers.CustomizeRequestOption {
	containerFiles := []testcontainers.ContainerFile{}
	for idx, script := range scripts {
		initScript := testcontainers.ContainerFile{
			HostFilePath:      script,
			ContainerFilePath: "/docker-entrypoint-initdb.d/" + fmt.Sprintf("%03d-%s", idx, filepath.Base(script)),
			FileMode:          0o755,
		}
		containerFiles = append(containerFiles, initScript)
	}

	return testcontainers.WithFiles(containerFiles...)
}

// WithPassword sets the initial password of the user to be created when the container starts
// It is required for you to use the PostgreSQL image. It must not be empty or undefined.
// This environment variable sets the superuser password for PostgreSQL.
func WithPassword(password string) testcontainers.CustomizeRequestOption {
	return func(req *testcontainers.GenericContainerRequest) error {
		req.Env["POSTGRES_PASSWORD"] = password

		return nil
	}
}

// WithUsername sets the initial username to be created when the container starts
// It is used in conjunction with WithPassword to set a user and its password.
// It will create the specified user with superuser power and a database with the same name.
// If it is not specified, then the default user of postgres will be used.
func WithUsername(user string) testcontainers.CustomizeRequestOption {
	return func(req *testcontainers.GenericContainerRequest) error {
		if user == "" {
			user = defaultUser
		}

		req.Env["POSTGRES_USER"] = user

		return nil
	}
}

// Deprecated: use Run instead
// RunContainer creates an instance of the Postgres container type
func RunContainer(ctx context.Context, opts ...testcontainers.ContainerCustomizer) (*PostgresContainer, error) {
	return Run(ctx, "postgres:16-alpine", opts...)
}

// Run creates an instance of the Postgres container type
func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustomizer) (*PostgresContainer, error) {
	req := testcontainers.ContainerRequest{
		Image: img,
		Env: map[string]string{
			"POSTGRES_USER":     defaultUser,
			"POSTGRES_PASSWORD": defaultPassword,
			"POSTGRES_DB":       defaultUser, // defaults to the user name
		},
		ExposedPorts: []string{"5432/tcp"},
		Cmd:          []string{"postgres", "-c", "fsync=off"},
	}

	genericContainerReq := testcontainers.GenericContainerRequest{
		ContainerRequest: req,
		Started:          true,
	}

	// Gather all config options (defaults and then apply provided options)
	settings := defaultOptions()
	for _, opt := range opts {
		if apply, ok := opt.(Option); ok {
			apply(&settings)
		}
		if err := opt.Customize(&genericContainerReq); err != nil {
			return nil, err
		}
	}

	container, err := testcontainers.GenericContainer(ctx, genericContainerReq)
	var c *PostgresContainer
	if container != nil {
		c = &PostgresContainer{
			Container:     container,
			dbName:        req.Env["POSTGRES_DB"],
			password:      req.Env["POSTGRES_PASSWORD"],
			user:          req.Env["POSTGRES_USER"],
			sqlDriverName: settings.SQLDriverName,
			snapshotName:  settings.Snapshot,
		}
	}

	if err != nil {
		return c, fmt.Errorf("generic container: %w", err)
	}

	return c, nil
}

type snapshotConfig struct {
	snapshotName string
}

// SnapshotOption is the type for passing options to the snapshot function of the database
type SnapshotOption func(container *snapshotConfig) *snapshotConfig

// WithSnapshotName adds a specific name to the snapshot database created from the main database defined on the
// container. The snapshot must not have the same name as your main database, otherwise it will be overwritten
func WithSnapshotName(name string) SnapshotOption {
	return func(cfg *snapshotConfig) *snapshotConfig {
		cfg.snapshotName = name
		return cfg
	}
}

// WithSSLSettings configures the Postgres server to run with the provided CA Chain
// This will not function if the corresponding postgres conf is not correctly configured.
// Namely the paths below must match what is set in the conf file
func WithSSLCert(caCertFile string, certFile string, keyFile string) testcontainers.CustomizeRequestOption {
	const defaultPermission = 0o600

	return func(req *testcontainers.GenericContainerRequest) error {
		const entrypointPath = "/usr/local/bin/docker-entrypoint-ssl.bash"

		req.Files = append(req.Files,
			testcontainers.ContainerFile{
				HostFilePath:      caCertFile,
				ContainerFilePath: "/tmp/testcontainers-go/postgres/ca_cert.pem",
				FileMode:          defaultPermission,
			},
			testcontainers.ContainerFile{
				HostFilePath:      certFile,
				ContainerFilePath: "/tmp/testcontainers-go/postgres/server.cert",
				FileMode:          defaultPermission,
			},
			testcontainers.ContainerFile{
				HostFilePath:      keyFile,
				ContainerFilePath: "/tmp/testcontainers-go/postgres/server.key",
				FileMode:          defaultPermission,
			},
			testcontainers.ContainerFile{
				Reader:            strings.NewReader(embeddedCustomEntrypoint),
				ContainerFilePath: entrypointPath,
				FileMode:          defaultPermission,
			},
		)
		req.Entrypoint = []string{"sh", entrypointPath}

		return nil
	}
}

// Snapshot takes a snapshot of the current state of the database as a template, which can then be restored using
// the Restore method. By default, the snapshot will be created under a database called migrated_template, you can
// customize the snapshot name with the options.
// If a snapshot already exists under the given/default name, it will be overwritten with the new snapshot.
func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption) error {
	snapshotName, err := c.checkSnapshotConfig(opts)
	if err != nil {
		return err
	}

	// execute the commands to create the snapshot, in order
	if err := c.execCommandsSQL(ctx,
		// Update pg_database to remove the template flag, then drop the database if it exists.
		// This is needed because dropping a template database will fail.
		// https://www.postgresql.org/docs/current/manage-ag-templatedbs.html
		fmt.Sprintf(`UPDATE pg_database SET datistemplate = FALSE WHERE datname = '%s'`, snapshotName),
		fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, snapshotName),
		// Create a copy of the database to another database to use as a template now that it was fully migrated
		fmt.Sprintf(`CREATE DATABASE "%s" WITH TEMPLATE "%s" OWNER "%s"`, snapshotName, c.dbName, c.user),
		// Snapshot the template database so we can restore it onto our original database going forward
		fmt.Sprintf(`ALTER DATABASE "%s" WITH is_template = TRUE`, snapshotName),
	); err != nil {
		return err
	}

	c.snapshotName = snapshotName
	return nil
}

// Restore will restore the database to a specific snapshot. By default, it will restore the last snapshot taken on the
// database by the Snapshot method. If a snapshot name is provided, it will instead try to restore the snapshot by name.
func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption) error {
	snapshotName, err := c.checkSnapshotConfig(opts)
	if err != nil {
		return err
	}

	// execute the commands to restore the snapshot, in order
	return c.execCommandsSQL(ctx,
		// Drop the entire database by connecting to the postgres global database
		fmt.Sprintf(`DROP DATABASE "%s" with (FORCE)`, c.dbName),
		// Then restore the previous snapshot
		fmt.Sprintf(`CREATE DATABASE "%s" WITH TEMPLATE "%s" OWNER "%s"`, c.dbName, snapshotName, c.user),
	)
}

func (c *PostgresContainer) checkSnapshotConfig(opts []SnapshotOption) (string, error) {
	config := &snapshotConfig{}
	for _, opt := range opts {
		config = opt(config)
	}

	snapshotName := c.snapshotName
	if config.snapshotName != "" {
		snapshotName = config.snapshotName
	}

	if c.dbName == "postgres" {
		return "", errors.New("cannot restore the postgres system database as it cannot be dropped to be restored")
	}
	return snapshotName, nil
}

func (c *PostgresContainer) execCommandsSQL(ctx context.Context, cmds ...string) error {
	conn, cleanup, err := c.snapshotConnection(ctx)
	if err != nil {
		log.Printf("Could not connect to database to restore snapshot, falling back to `docker exec psql`: %v", err)
		return c.execCommandsFallback(ctx, cmds)
	}
	if cleanup != nil {
		defer cleanup()
	}
	for _, cmd := range cmds {
		if _, err := conn.ExecContext(ctx, cmd); err != nil {
			return fmt.Errorf("could not execute restore command %s: %w", cmd, err)
		}
	}
	return nil
}

// snapshotConnection connects to the actual database using the "postgres" sql.DB driver, if it exists.
// The returned function should be called as a defer() to close the pool.
// No need to close the individual connection, that is done as part of the pool close.
// Also, no need to cache the connection pool, since it is a single connection which is very fast to establish.
func (c *PostgresContainer) snapshotConnection(ctx context.Context) (*sql.Conn, func(), error) {
	// Connect to the database "postgres" instead of the app one
	c2 := &PostgresContainer{
		Container:     c.Container,
		dbName:        "postgres",
		user:          c.user,
		password:      c.password,
		sqlDriverName: c.sqlDriverName,
	}

	// Try to use an actual postgres connection, if the driver is loaded
	connStr := c2.MustConnectionString(ctx, "sslmode=disable")
	pool, err := sql.Open(c.sqlDriverName, connStr)
	if err != nil {
		return nil, nil, fmt.Errorf("sql.Open for snapshot connection failed: %w", err)
	}

	cleanupPool := func() {
		if err := pool.Close(); err != nil {
			log.Printf("Could not close database connection pool after restoring snapshot: %v", err)
		}
	}

	conn, err := pool.Conn(ctx)
	if err != nil {
		cleanupPool()
		return nil, nil, fmt.Errorf("DB.Conn for snapshot connection failed: %w", err)
	}
	return conn, cleanupPool, nil
}

func (c *PostgresContainer) execCommandsFallback(ctx context.Context, cmds []string) error {
	for _, cmd := range cmds {
		exitCode, reader, err := c.Exec(ctx, []string{"psql", "-v", "ON_ERROR_STOP=1", "-U", c.user, "-d", "postgres", "-c", cmd})
		if err != nil {
			return err
		}
		if exitCode != 0 {
			buf := new(strings.Builder)
			_, err := io.Copy(buf, reader)
			if err != nil {
				return fmt.Errorf("non-zero exit code for restore command, could not read command output: %w", err)
			}

			return fmt.Errorf("non-zero exit code for restore command: %s", buf.String())
		}
	}
	return nil
}