1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
package grpcutil
import (
"context"
"errors"
"fmt"
"runtime"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)
func minimum(a int, b int) int {
if a <= b {
return a
}
return b
}
// EachFunc is a callback function that is called for each batch. no is the
// batch number, start is the starting index of this batch in the slice, and
// end is the ending index of this batch in the slice.
type EachFunc func(ctx context.Context, no int, start int, end int) error
// ConcurrentBatch will calculate the minimum number of batches to required to batch n items
// with batchSize batches. For each batch, it will execute the each function.
// These functions will be processed in parallel using maxWorkers number of
// goroutines. If maxWorkers is 1, then batching will happen sychronously. If
// maxWorkers is 0, then GOMAXPROCS number of workers will be used.
//
// If an error occurs during a batch, all the worker's contexts are cancelled
// and the original error is returned.
func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int, each EachFunc) error {
if n < 0 {
return errors.New("cannot batch items of length < 0")
} else if n == 0 {
// Batching zero items is a noop.
return nil
}
if batchSize < 1 {
return errors.New("cannot batch items with batch size < 1")
}
if maxWorkers < 0 {
return errors.New("cannot batch items with workers < 0")
} else if maxWorkers == 0 {
maxWorkers = runtime.GOMAXPROCS(0)
}
sem := semaphore.NewWeighted(int64(maxWorkers))
g, ctx := errgroup.WithContext(ctx)
numBatches := (n + batchSize - 1) / batchSize
for i := 0; i < numBatches; i++ {
if err := sem.Acquire(ctx, 1); err != nil {
return fmt.Errorf("failed to acquire semaphore for batch number %d: %w", i, err)
}
batchNum := i
g.Go(func() error {
defer sem.Release(1)
start := batchNum * batchSize
end := minimum(start+batchSize, n)
return each(ctx, batchNum, start, end)
})
}
return g.Wait()
}
|