diff options
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go')
| -rw-r--r-- | vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go b/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go new file mode 100644 index 0000000..a088e35 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go @@ -0,0 +1,153 @@ +package taskrunner + +import ( + "context" + "sync" +) + +// PreloadedTaskRunner is a task runner that invokes a series of preloaded tasks, +// running until the tasks are completed, the context is canceled or an error is +// returned by one of the tasks (which cancels the context). +type PreloadedTaskRunner struct { + // ctx holds the context given to the task runner and annotated with the cancel + // function. + ctx context.Context + cancel func() + + // sem is a chan of length `concurrencyLimit` used to ensure the task runner does + // not exceed the concurrencyLimit with spawned goroutines. + sem chan struct{} + + wg sync.WaitGroup + + lock sync.Mutex + err error // GUARDED_BY(lock) + tasks []TaskFunc // GUARDED_BY(lock) +} + +func NewPreloadedTaskRunner(ctx context.Context, concurrencyLimit uint16, initialCapacity int) *PreloadedTaskRunner { + // Ensure a concurrency level of at least 1. + if concurrencyLimit <= 0 { + concurrencyLimit = 1 + } + + ctxWithCancel, cancel := context.WithCancel(ctx) + return &PreloadedTaskRunner{ + ctx: ctxWithCancel, + cancel: cancel, + sem: make(chan struct{}, concurrencyLimit), + tasks: make([]TaskFunc, 0, initialCapacity), + } +} + +// Add adds the given task function to be run. +func (tr *PreloadedTaskRunner) Add(f TaskFunc) { + tr.tasks = append(tr.tasks, f) + tr.wg.Add(1) +} + +// Start starts running the tasks in the task runner. This does *not* wait for the tasks +// to complete, but rather returns immediately. +func (tr *PreloadedTaskRunner) Start() { + for range tr.tasks { + tr.spawnIfAvailable() + } +} + +// StartAndWait starts running the tasks in the task runner and waits for them to complete. +func (tr *PreloadedTaskRunner) StartAndWait() error { + tr.Start() + tr.wg.Wait() + + tr.lock.Lock() + defer tr.lock.Unlock() + + return tr.err +} + +func (tr *PreloadedTaskRunner) spawnIfAvailable() { + // To spawn a runner, write a struct{} to the sem channel. If the task runner + // is already at the concurrency limit, then this chan write will fail, + // and nothing will be spawned. This also checks if the context has already + // been canceled, in which case nothing needs to be done. + select { + case tr.sem <- struct{}{}: + go tr.runner() + + case <-tr.ctx.Done(): + // If the context was canceled, nothing more to do. + tr.emptyForCancel() + return + + default: + return + } +} + +func (tr *PreloadedTaskRunner) runner() { + for { + select { + case <-tr.ctx.Done(): + // If the context was canceled, nothing more to do. + tr.emptyForCancel() + return + + default: + // Select a task from the list, if any. + task := tr.selectTask() + if task == nil { + return + } + + // Run the task. If an error occurs, store it and cancel any further tasks. + err := task(tr.ctx) + if err != nil { + tr.storeErrorAndCancel(err) + } + tr.wg.Done() + } + } +} + +func (tr *PreloadedTaskRunner) selectTask() TaskFunc { + tr.lock.Lock() + defer tr.lock.Unlock() + + if len(tr.tasks) == 0 { + return nil + } + + task := tr.tasks[0] + tr.tasks[0] = nil // to free the reference once the task completes. + tr.tasks = tr.tasks[1:] + return task +} + +func (tr *PreloadedTaskRunner) storeErrorAndCancel(err error) { + tr.lock.Lock() + defer tr.lock.Unlock() + + if tr.err == nil { + tr.err = err + tr.cancel() + } +} + +func (tr *PreloadedTaskRunner) emptyForCancel() { + tr.lock.Lock() + defer tr.lock.Unlock() + + if tr.err == nil { + tr.err = tr.ctx.Err() + } + + for { + if len(tr.tasks) == 0 { + break + } + + tr.tasks[0] = nil // to free the reference + tr.tasks = tr.tasks[1:] + tr.wg.Done() + } +} |
