summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/zed/internal/cmd/restorer.go
blob: 468064fb45e1f88b1cec08fca9cd804748373bea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
package cmd

import (
	"context"
	"errors"
	"fmt"
	"os"
	"strings"
	"time"

	"github.com/ccoveille/go-safecast"
	"github.com/cenkalti/backoff/v4"
	"github.com/mattn/go-isatty"
	"github.com/rs/zerolog/log"
	"github.com/samber/lo"
	"github.com/schollz/progressbar/v3"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
	"github.com/authzed/spicedb/pkg/spiceerrors"

	"github.com/authzed/zed/internal/client"
	"github.com/authzed/zed/internal/console"
	"github.com/authzed/zed/pkg/backupformat"
)

type ConflictStrategy int

const (
	Fail ConflictStrategy = iota
	Skip
	Touch

	defaultBackoff    = 50 * time.Millisecond
	defaultMaxRetries = 10
)

var conflictStrategyMapping = map[string]ConflictStrategy{
	"fail":  Fail,
	"skip":  Skip,
	"touch": Touch,
}

// Fallback for datastore implementations on SpiceDB < 1.29.0 not returning proper gRPC codes
// Remove once https://github.com/authzed/spicedb/pull/1688 lands
var (
	txConflictCodes = []string{
		"SQLSTATE 23505",     // CockroachDB
		"Error 1062 (23000)", // MySQL
	}
	retryableErrorCodes = []string{
		"retryable error",                          // CockroachDB, PostgreSQL
		"try restarting transaction", "Error 1205", // MySQL
	}
)

type restorer struct {
	schema                string
	decoder               *backupformat.Decoder
	client                client.Client
	prefixFilter          string
	batchSize             uint
	batchesPerTransaction uint
	conflictStrategy      ConflictStrategy
	disableRetryErrors    bool
	bar                   *progressbar.ProgressBar

	// stats
	filteredOutRels  uint
	writtenRels      uint
	writtenBatches   uint
	skippedRels      uint
	skippedBatches   uint
	duplicateRels    uint
	duplicateBatches uint
	totalRetries     uint
	requestTimeout   time.Duration
}

func newRestorer(schema string, decoder *backupformat.Decoder, client client.Client, prefixFilter string, batchSize uint,
	batchesPerTransaction uint, conflictStrategy ConflictStrategy, disableRetryErrors bool,
	requestTimeout time.Duration,
) *restorer {
	return &restorer{
		decoder:               decoder,
		schema:                schema,
		client:                client,
		prefixFilter:          prefixFilter,
		requestTimeout:        requestTimeout,
		batchSize:             batchSize,
		batchesPerTransaction: batchesPerTransaction,
		conflictStrategy:      conflictStrategy,
		disableRetryErrors:    disableRetryErrors,
		bar:                   console.CreateProgressBar("restoring from backup"),
	}
}

