summaryrefslogtreecommitdiff
path: root/vendor/github.com/testcontainers/testcontainers-go/wait/walk.go
blob: 98f5755e140c7987b3759bf8d2298a15c7de25e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
package wait

import (
	"errors"
	"slices"
)

var (
	// ErrVisitStop is used as a return value from [VisitFunc] to stop the walk.
	// It is not returned as an error by any function.
	ErrVisitStop = errors.New("stop the walk")

	// Deprecated: use [ErrVisitStop] instead.
	VisitStop = ErrVisitStop

	// ErrVisitRemove is used as a return value from [VisitFunc] to have the current node removed.
	// It is not returned as an error by any function.
	ErrVisitRemove = errors.New("remove this strategy")

	// Deprecated: use [ErrVisitRemove] instead.
	VisitRemove = ErrVisitRemove
)

// VisitFunc is a function that visits a strategy node.
// If it returns [ErrVisitStop], the walk stops.
// If it returns [ErrVisitRemove], the current node is removed.
type VisitFunc func(root Strategy) error

// Walk walks the strategies tree and calls the visit function for each node.
func Walk(root *Strategy, visit VisitFunc) error {
	if root == nil {
		return errors.New("root strategy is nil")
	}

	if err := walk(root, visit); err != nil {
		if errors.Is(err, ErrVisitRemove) || errors.Is(err, ErrVisitStop) {
			return nil
		}
		return err
	}

	return nil
}

// walk walks the strategies tree and calls the visit function for each node.
// It returns an error if the visit function returns an error.
func walk(root *Strategy, visit VisitFunc) error {
	if *root == nil {
		// No strategy.
		return nil
	}

	// Allow the visit function to customize the behaviour of the walk before visiting the children.
	if err := visit(*root); err != nil {
		if errors.Is(err, ErrVisitRemove) {
			*root = nil
		}

		return err
	}

	if s, ok := (*root).(*MultiStrategy); ok {
		var i int
		for range s.Strategies {
			if err := walk(&s.Strategies[i], visit); err != nil {
				if errors.Is(err, ErrVisitRemove) {
					s.Strategies = slices.Delete(s.Strategies, i, i+1)
					if errors.Is(err, VisitStop) {
						return VisitStop
					}
					continue
				}

				return err
			}
			i++
		}
	}

	return nil
}