summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go
blob: 1b519ed117a4ec372ea71545aed60395c88b8f36 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
package taskrunner

import (
	"context"
	"sync"
)

// TaskRunner is a helper which runs a series of scheduled tasks against a defined
// limit of goroutines.
type TaskRunner struct {
	// ctx holds the context given to the task runner and annotated with the cancel
	// function.
	ctx    context.Context
	cancel func()

	// sem is a chan of length `concurrencyLimit` used to ensure the task runner does
	// not exceed the concurrencyLimit with spawned goroutines.
	sem chan struct{}

	wg sync.WaitGroup

	lock  sync.Mutex
	tasks []TaskFunc // GUARDED_BY(lock)

	// err holds the error returned by any task, if any. If the context is canceled,
	// this err will hold the cancelation error.
	err error // GUARDED_BY(lock)
}

// TaskFunc defines functions representing tasks.
type TaskFunc func(ctx context.Context) error

// NewTaskRunner creates a new task runner with the given starting context and
// concurrency limit. The TaskRunner will schedule no more goroutines that the
// specified concurrencyLimit. If the given context is canceled, then all tasks
// started after that point will also be canceled and the error returned. If
// a task returns an error, the context provided to all tasks is also canceled.
func NewTaskRunner(ctx context.Context, concurrencyLimit uint16) *TaskRunner {
	if concurrencyLimit < 1 {
		concurrencyLimit = 1
	}

	ctxWithCancel, cancel := context.WithCancel(ctx)
	return &TaskRunner{
		ctx:    ctxWithCancel,
		cancel: cancel,
		sem:    make(chan struct{}, concurrencyLimit),
		tasks:  make([]TaskFunc, 0),
	}
}

// Schedule schedules a task to be run. This is safe to call from within another
// task handler function and immediately returns.
func (tr *TaskRunner) Schedule(f TaskFunc) {
	if tr.addTask(f) {
		tr.spawnIfAvailable()
	}
}

func (tr *TaskRunner) spawnIfAvailable() {
	// To spawn a runner, write a struct{} to the sem channel. If the task runner
	// is already at the concurrency limit, then this chan write will fail,
	// and nothing will be spawned. This also checks if the context has already
	// been canceled, in which case nothing needs to be done.
	select {
	case tr.sem <- struct{}{}:
		go tr.runner()

	case <-tr.ctx.Done():
		return

	default:
		return
	}
}

func (tr *TaskRunner) runner() {
	for {
		select {
		case <-tr.ctx.Done():
			// If the context was canceled, mark all the remaining tasks as "Done".
			tr.emptyForCancel()
			return

		default:
			// Select a task from the list, if any.
			task := tr.selectTask()
			if task == nil {
				// If there are no further tasks, then "return" the struct{} by reading
				// it from the channel (freeing a slot potentially for another worker
				// to be spawned later).
				<-tr.sem
				return
			}

			// Run the task. If an error occurs, store it and cancel any further tasks.
			err := task(tr.ctx)
			if err != nil {
				tr.storeErrorAndCancel(err)
			}
			tr.wg.Done()
		}
	}
}

func (tr *TaskRunner) addTask(f TaskFunc) bool {
	tr.lock.Lock()
	defer tr.lock.Unlock()

	if tr.err != nil {
		return false
	}

	tr.wg.Add(1)
	tr.tasks = append(tr.tasks, f)
	return true
}

func (tr *TaskRunner) selectTask() TaskFunc {
	tr.lock.Lock()
	defer tr.lock.Unlock()

	if len(tr.tasks) == 0 {
		return nil
	}

	task := tr.tasks[0]
	tr.tasks = tr.tasks[1:]
	return task
}

func (tr *TaskRunner) storeErrorAndCancel(err error) {
	tr.lock.Lock()
	defer tr.lock.Unlock()

	if tr.err == nil {
		tr.err = err
		tr.cancel()
	}
}

func (tr *TaskRunner) emptyForCancel() {
	tr.lock.Lock()
	defer tr.lock.Unlock()

	if tr.err == nil {
		tr.err = tr.ctx.Err()
	}

	for {
		if len(tr.tasks) == 0 {
			break
		}

		tr.tasks = tr.tasks[1:]
		tr.wg.Done()
	}
}

// Wait waits for all tasks to be completed, or a task to raise an error,
// or the parent context to have been canceled.
func (tr *TaskRunner) Wait() error {
	tr.wg.Wait()

	tr.lock.Lock()
	defer tr.lock.Unlock()
	return tr.err
}