summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go')
-rw-r--r--vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go168
1 files changed, 168 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go b/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go
new file mode 100644
index 0000000..1b519ed
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go
@@ -0,0 +1,168 @@
+package taskrunner
+
+import (
+ "context"
+ "sync"
+)
+
+// TaskRunner is a helper which runs a series of scheduled tasks against a defined
+// limit of goroutines.
+type TaskRunner struct {
+ // ctx holds the context given to the task runner and annotated with the cancel
+ // function.
+ ctx context.Context
+ cancel func()
+
+ // sem is a chan of length `concurrencyLimit` used to ensure the task runner does
+ // not exceed the concurrencyLimit with spawned goroutines.
+ sem chan struct{}
+
+ wg sync.WaitGroup
+
+ lock sync.Mutex
+ tasks []TaskFunc // GUARDED_BY(lock)
+
+ // err holds the error returned by any task, if any. If the context is canceled,
+ // this err will hold the cancelation error.
+ err error // GUARDED_BY(lock)
+}
+
+// TaskFunc defines functions representing tasks.
+type TaskFunc func(ctx context.Context) error
+
+// NewTaskRunner creates a new task runner with the given starting context and
+// concurrency limit. The TaskRunner will schedule no more goroutines that the
+// specified concurrencyLimit. If the given context is canceled, then all tasks
+// started after that point will also be canceled and the error returned. If
+// a task returns an error, the context provided to all tasks is also canceled.
+func NewTaskRunner(ctx context.Context, concurrencyLimit uint16) *TaskRunner {
+ if concurrencyLimit < 1 {
+ concurrencyLimit = 1
+ }
+
+ ctxWithCancel, cancel := context.WithCancel(ctx)
+ return &TaskRunner{
+ ctx: ctxWithCancel,
+ cancel: cancel,
+ sem: make(chan struct{}, concurrencyLimit),
+ tasks: make([]TaskFunc, 0),
+ }
+}
+
+// Schedule schedules a task to be run. This is safe to call from within another
+// task handler function and immediately returns.
+func (tr *TaskRunner) Schedule(f TaskFunc) {
+ if tr.addTask(f) {
+ tr.spawnIfAvailable()
+ }
+}
+
+func (tr *TaskRunner) spawnIfAvailable() {
+ // To spawn a runner, write a struct{} to the sem channel. If the task runner
+ // is already at the concurrency limit, then this chan write will fail,
+ // and nothing will be spawned. This also checks if the context has already
+ // been canceled, in which case nothing needs to be done.
+ select {
+ case tr.sem <- struct{}{}:
+ go tr.runner()
+
+ case <-tr.ctx.Done():
+ return
+
+ default:
+ return
+ }
+}
+
+func (tr *TaskRunner) runner() {
+ for {
+ select {
+ case <-tr.ctx.Done():
+ // If the context was canceled, mark all the remaining tasks as "Done".
+ tr.emptyForCancel()
+ return
+
+ default:
+ // Select a task from the list, if any.
+ task := tr.selectTask()
+ if task == nil {
+ // If there are no further tasks, then "return" the struct{} by reading
+ // it from the channel (freeing a slot potentially for another worker
+ // to be spawned later).
+ <-tr.sem
+ return
+ }
+
+ // Run the task. If an error occurs, store it and cancel any further tasks.
+ err := task(tr.ctx)
+ if err != nil {
+ tr.storeErrorAndCancel(err)
+ }
+ tr.wg.Done()
+ }
+ }
+}
+
+func (tr *TaskRunner) addTask(f TaskFunc) bool {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if tr.err != nil {
+ return false
+ }
+
+ tr.wg.Add(1)
+ tr.tasks = append(tr.tasks, f)
+ return true
+}
+
+func (tr *TaskRunner) selectTask() TaskFunc {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if len(tr.tasks) == 0 {
+ return nil
+ }
+
+ task := tr.tasks[0]
+ tr.tasks = tr.tasks[1:]
+ return task
+}
+
+func (tr *TaskRunner) storeErrorAndCancel(err error) {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if tr.err == nil {
+ tr.err = err
+ tr.cancel()
+ }
+}
+
+func (tr *TaskRunner) emptyForCancel() {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if tr.err == nil {
+ tr.err = tr.ctx.Err()
+ }
+
+ for {
+ if len(tr.tasks) == 0 {
+ break
+ }
+
+ tr.tasks = tr.tasks[1:]
+ tr.wg.Done()
+ }
+}
+
+// Wait waits for all tasks to be completed, or a task to raise an error,
+// or the parent context to have been canceled.
+func (tr *TaskRunner) Wait() error {
+ tr.wg.Wait()
+
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+ return tr.err
+}