summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go
blob: eaa48123215b1d57b9fcaab42ae0e771a3b986ff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package memdb

import (
	"context"
	"errors"
	"fmt"
	"time"

	"github.com/hashicorp/go-memdb"

	"github.com/authzed/spicedb/internal/datastore/revisions"
	"github.com/authzed/spicedb/pkg/datastore"
)

const errWatchError = "watch error: %w"

func (mdb *memdbDatastore) Watch(ctx context.Context, ar datastore.Revision, options datastore.WatchOptions) (<-chan datastore.RevisionChanges, <-chan error) {
	watchBufferLength := options.WatchBufferLength
	if watchBufferLength == 0 {
		watchBufferLength = mdb.watchBufferLength
	}

	updates := make(chan datastore.RevisionChanges, watchBufferLength)
	errs := make(chan error, 1)

	if options.EmissionStrategy == datastore.EmitImmediatelyStrategy {
		close(updates)
		errs <- errors.New("emit immediately strategy is unsupported in MemDB")
		return updates, errs
	}

	watchBufferWriteTimeout := options.WatchBufferWriteTimeout
	if watchBufferWriteTimeout == 0 {
		watchBufferWriteTimeout = mdb.watchBufferWriteTimeout
	}

	sendChange := func(change datastore.RevisionChanges) bool {
		select {
		case updates <- change:
			return true

		default:
			// If we cannot immediately write, setup the timer and try again.
		}

		timer := time.NewTimer(watchBufferWriteTimeout)
		defer timer.Stop()

		select {
		case updates <- change:
			return true

		case <-timer.C:
			errs <- datastore.NewWatchDisconnectedErr()
			return false
		}
	}

	go func() {
		defer close(updates)
		defer close(errs)

		currentTxn := ar.(revisions.TimestampRevision).TimestampNanoSec()

		for {
			var stagedUpdates []datastore.RevisionChanges
			var watchChan <-chan struct{}
			var err error
			stagedUpdates, currentTxn, watchChan, err = mdb.loadChanges(ctx, currentTxn, options)
			if err != nil {
				errs <- err
				return
			}

			// Write the staged updates to the channel
			for _, changeToWrite := range stagedUpdates {
				if !sendChange(changeToWrite) {
					return
				}
			}

			// Wait for new changes
			ws := memdb.NewWatchSet()
			ws.Add(watchChan)

			err = ws.WatchCtx(ctx)
			if err != nil {
				switch {
				case errors.Is(err, context.Canceled):
					errs <- datastore.NewWatchCanceledErr()
				default:
					errs <- fmt.Errorf(errWatchError, err)
				}
				return
			}
		}
	}()

	return updates, errs
}

func (mdb *memdbDatastore) loadChanges(_ context.Context, currentTxn int64, options datastore.WatchOptions) ([]datastore.RevisionChanges, int64, <-chan struct{}, error) {
	mdb.RLock()
	defer mdb.RUnlock()

	if err := mdb.checkNotClosed(); err != nil {
		return nil, 0, nil, err
	}

	loadNewTxn := mdb.db.Txn(false)
	defer loadNewTxn.Abort()

	it, err := loadNewTxn.LowerBound(tableChangelog, indexRevision, currentTxn+1)
	if err != nil {
		return nil, 0, nil, fmt.Errorf(errWatchError, err)
	}

	var changes []datastore.RevisionChanges
	lastRevision := currentTxn
	for changeRaw := it.Next(); changeRaw != nil; changeRaw = it.Next() {
		change := changeRaw.(*changelog)

		if options.Content&datastore.WatchRelationships == datastore.WatchRelationships && len(change.changes.RelationshipChanges) > 0 {
			changes = append(changes, change.changes)
		}

		if options.Content&datastore.WatchSchema == datastore.WatchSchema &&
			len(change.changes.ChangedDefinitions) > 0 || len(change.changes.DeletedCaveats) > 0 || len(change.changes.DeletedNamespaces) > 0 {
			changes = append(changes, change.changes)
		}

		if options.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints && change.revisionNanos > lastRevision {
			changes = append(changes, datastore.RevisionChanges{
				Revision:     revisions.NewForTimestamp(change.revisionNanos),
				IsCheckpoint: true,
			})
		}

		lastRevision = change.revisionNanos
	}

	watchChan, _, err := loadNewTxn.LastWatch(tableChangelog, indexRevision)
	if err != nil {
		return nil, 0, nil, fmt.Errorf(errWatchError, err)
	}

	return changes, lastRevision, watchChan, nil
}