func (r *restorer) restoreFromDecoder(ctx context.Context) error {
	relationshipWriteStart := time.Now()
	defer func() {
		if err := r.bar.Finish(); err != nil {
			log.Warn().Err(err).Msg("error finalizing progress bar")
		}
	}()

	r.bar.Describe("restoring schema from backup")
	if _, err := r.client.WriteSchema(ctx, &v1.WriteSchemaRequest{
		Schema: r.schema,
	}); err != nil {
		return fmt.Errorf("unable to write schema: %w", err)
	}

	relationshipWriter, err := r.client.BulkImportRelationships(ctx)
	if err != nil {
		return fmt.Errorf("error creating writer stream: %w", err)
	}

	r.bar.Describe("restoring relationships from backup")
	batch := make([]*v1.Relationship, 0, r.batchSize)
	batchesToBeCommitted := make([][]*v1.Relationship, 0, r.batchesPerTransaction)
	for rel, err := r.decoder.Next(); rel != nil && err == nil; rel, err = r.decoder.Next() {
		if err := ctx.Err(); err != nil {
			r.bar.Describe("backup restore aborted")
			return fmt.Errorf("aborted restore: %w", err)
		}

		if !hasRelPrefix(rel, r.prefixFilter) {
			r.filteredOutRels++
			continue
		}

		batch = append(batch, rel)

		if uint(len(batch))%r.batchSize == 0 {
			batchesToBeCommitted = append(batchesToBeCommitted, batch)
			err := relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{
				Relationships: batch,
			})
			if err != nil {
				// It feels non-idiomatic to check for error and perform an operation, but in gRPC, when an element
				// sent over the stream fails, we need to call recvAndClose() to get the error.
				if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
					return fmt.Errorf("error committing batches: %w", err)
				}

				// after an error
				relationshipWriter, err = r.client.BulkImportRelationships(ctx)
				if err != nil {
					return fmt.Errorf("error creating new writer stream: %w", err)
				}

				batchesToBeCommitted = batchesToBeCommitted[:0]
				batch = batch[:0]
				continue
			}

			// The batch just sent is kept in batchesToBeCommitted, which is used for retries.
			// Therefore, we cannot reuse the batch. Batches may fail on send, or on commit (CloseAndRecv).
			batch = make([]*v1.Relationship, 0, r.batchSize)

			// if we've sent the maximum number of batches per transaction, proceed to commit
			if uint(len(batchesToBeCommitted))%r.batchesPerTransaction != 0 {
				continue
			}

			if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
				return fmt.Errorf("error committing batches: %w", err)
			}

			relationshipWriter, err = r.client.BulkImportRelationships(ctx)
			if err != nil {
				return fmt.Errorf("error creating new writer stream: %w", err)
			}

			batchesToBeCommitted = batchesToBeCommitted[:0]
		}
	}

	// Write the last batch
	if len(batch) > 0 {
		// Since we are going to close the stream anyway after the last batch, and given the actual error
		// is only returned on CloseAndRecv(), we have to ignore the error here in order to get the actual
		// underlying error that caused Send() to fail. It also gives us the opportunity to retry it
		// in case it failed.
		batchesToBeCommitted = append(batchesToBeCommitted, batch)
		_ = relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{Relationships: batch})
	}

	if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
		return fmt.Errorf("error committing last set of batches: %w", err)
	}

	r.bar.Describe("completed import")
	if err := r.bar.Finish(); err != nil {
		log.Warn().Err(err).Msg("error finalizing progress bar")
	}

	totalTime := time.Since(relationshipWriteStart)
	log.Info().
		Uint("batches", r.writtenBatches).
		Uint("relationships_loaded", r.writtenRels).
		Uint("relationships_skipped", r.skippedRels).
		Uint("duplicate_relationships", r.duplicateRels).
		Uint("relationships_filtered_out", r.filteredOutRels).
		Uint("retried_errors", r.totalRetries).
		Uint64("perSecond", perSec(uint64(r.writtenRels), totalTime)).
		Stringer("duration", totalTime).
		Msg("finished restore")
	return nil
}

func (r *restorer) commitStream(ctx context.Context, bulkImportClient v1.ExperimentalService_BulkImportRelationshipsClient,
	batchesToBeCommitted [][]*v1.Relationship,
) error {
	var numLoaded, expectedLoaded, retries uint
	for _, b := range batchesToBeCommitted {
		expectedLoaded += uint(len(b))
	}

	resp, err := bulkImportClient.CloseAndRecv() // transaction commit happens here

	// Failure to commit transaction means the stream is closed, so it can't be reused any further
	// The retry will be done using WriteRelationships instead of BulkImportRelationships
	// This lets us retry with TOUCH semantics in case of failure due to duplicates
	retryable := isRetryableError(err)
	conflict := isAlreadyExistsError(err)
	canceled, cancelErr := isCanceledError(ctx.Err(), err)
	unknown := !retryable && !conflict && !canceled && err != nil

	numBatches := uint(len(batchesToBeCommitted))

	switch {
	case canceled:
		r.bar.Describe("backup restore aborted")
		return cancelErr
	case unknown:
		r.bar.Describe("failed with unrecoverable error")
		return fmt.Errorf("error finalizing write of %d batches: %w", len(batchesToBeCommitted), err)
	case retryable && r.disableRetryErrors:
		return err
	case conflict && r.conflictStrategy == Skip:
		r.skippedRels += expectedLoaded
		r.skippedBatches += numBatches
		r.duplicateBatches += numBatches
		r.duplicateRels += expectedLoaded
		r.bar.Describe("skipping conflicting batch")
	case conflict && r.conflictStrategy == Touch:
		r.bar.Describe("touching conflicting batch")
		r.duplicateRels += expectedLoaded
		r.duplicateBatches += numBatches
		r.totalRetries++
		numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted)
		if err != nil {
			return fmt.Errorf("failed to write retried batch: %w", err)
		}

		retries++ // account for the initial attempt
		r.writtenBatches += numBatches
		r.writtenRels += numLoaded
	case conflict && r.conflictStrategy == Fail:
		r.bar.Describe("conflict detected, aborting restore")
		return fmt.Errorf("duplicate relationships found")
	case retryable:
		r.bar.Describe("retrying after error")
		r.totalRetries++
		numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted)
		if err != nil {
			return fmt.Errorf("failed to write retried batch: %w", err)
		}

		retries++ // account for the initial attempt
		r.writtenBatches += numBatches
		r.writtenRels += numLoaded
	default:
		r.bar.Describe("restoring relationships from backup")
		r.writtenBatches += numBatches
	}

	// it was a successful transaction commit without duplicates
	if resp != nil {
		numLoaded, err := safecast.ToUint(resp.NumLoaded)
		if err != nil {
			return spiceerrors.MustBugf("could not cast numLoaded to uint")
		}
		r.writtenRels += numLoaded
		if uint64(expectedLoaded) != resp.NumLoaded {
			log.Warn().Uint64("loaded", resp.NumLoaded).Uint("expected", expectedLoaded).Msg("unexpected number of relationships loaded")
		}
	}

	writtenAndSkipped, err := safecast.ToInt64(r.writtenRels + r.skippedRels)
	if err != nil {
		return fmt.Errorf("too many written and skipped rels for an int64")
	}

	if err := r.bar.Set64(writtenAndSkipped); err != nil {
		return fmt.Errorf("error incrementing progress bar: %w", err)
	}

	if !isatty.IsTerminal(os.Stderr.Fd()) {
		log.Trace().
			Uint("batches_written", r.writtenBatches).
			Uint("relationships_written", r.writtenRels).
			Uint("duplicate_batches", r.duplicateBatches).
			Uint("duplicate_relationships", r.duplicateRels).
			Uint("skipped_batches", r.skippedBatches).
			Uint("skipped_relationships", r.skippedRels).
			Uint("retries", retries).
			Msg("restore progress")
	}

	return nil
}

