diff options
Diffstat (limited to 'vendor/github.com/testcontainers/testcontainers-go/wait/http.go')
| -rw-r--r-- | vendor/github.com/testcontainers/testcontainers-go/wait/http.go | 338 |
1 files changed, 338 insertions, 0 deletions
diff --git a/vendor/github.com/testcontainers/testcontainers-go/wait/http.go b/vendor/github.com/testcontainers/testcontainers-go/wait/http.go new file mode 100644 index 0000000..2c7c655 --- /dev/null +++ b/vendor/github.com/testcontainers/testcontainers-go/wait/http.go @@ -0,0 +1,338 @@ +package wait + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/docker/go-connections/nat" +) + +// Implement interface +var ( + _ Strategy = (*HTTPStrategy)(nil) + _ StrategyTimeout = (*HTTPStrategy)(nil) +) + +type HTTPStrategy struct { + // all Strategies should have a startupTimeout to avoid waiting infinitely + timeout *time.Duration + + // additional properties + Port nat.Port + Path string + StatusCodeMatcher func(status int) bool + ResponseMatcher func(body io.Reader) bool + UseTLS bool + AllowInsecure bool + TLSConfig *tls.Config // TLS config for HTTPS + Method string // http method + Body io.Reader // http request body + Headers map[string]string + ResponseHeadersMatcher func(headers http.Header) bool + PollInterval time.Duration + UserInfo *url.Userinfo + ForceIPv4LocalHost bool +} + +// NewHTTPStrategy constructs a HTTP strategy waiting on port 80 and status code 200 +func NewHTTPStrategy(path string) *HTTPStrategy { + return &HTTPStrategy{ + Port: "", + Path: path, + StatusCodeMatcher: defaultStatusCodeMatcher, + ResponseMatcher: func(_ io.Reader) bool { return true }, + UseTLS: false, + TLSConfig: nil, + Method: http.MethodGet, + Body: nil, + Headers: map[string]string{}, + ResponseHeadersMatcher: func(_ http.Header) bool { return true }, + PollInterval: defaultPollInterval(), + UserInfo: nil, + } +} + +func defaultStatusCodeMatcher(status int) bool { + return status == http.StatusOK +} + +// fluent builders for each property +// since go has neither covariance nor generics, the return type must be the type of the concrete implementation +// this is true for all properties, even the "shared" ones like startupTimeout + +// WithStartupTimeout can be used to change the default startup timeout +func (ws *HTTPStrategy) WithStartupTimeout(timeout time.Duration) *HTTPStrategy { + ws.timeout = &timeout + return ws +} + +// WithPort set the port to wait for. +// Default is the lowest numbered port. +func (ws *HTTPStrategy) WithPort(port nat.Port) *HTTPStrategy { + ws.Port = port + return ws +} + +func (ws *HTTPStrategy) WithStatusCodeMatcher(statusCodeMatcher func(status int) bool) *HTTPStrategy { + ws.StatusCodeMatcher = statusCodeMatcher + return ws +} + +func (ws *HTTPStrategy) WithResponseMatcher(matcher func(body io.Reader) bool) *HTTPStrategy { + ws.ResponseMatcher = matcher + return ws +} + +func (ws *HTTPStrategy) WithTLS(useTLS bool, tlsconf ...*tls.Config) *HTTPStrategy { + ws.UseTLS = useTLS + if useTLS && len(tlsconf) > 0 { + ws.TLSConfig = tlsconf[0] + } + return ws +} + +func (ws *HTTPStrategy) WithAllowInsecure(allowInsecure bool) *HTTPStrategy { + ws.AllowInsecure = allowInsecure + return ws +} + +func (ws *HTTPStrategy) WithMethod(method string) *HTTPStrategy { + ws.Method = method + return ws +} + +func (ws *HTTPStrategy) WithBody(reqdata io.Reader) *HTTPStrategy { + ws.Body = reqdata + return ws +} + +func (ws *HTTPStrategy) WithHeaders(headers map[string]string) *HTTPStrategy { + ws.Headers = headers + return ws +} + +func (ws *HTTPStrategy) WithResponseHeadersMatcher(matcher func(http.Header) bool) *HTTPStrategy { + ws.ResponseHeadersMatcher = matcher + return ws +} + +func (ws *HTTPStrategy) WithBasicAuth(username, password string) *HTTPStrategy { + ws.UserInfo = url.UserPassword(username, password) + return ws +} + +// WithPollInterval can be used to override the default polling interval of 100 milliseconds +func (ws *HTTPStrategy) WithPollInterval(pollInterval time.Duration) *HTTPStrategy { + ws.PollInterval = pollInterval + return ws +} + +// WithForcedIPv4LocalHost forces usage of localhost to be ipv4 127.0.0.1 +// to avoid ipv6 docker bugs https://github.com/moby/moby/issues/42442 https://github.com/moby/moby/issues/42375 +func (ws *HTTPStrategy) WithForcedIPv4LocalHost() *HTTPStrategy { + ws.ForceIPv4LocalHost = true + return ws +} + +// ForHTTP is a convenience method similar to Wait.java +// https://github.com/testcontainers/testcontainers-java/blob/1d85a3834bd937f80aad3a4cec249c027f31aeb4/core/src/main/java/org/testcontainers/containers/wait/strategy/Wait.java +func ForHTTP(path string) *HTTPStrategy { + return NewHTTPStrategy(path) +} + +func (ws *HTTPStrategy) Timeout() *time.Duration { + return ws.timeout +} + +// WaitUntilReady implements Strategy.WaitUntilReady +func (ws *HTTPStrategy) WaitUntilReady(ctx context.Context, target StrategyTarget) error { + timeout := defaultStartupTimeout() + if ws.timeout != nil { + timeout = *ws.timeout + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ipAddress, err := target.Host(ctx) + if err != nil { + return err + } + // to avoid ipv6 docker bugs https://github.com/moby/moby/issues/42442 https://github.com/moby/moby/issues/42375 + if ws.ForceIPv4LocalHost { + ipAddress = strings.Replace(ipAddress, "localhost", "127.0.0.1", 1) + } + + var mappedPort nat.Port + if ws.Port == "" { + // We wait one polling interval before we grab the ports + // otherwise they might not be bound yet on startup. + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(ws.PollInterval): + // Port should now be bound so just continue. + } + + if err := checkTarget(ctx, target); err != nil { + return err + } + + inspect, err := target.Inspect(ctx) + if err != nil { + return err + } + + // Find the lowest numbered exposed tcp port. + var lowestPort nat.Port + var hostPort string + for port, bindings := range inspect.NetworkSettings.Ports { + if len(bindings) == 0 || port.Proto() != "tcp" { + continue + } + + if lowestPort == "" || port.Int() < lowestPort.Int() { + lowestPort = port + hostPort = bindings[0].HostPort + } + } + + if lowestPort == "" { + return errors.New("No exposed tcp ports or mapped ports - cannot wait for status") + } + + mappedPort, _ = nat.NewPort(lowestPort.Proto(), hostPort) + } else { + mappedPort, err = target.MappedPort(ctx, ws.Port) + + for mappedPort == "" { + select { + case <-ctx.Done(): + return fmt.Errorf("%w: %w", ctx.Err(), err) + case <-time.After(ws.PollInterval): + if err := checkTarget(ctx, target); err != nil { + return err + } + + mappedPort, err = target.MappedPort(ctx, ws.Port) + } + } + + if mappedPort.Proto() != "tcp" { + return errors.New("Cannot use HTTP client on non-TCP ports") + } + } + + switch ws.Method { + case http.MethodGet, http.MethodHead, http.MethodPost, + http.MethodPut, http.MethodPatch, http.MethodDelete, + http.MethodConnect, http.MethodOptions, http.MethodTrace: + default: + if ws.Method != "" { + return fmt.Errorf("invalid http method %q", ws.Method) + } + ws.Method = http.MethodGet + } + + tripper := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: ws.TLSConfig, + } + + var proto string + if ws.UseTLS { + proto = "https" + if ws.AllowInsecure { + if ws.TLSConfig == nil { + tripper.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } else { + ws.TLSConfig.InsecureSkipVerify = true + } + } + } else { + proto = "http" + } + + client := http.Client{Transport: tripper, Timeout: time.Second} + address := net.JoinHostPort(ipAddress, strconv.Itoa(mappedPort.Int())) + + endpoint, err := url.Parse(ws.Path) + if err != nil { + return err + } + endpoint.Scheme = proto + endpoint.Host = address + + if ws.UserInfo != nil { + endpoint.User = ws.UserInfo + } + + // cache the body into a byte-slice so that it can be iterated over multiple times + var body []byte + if ws.Body != nil { + body, err = io.ReadAll(ws.Body) + if err != nil { + return err + } + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(ws.PollInterval): + if err := checkTarget(ctx, target); err != nil { + return err + } + req, err := http.NewRequestWithContext(ctx, ws.Method, endpoint.String(), bytes.NewReader(body)) + if err != nil { + return err + } + + for k, v := range ws.Headers { + req.Header.Set(k, v) + } + + resp, err := client.Do(req) + if err != nil { + continue + } + if ws.StatusCodeMatcher != nil && !ws.StatusCodeMatcher(resp.StatusCode) { + _ = resp.Body.Close() + continue + } + if ws.ResponseMatcher != nil && !ws.ResponseMatcher(resp.Body) { + _ = resp.Body.Close() + continue + } + if ws.ResponseHeadersMatcher != nil && !ws.ResponseHeadersMatcher(resp.Header) { + _ = resp.Body.Close() + continue + } + if err := resp.Body.Close(); err != nil { + continue + } + return nil + } + } +} |