// writeBatchesWithRetry writes a set of batches using touch semantics and without transactional guarantees -
// each batch will be committed independently. If a batch fails, it will be retried up to 10 times with a backoff.
func (r *restorer) writeBatchesWithRetry(ctx context.Context, batches [][]*v1.Relationship) (uint, uint, error) {
	backoffInterval := backoff.NewExponentialBackOff()
	backoffInterval.InitialInterval = defaultBackoff
	backoffInterval.MaxInterval = 2 * time.Second
	backoffInterval.MaxElapsedTime = 0
	backoffInterval.Reset()

	var currentRetries, totalRetries, loadedRels uint
	for _, batch := range batches {
		updates := lo.Map[*v1.Relationship, *v1.RelationshipUpdate](batch, func(item *v1.Relationship, _ int) *v1.RelationshipUpdate {
			return &v1.RelationshipUpdate{
				Relationship: item,
				Operation:    v1.RelationshipUpdate_OPERATION_TOUCH,
			}
		})

		for {
			cancelCtx, cancel := context.WithTimeout(ctx, r.requestTimeout)
			_, err := r.client.WriteRelationships(cancelCtx, &v1.WriteRelationshipsRequest{Updates: updates})
			cancel()

			if isRetryableError(err) && currentRetries < defaultMaxRetries {
				// throttle the writes so we don't overwhelm the server
				bo := backoffInterval.NextBackOff()
				r.bar.Describe(fmt.Sprintf("retrying write with backoff %s after error (attempt %d/%d)", bo,
					currentRetries+1, defaultMaxRetries))
				time.Sleep(bo)
				currentRetries++
				r.totalRetries++
				totalRetries++
				continue
			}
			if err != nil {
				return 0, 0, err
			}

			currentRetries = 0
			backoffInterval.Reset()
			loadedRels += uint(len(batch))
			break
		}
	}

	return loadedRels, totalRetries, nil
}

func isAlreadyExistsError(err error) bool {
	if err == nil {
		return false
	}

	if isGRPCCode(err, codes.AlreadyExists) {
		return true
	}

	return isContainsErrorString(err, txConflictCodes...)
}

func isRetryableError(err error) bool {
	if err == nil {
		return false
	}

	if isGRPCCode(err, codes.Unavailable, codes.DeadlineExceeded) {
		return true
	}

	if isContainsErrorString(err, retryableErrorCodes...) {
		return true
	}

	return errors.Is(err, context.DeadlineExceeded)
}

func isCanceledError(errs ...error) (bool, error) {
	for _, err := range errs {
		if err == nil {
			continue
		}

		if errors.Is(err, context.Canceled) {
			return true, err
		}

		if isGRPCCode(err, codes.Canceled) {
			return true, err
		}
	}

	return false, nil
}

func isContainsErrorString(err error, errStrings ...string) bool {
	if err == nil {
		return false
	}

	for _, errString := range errStrings {
		if strings.Contains(err.Error(), errString) {
			return true
		}
	}

	return false
}

func isGRPCCode(err error, codes ...codes.Code) bool {
	if err == nil {
		return false
	}

	if s, ok := status.FromError(err); ok {
		for _, code := range codes {
			if s.Code() == code {
				return true
			}
		}
	}

	return false
}

func perSec(i uint64, d time.Duration) uint64 {
	secs := uint64(d.Seconds())
	if secs == 0 {
		return i
	}
	return i / secs
}