summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/internal/datastore
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/spicedb/internal/datastore
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal/datastore')
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go352
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go154
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go269
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go49
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/index.go28
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go15
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go42
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go214
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go188
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go17
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go961
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go31
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/url.go19
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go276
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md23
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go156
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go37
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go386
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go597
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go386
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go118
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go232
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go51
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go148
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go79
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go166
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go118
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go125
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go97
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go80
30 files changed, 5414 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go
new file mode 100644
index 0000000..291abb5
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go
@@ -0,0 +1,352 @@
+package common
+
+import (
+ "context"
+ "sort"
+
+ "golang.org/x/exp/maps"
+ "google.golang.org/protobuf/types/known/structpb"
+
+ "github.com/ccoveille/go-safecast"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+const (
+ nsPrefix = "n$"
+ caveatPrefix = "c$"
+)
+
+// Changes represents a set of datastore mutations that are kept self-consistent
+// across one or more transaction revisions.
+type Changes[R datastore.Revision, K comparable] struct {
+ records map[K]changeRecord[R]
+ keyFunc func(R) K
+ content datastore.WatchContent
+ maxByteSize uint64
+ currentByteSize int64
+}
+
+type changeRecord[R datastore.Revision] struct {
+ rev R
+ relTouches map[string]tuple.Relationship
+ relDeletes map[string]tuple.Relationship
+ definitionsChanged map[string]datastore.SchemaDefinition
+ namespacesDeleted map[string]struct{}
+ caveatsDeleted map[string]struct{}
+ metadata map[string]any
+}
+
+// NewChanges creates a new Changes object for change tracking and de-duplication.
+func NewChanges[R datastore.Revision, K comparable](keyFunc func(R) K, content datastore.WatchContent, maxByteSize uint64) *Changes[R, K] {
+ return &Changes[R, K]{
+ records: make(map[K]changeRecord[R], 0),
+ keyFunc: keyFunc,
+ content: content,
+ maxByteSize: maxByteSize,
+ currentByteSize: 0,
+ }
+}
+
+// IsEmpty returns if the change set is empty.
+func (ch *Changes[R, K]) IsEmpty() bool {
+ return len(ch.records) == 0
+}
+
+// AddRelationshipChange adds a specific change to the complete list of tracked changes
+func (ch *Changes[R, K]) AddRelationshipChange(
+ ctx context.Context,
+ rev R,
+ rel tuple.Relationship,
+ op tuple.UpdateOperation,
+) error {
+ if ch.content&datastore.WatchRelationships != datastore.WatchRelationships {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ key := tuple.StringWithoutCaveatOrExpiration(rel)
+
+ switch op {
+ case tuple.UpdateOperationTouch:
+ // If there was a delete for the same tuple at the same revision, drop it
+ existing, ok := record.relDeletes[key]
+ if ok {
+ delete(record.relDeletes, key)
+ if err := ch.adjustByteSize(existing, -1); err != nil {
+ return err
+ }
+ }
+
+ record.relTouches[key] = rel
+ if err := ch.adjustByteSize(rel, 1); err != nil {
+ return err
+ }
+
+ case tuple.UpdateOperationDelete:
+ _, alreadyTouched := record.relTouches[key]
+ if !alreadyTouched {
+ record.relDeletes[key] = rel
+ if err := ch.adjustByteSize(rel, 1); err != nil {
+ return err
+ }
+ }
+
+ default:
+ return spiceerrors.MustBugf("unknown change operation")
+ }
+
+ return nil
+}
+
+type sized interface {
+ SizeVT() int
+}
+
+func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error {
+ if ch.maxByteSize == 0 {
+ return nil
+ }
+
+ size := item.SizeVT() * delta
+ ch.currentByteSize += int64(size)
+ if ch.currentByteSize < 0 {
+ return spiceerrors.MustBugf("byte size underflow")
+ }
+
+ currentByteSize, err := safecast.ToUint64(ch.currentByteSize)
+ if err != nil {
+ return spiceerrors.MustBugf("could not cast currentByteSize to uint64: %v", err)
+ }
+
+ if currentByteSize > ch.maxByteSize {
+ return datastore.NewMaximumChangesSizeExceededError(ch.maxByteSize)
+ }
+
+ return nil
+}
+
+// SetRevisionMetadata sets the metadata for the given revision.
+func (ch *Changes[R, K]) SetRevisionMetadata(ctx context.Context, rev R, metadata map[string]any) error {
+ if len(metadata) == 0 {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ if len(record.metadata) > 0 {
+ return spiceerrors.MustBugf("metadata already set for revision")
+ }
+
+ maps.Copy(record.metadata, metadata)
+ return nil
+}
+
+func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) {
+ k := ch.keyFunc(rev)
+ revisionChanges, ok := ch.records[k]
+ if !ok {
+ revisionChanges = changeRecord[R]{
+ rev,
+ make(map[string]tuple.Relationship),
+ make(map[string]tuple.Relationship),
+ make(map[string]datastore.SchemaDefinition),
+ make(map[string]struct{}),
+ make(map[string]struct{}),
+ make(map[string]any),
+ }
+ ch.records[k] = revisionChanges
+ }
+
+ return revisionChanges, nil
+}
+
+// AddDeletedNamespace adds a change indicating that the namespace with the name was deleted.
+func (ch *Changes[R, K]) AddDeletedNamespace(
+ _ context.Context,
+ rev R,
+ namespaceName string,
+) error {
+ if ch.content&datastore.WatchSchema != datastore.WatchSchema {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ // if a delete happens in the same transaction as a change, we assume it was a change in the first place
+ // because that's how namespace changes are implemented in the MVCC
+ if _, ok := record.definitionsChanged[nsPrefix+namespaceName]; ok {
+ return nil
+ }
+
+ delete(record.definitionsChanged, nsPrefix+namespaceName)
+ record.namespacesDeleted[namespaceName] = struct{}{}
+ return nil
+}
+
+// AddDeletedCaveat adds a change indicating that the caveat with the name was deleted.
+func (ch *Changes[R, K]) AddDeletedCaveat(
+ _ context.Context,
+ rev R,
+ caveatName string,
+) error {
+ if ch.content&datastore.WatchSchema != datastore.WatchSchema {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ // if a delete happens in the same transaction as a change, we assume it was a change in the first place
+ // because that's how namespace changes are implemented in the MVCC
+ if _, ok := record.definitionsChanged[caveatPrefix+caveatName]; ok {
+ return nil
+ }
+
+ delete(record.definitionsChanged, caveatPrefix+caveatName)
+ record.caveatsDeleted[caveatName] = struct{}{}
+ return nil
+}
+
+// AddChangedDefinition adds a change indicating that the schema definition (namespace or caveat)
+// was changed to the definition given.
+func (ch *Changes[R, K]) AddChangedDefinition(
+ ctx context.Context,
+ rev R,
+ def datastore.SchemaDefinition,
+) error {
+ if ch.content&datastore.WatchSchema != datastore.WatchSchema {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ switch t := def.(type) {
+ case *core.NamespaceDefinition:
+ delete(record.namespacesDeleted, t.Name)
+
+ if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok {
+ if err := ch.adjustByteSize(existing, -1); err != nil {
+ return err
+ }
+ }
+
+ record.definitionsChanged[nsPrefix+t.Name] = t
+
+ if err := ch.adjustByteSize(t, 1); err != nil {
+ return err
+ }
+
+ case *core.CaveatDefinition:
+ delete(record.caveatsDeleted, t.Name)
+
+ if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok {
+ if err := ch.adjustByteSize(existing, -1); err != nil {
+ return err
+ }
+ }
+
+ record.definitionsChanged[caveatPrefix+t.Name] = t
+
+ if err := ch.adjustByteSize(t, 1); err != nil {
+ return err
+ }
+
+ default:
+ log.Ctx(ctx).Fatal().Msg("unknown schema definition kind")
+ }
+
+ return nil
+}
+
+// AsRevisionChanges returns the list of changes processed so far as a datastore watch
+// compatible, ordered, changelist.
+func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) ([]datastore.RevisionChanges, error) {
+ return ch.revisionChanges(lessThanFunc, *new(R), false)
+}
+
+// FilterAndRemoveRevisionChanges filters a list of changes processed up to the bound revision from the changes list, removing them
+// and returning the filtered changes.
+func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) ([]datastore.RevisionChanges, error) {
+ changes, err := ch.revisionChanges(lessThanFunc, boundRev, true)
+ if err != nil {
+ return nil, err
+ }
+
+ ch.removeAllChangesBefore(boundRev)
+ return changes, nil
+}
+
+func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) ([]datastore.RevisionChanges, error) {
+ if ch.IsEmpty() {
+ return nil, nil
+ }
+
+ revisionsWithChanges := make([]K, 0, len(ch.records))
+ for rk, cr := range ch.records {
+ if !withBound || boundRev.GreaterThan(cr.rev) {
+ revisionsWithChanges = append(revisionsWithChanges, rk)
+ }
+ }
+
+ if len(revisionsWithChanges) == 0 {
+ return nil, nil
+ }
+
+ sort.Slice(revisionsWithChanges, func(i int, j int) bool {
+ return lessThanFunc(revisionsWithChanges[i], revisionsWithChanges[j])
+ })
+
+ changes := make([]datastore.RevisionChanges, len(revisionsWithChanges))
+ for i, k := range revisionsWithChanges {
+ revisionChangeRecord := ch.records[k]
+ changes[i].Revision = revisionChangeRecord.rev
+ for _, rel := range revisionChangeRecord.relTouches {
+ changes[i].RelationshipChanges = append(changes[i].RelationshipChanges, tuple.Touch(rel))
+ }
+ for _, rel := range revisionChangeRecord.relDeletes {
+ changes[i].RelationshipChanges = append(changes[i].RelationshipChanges, tuple.Delete(rel))
+ }
+ changes[i].ChangedDefinitions = maps.Values(revisionChangeRecord.definitionsChanged)
+ changes[i].DeletedNamespaces = maps.Keys(revisionChangeRecord.namespacesDeleted)
+ changes[i].DeletedCaveats = maps.Keys(revisionChangeRecord.caveatsDeleted)
+
+ if len(revisionChangeRecord.metadata) > 0 {
+ metadata, err := structpb.NewStruct(revisionChangeRecord.metadata)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("failed to convert metadata to structpb: %v", err)
+ }
+
+ changes[i].Metadata = metadata
+ }
+ }
+
+ return changes, nil
+}
+
+func (ch *Changes[R, K]) removeAllChangesBefore(boundRev R) {
+ for rk, cr := range ch.records {
+ if boundRev.GreaterThan(cr.rev) {
+ delete(ch.records, rk)
+ }
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go
new file mode 100644
index 0000000..af0b229
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go
@@ -0,0 +1,154 @@
+package common
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "regexp"
+ "strings"
+
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// SerializationError is returned when there's been a serialization
+// error while performing a datastore operation
+type SerializationError struct {
+ error
+}
+
+func (err SerializationError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.Aborted,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_SERIALIZATION_FAILURE,
+ map[string]string{},
+ ),
+ )
+}
+
+func (err SerializationError) Unwrap() error {
+ return err.error
+}
+
+// NewSerializationError creates a new SerializationError
+func NewSerializationError(err error) error {
+ return SerializationError{err}
+}
+
+// ReadOnlyTransactionError is returned when an otherwise read-write
+// transaction fails on writes with an error indicating that the datastore
+// is currently in a read-only mode.
+type ReadOnlyTransactionError struct {
+ error
+}
+
+func (err ReadOnlyTransactionError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.Aborted,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_SERVICE_READ_ONLY,
+ map[string]string{},
+ ),
+ )
+}
+
+// NewReadOnlyTransactionError creates a new ReadOnlyTransactionError.
+func NewReadOnlyTransactionError(err error) error {
+ return ReadOnlyTransactionError{
+ fmt.Errorf("could not perform write operation, as the datastore is currently in read-only mode: %w. This may indicate that the datastore has been put into maintenance mode", err),
+ }
+}
+
+// CreateRelationshipExistsError is an error returned when attempting to CREATE an already-existing
+// relationship.
+type CreateRelationshipExistsError struct {
+ error
+
+ // Relationship is the relationship that caused the error. May be nil, depending on the datastore.
+ Relationship *tuple.Relationship
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err CreateRelationshipExistsError) GRPCStatus() *status.Status {
+ if err.Relationship == nil {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.AlreadyExists,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_ATTEMPT_TO_RECREATE_RELATIONSHIP,
+ map[string]string{},
+ ),
+ )
+ }
+
+ relationship := tuple.ToV1Relationship(*err.Relationship)
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.AlreadyExists,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_ATTEMPT_TO_RECREATE_RELATIONSHIP,
+ map[string]string{
+ "relationship": tuple.V1StringRelationshipWithoutCaveatOrExpiration(relationship),
+ "resource_type": relationship.Resource.ObjectType,
+ "resource_object_id": relationship.Resource.ObjectId,
+ "resource_relation": relationship.Relation,
+ "subject_type": relationship.Subject.Object.ObjectType,
+ "subject_object_id": relationship.Subject.Object.ObjectId,
+ "subject_relation": relationship.Subject.OptionalRelation,
+ },
+ ),
+ )
+}
+
+// NewCreateRelationshipExistsError creates a new CreateRelationshipExistsError.
+func NewCreateRelationshipExistsError(relationship *tuple.Relationship) error {
+ msg := "could not CREATE one or more relationships, as they already existed. If this is persistent, please switch to TOUCH operations or specify a precondition"
+ if relationship != nil {
+ msg = fmt.Sprintf("could not CREATE relationship `%s`, as it already existed. If this is persistent, please switch to TOUCH operations or specify a precondition", tuple.StringWithoutCaveatOrExpiration(*relationship))
+ }
+
+ return CreateRelationshipExistsError{
+ errors.New(msg),
+ relationship,
+ }
+}
+
+var (
+ portMatchRegex = regexp.MustCompile("invalid port \\\"(.+)\\\" after host")
+ parseMatchRegex = regexp.MustCompile("parse \\\"(.+)\\\":")
+)
+
+// RedactAndLogSensitiveConnString elides the given error, logging it only at trace
+// level (after being redacted).
+func RedactAndLogSensitiveConnString(ctx context.Context, baseErr string, err error, pgURL string) error {
+ if err == nil {
+ return errors.New(baseErr)
+ }
+
+ // See: https://github.com/jackc/pgx/issues/1271
+ filtered := err.Error()
+ filtered = strings.ReplaceAll(filtered, pgURL, "(redacted)")
+ filtered = portMatchRegex.ReplaceAllString(filtered, "(redacted)")
+ filtered = parseMatchRegex.ReplaceAllString(filtered, "(redacted)")
+ log.Ctx(ctx).Trace().Msg(baseErr + ": " + filtered)
+ return fmt.Errorf("%s. To view details of this error (that may contain sensitive information), please run with --log-level=trace", baseErr)
+}
+
+// RevisionUnavailableError is returned when a revision is not available on a replica.
+type RevisionUnavailableError struct {
+ error
+}
+
+// NewRevisionUnavailableError creates a new RevisionUnavailableError.
+func NewRevisionUnavailableError(err error) error {
+ return RevisionUnavailableError{err}
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go
new file mode 100644
index 0000000..5788134
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go
@@ -0,0 +1,269 @@
+package common
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/rs/zerolog"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+var (
+ gcDurationHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_duration_seconds",
+ Help: "The duration of datastore garbage collection.",
+ Buckets: []float64{0.01, 0.1, 0.5, 1, 5, 10, 25, 60, 120},
+ })
+
+ gcRelationshipsCounter = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_relationships_total",
+ Help: "The number of stale relationships deleted by the datastore garbage collection.",
+ })
+
+ gcExpiredRelationshipsCounter = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_expired_relationships_total",
+ Help: "The number of expired relationships deleted by the datastore garbage collection.",
+ })
+
+ gcTransactionsCounter = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_transactions_total",
+ Help: "The number of stale transactions deleted by the datastore garbage collection.",
+ })
+
+ gcNamespacesCounter = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_namespaces_total",
+ Help: "The number of stale namespaces deleted by the datastore garbage collection.",
+ })
+
+ gcFailureCounterConfig = prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_failure_total",
+ Help: "The number of failed runs of the datastore garbage collection.",
+ }
+ gcFailureCounter = prometheus.NewCounter(gcFailureCounterConfig)
+)
+
+// RegisterGCMetrics registers garbage collection metrics to the default
+// registry.
+func RegisterGCMetrics() error {
+ for _, metric := range []prometheus.Collector{
+ gcDurationHistogram,
+ gcRelationshipsCounter,
+ gcTransactionsCounter,
+ gcNamespacesCounter,
+ gcFailureCounter,
+ } {
+ if err := prometheus.Register(metric); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// GarbageCollector represents any datastore that supports external garbage
+// collection.
+type GarbageCollector interface {
+ // HasGCRun returns true if a garbage collection run has been completed.
+ HasGCRun() bool
+
+ // MarkGCCompleted marks that a garbage collection run has been completed.
+ MarkGCCompleted()
+
+ // ResetGCCompleted resets the state of the garbage collection run.
+ ResetGCCompleted()
+
+ // LockForGCRun attempts to acquire a lock for garbage collection. This lock
+ // is typically done at the datastore level, to ensure that no other nodes are
+ // running garbage collection at the same time.
+ LockForGCRun(ctx context.Context) (bool, error)
+
+ // UnlockAfterGCRun releases the lock after a garbage collection run.
+ // NOTE: this method does not take a context, as the context used for the
+ // reset of the GC run can be canceled/timed out and the unlock will still need to happen.
+ UnlockAfterGCRun() error
+
+ // ReadyState returns the current state of the datastore.
+ ReadyState(context.Context) (datastore.ReadyState, error)
+
+ // Now returns the current time from the datastore.
+ Now(context.Context) (time.Time, error)
+
+ // TxIDBefore returns the highest transaction ID before the provided time.
+ TxIDBefore(context.Context, time.Time) (datastore.Revision, error)
+
+ // DeleteBeforeTx deletes all data before the provided transaction ID.
+ DeleteBeforeTx(ctx context.Context, txID datastore.Revision) (DeletionCounts, error)
+
+ // DeleteExpiredRels deletes all relationships that have expired.
+ DeleteExpiredRels(ctx context.Context) (int64, error)
+}
+
+// DeletionCounts tracks the amount of deletions that occurred when calling
+// DeleteBeforeTx.
+type DeletionCounts struct {
+ Relationships int64
+ Transactions int64
+ Namespaces int64
+}
+
+func (g DeletionCounts) MarshalZerologObject(e *zerolog.Event) {
+ e.
+ Int64("relationships", g.Relationships).
+ Int64("transactions", g.Transactions).
+ Int64("namespaces", g.Namespaces)
+}
+
+var MaxGCInterval = 60 * time.Minute
+
+// StartGarbageCollector loops forever until the context is canceled and
+// performs garbage collection on the provided interval.
+func StartGarbageCollector(ctx context.Context, gc GarbageCollector, interval, window, timeout time.Duration) error {
+ return startGarbageCollectorWithMaxElapsedTime(ctx, gc, interval, window, 0, timeout, gcFailureCounter)
+}
+
+func startGarbageCollectorWithMaxElapsedTime(ctx context.Context, gc GarbageCollector, interval, window, maxElapsedTime, timeout time.Duration, failureCounter prometheus.Counter) error {
+ backoffInterval := backoff.NewExponentialBackOff()
+ backoffInterval.InitialInterval = interval
+ backoffInterval.MaxInterval = max(MaxGCInterval, interval)
+ backoffInterval.MaxElapsedTime = maxElapsedTime
+ backoffInterval.Reset()
+
+ nextInterval := interval
+
+ log.Ctx(ctx).Info().
+ Dur("interval", nextInterval).
+ Msg("datastore garbage collection worker started")
+
+ for {
+ select {
+ case <-ctx.Done():
+ log.Ctx(ctx).Info().
+ Msg("shutting down datastore garbage collection worker")
+ return ctx.Err()
+
+ case <-time.After(nextInterval):
+ log.Ctx(ctx).Info().
+ Dur("interval", nextInterval).
+ Dur("window", window).
+ Dur("timeout", timeout).
+ Msg("running garbage collection worker")
+
+ err := RunGarbageCollection(gc, window, timeout)
+ if err != nil {
+ failureCounter.Inc()
+ nextInterval = backoffInterval.NextBackOff()
+ log.Ctx(ctx).Warn().Err(err).
+ Dur("next-attempt-in", nextInterval).
+ Msg("error attempting to perform garbage collection")
+ continue
+ }
+
+ backoffInterval.Reset()
+ nextInterval = interval
+
+ log.Ctx(ctx).Debug().
+ Dur("next-run-in", interval).
+ Msg("datastore garbage collection scheduled for next run")
+ }
+ }
+}
+
+// RunGarbageCollection runs garbage collection for the datastore.
+func RunGarbageCollection(gc GarbageCollector, window, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+
+ ctx, span := tracer.Start(ctx, "RunGarbageCollection")
+ defer span.End()
+
+ // Before attempting anything, check if the datastore is ready.
+ startTime := time.Now()
+ ready, err := gc.ReadyState(ctx)
+ if err != nil {
+ return err
+ }
+ if !ready.IsReady {
+ log.Ctx(ctx).Warn().
+ Msgf("datastore wasn't ready when attempting garbage collection: %s", ready.Message)
+ return nil
+ }
+
+ ok, err := gc.LockForGCRun(ctx)
+ if err != nil {
+ return fmt.Errorf("error locking for gc run: %w", err)
+ }
+
+ if !ok {
+ log.Info().
+ Msg("datastore garbage collection already in progress on another node")
+ return nil
+ }
+
+ defer func() {
+ err := gc.UnlockAfterGCRun()
+ if err != nil {
+ log.Error().
+ Err(err).
+ Msg("error unlocking after gc run")
+ }
+ }()
+
+ now, err := gc.Now(ctx)
+ if err != nil {
+ return fmt.Errorf("error retrieving now: %w", err)
+ }
+
+ watermark, err := gc.TxIDBefore(ctx, now.Add(-1*window))
+ if err != nil {
+ return fmt.Errorf("error retrieving watermark: %w", err)
+ }
+
+ collected, err := gc.DeleteBeforeTx(ctx, watermark)
+
+ expiredRelationshipsCount, eerr := gc.DeleteExpiredRels(ctx)
+
+ // even if an error happened, garbage would have been collected. This makes sure these are reflected even if the
+ // worker eventually fails or times out.
+ gcRelationshipsCounter.Add(float64(collected.Relationships))
+ gcTransactionsCounter.Add(float64(collected.Transactions))
+ gcNamespacesCounter.Add(float64(collected.Namespaces))
+ gcExpiredRelationshipsCounter.Add(float64(expiredRelationshipsCount))
+ collectionDuration := time.Since(startTime)
+ gcDurationHistogram.Observe(collectionDuration.Seconds())
+
+ if err != nil {
+ return fmt.Errorf("error deleting in gc: %w", err)
+ }
+
+ if eerr != nil {
+ return fmt.Errorf("error deleting expired relationships in gc: %w", eerr)
+ }
+
+ log.Ctx(ctx).Info().
+ Stringer("highestTxID", watermark).
+ Dur("duration", collectionDuration).
+ Time("nowTime", now).
+ Interface("collected", collected).
+ Int64("expiredRelationships", expiredRelationshipsCount).
+ Msg("datastore garbage collection completed successfully")
+
+ gc.MarkGCCompleted()
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go
new file mode 100644
index 0000000..8f34134
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go
@@ -0,0 +1,49 @@
+package common
+
+import (
+ "context"
+ "fmt"
+
+ "google.golang.org/protobuf/types/known/structpb"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// WriteRelationships is a convenience method to perform the same update operation on a set of relationships
+func WriteRelationships(ctx context.Context, ds datastore.Datastore, op tuple.UpdateOperation, rels ...tuple.Relationship) (datastore.Revision, error) {
+ updates := make([]tuple.RelationshipUpdate, 0, len(rels))
+ for _, rel := range rels {
+ ru := tuple.RelationshipUpdate{
+ Operation: op,
+ Relationship: rel,
+ }
+ updates = append(updates, ru)
+ }
+ return UpdateRelationshipsInDatastore(ctx, ds, updates...)
+}
+
+// UpdateRelationshipsInDatastore is a convenience method to perform multiple relation update operations on a Datastore
+func UpdateRelationshipsInDatastore(ctx context.Context, ds datastore.Datastore, updates ...tuple.RelationshipUpdate) (datastore.Revision, error) {
+ return ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ return rwt.WriteRelationships(ctx, updates)
+ })
+}
+
+// ContextualizedCaveatFrom convenience method that handles creation of a contextualized caveat
+// given the possibility of arguments with zero-values.
+func ContextualizedCaveatFrom(name string, context map[string]any) (*core.ContextualizedCaveat, error) {
+ var caveat *core.ContextualizedCaveat
+ if name != "" {
+ strct, err := structpb.NewStruct(context)
+ if err != nil {
+ return nil, fmt.Errorf("malformed caveat context: %w", err)
+ }
+ caveat = &core.ContextualizedCaveat{
+ CaveatName: name,
+ Context: strct,
+ }
+ }
+ return caveat, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go
new file mode 100644
index 0000000..1eb64d1
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go
@@ -0,0 +1,28 @@
+package common
+
+import "github.com/authzed/spicedb/pkg/datastore/queryshape"
+
+// IndexDefinition is a definition of an index for a datastore.
+type IndexDefinition struct {
+ // Name is the unique name for the index.
+ Name string
+
+ // ColumnsSQL is the SQL fragment of the columns over which this index will apply.
+ ColumnsSQL string
+
+ // Shapes are those query shapes for which this index should be used.
+ Shapes []queryshape.Shape
+
+ // IsDeprecated is true if this index is deprecated and should not be used.
+ IsDeprecated bool
+}
+
+// matchesShape returns true if the index matches the given shape.
+func (id IndexDefinition) matchesShape(shape queryshape.Shape) bool {
+ for _, s := range id.Shapes {
+ if s == shape {
+ return true
+ }
+ }
+ return false
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go
new file mode 100644
index 0000000..6e84549
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go
@@ -0,0 +1,15 @@
+package common
+
+import (
+ "context"
+
+ log "github.com/authzed/spicedb/internal/logging"
+)
+
+// LogOnError executes the function and logs the error.
+// Useful to avoid silently ignoring errors in defer statements
+func LogOnError(ctx context.Context, f func() error) {
+ if err := f(); err != nil {
+ log.Ctx(ctx).Err(err).Msg("datastore error")
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go
new file mode 100644
index 0000000..304f62c
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go
@@ -0,0 +1,42 @@
+package common
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+type MigrationValidator struct {
+ additionalAllowedMigrations []string
+ headMigration string
+}
+
+func NewMigrationValidator(headMigration string, additionalAllowedMigrations []string) *MigrationValidator {
+ return &MigrationValidator{
+ additionalAllowedMigrations: additionalAllowedMigrations,
+ headMigration: headMigration,
+ }
+}
+
+// MigrationReadyState returns the readiness of the datastore for the given version.
+func (mv *MigrationValidator) MigrationReadyState(version string) datastore.ReadyState {
+ if version == mv.headMigration {
+ return datastore.ReadyState{IsReady: true}
+ }
+ if slices.Contains(mv.additionalAllowedMigrations, version) {
+ return datastore.ReadyState{IsReady: true}
+ }
+ var msgBuilder strings.Builder
+ msgBuilder.WriteString(fmt.Sprintf("datastore is not migrated: currently at revision %q, but requires %q", version, mv.headMigration))
+
+ if len(mv.additionalAllowedMigrations) > 0 {
+ msgBuilder.WriteString(fmt.Sprintf(" (additional allowed migrations: %v)", mv.additionalAllowedMigrations))
+ }
+ msgBuilder.WriteString(". Please run \"spicedb datastore migrate\".")
+ return datastore.ReadyState{
+ Message: msgBuilder.String(),
+ IsReady: false,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go
new file mode 100644
index 0000000..dee0ad5
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go
@@ -0,0 +1,214 @@
+package common
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "time"
+
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+const errUnableToQueryRels = "unable to query relationships: %w"
+
+// Querier is an interface for querying the database.
+type Querier[R Rows] interface {
+ QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error
+}
+
+// Rows is a common interface for database rows reading.
+type Rows interface {
+ Scan(dest ...any) error
+ Next() bool
+ Err() error
+}
+
+type closeRowsWithError interface {
+ Rows
+ Close() error
+}
+
+type closeRows interface {
+ Rows
+ Close()
+}
+
+func runExplainIfNecessary[R Rows](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R], explainable datastore.Explainable) error {
+ if builder.SQLExplainCallbackForTest == nil {
+ return nil
+ }
+
+ // Determine the expected index names via the schema.
+ expectedIndexes := builder.Schema.expectedIndexesForShape(builder.queryShape)
+
+ // Run any pre-explain statements.
+ for _, statement := range explainable.PreExplainStatements() {
+ if err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error {
+ rows.Next()
+ return nil
+ }, statement); err != nil {
+ return fmt.Errorf(errUnableToQueryRels, err)
+ }
+ }
+
+ // Run the query with EXPLAIN ANALYZE.
+ sqlString, args, err := builder.SelectSQL()
+ if err != nil {
+ return fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ explainSQL, explainArgs, err := explainable.BuildExplainQuery(sqlString, args)
+ if err != nil {
+ return fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ err = tx.QueryFunc(ctx, func(ctx context.Context, rows R) error {
+ explainString := ""
+ for rows.Next() {
+ var explain string
+ if err := rows.Scan(&explain); err != nil {
+ return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err))
+ }
+ explainString += explain + "\n"
+ }
+ if explainString == "" {
+ return fmt.Errorf("received empty explain")
+ }
+
+ return builder.SQLExplainCallbackForTest(ctx, sqlString, args, builder.queryShape, explainString, expectedIndexes)
+ }, explainSQL, explainArgs...)
+ if err != nil {
+ return fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ return nil
+}
+
+// QueryRelationships queries relationships for the given query and transaction.
+func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R], explainable datastore.Explainable) (datastore.RelationshipIterator, error) {
+ span := trace.SpanFromContext(ctx)
+ sqlString, args, err := builder.SelectSQL()
+ if err != nil {
+ return nil, fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ if err := runExplainIfNecessary(ctx, builder, tx, explainable); err != nil {
+ return nil, err
+ }
+
+ var resourceObjectType string
+ var resourceObjectID string
+ var resourceRelation string
+ var subjectObjectType string
+ var subjectObjectID string
+ var subjectRelation string
+ var caveatName sql.NullString
+ var caveatCtx C
+ var expiration *time.Time
+
+ var integrityKeyID string
+ var integrityHash []byte
+ var timestamp time.Time
+
+ span.AddEvent("Selecting columns")
+ colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, &timestamp)
+ if err != nil {
+ return nil, fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ span.AddEvent("Returning iterator", trace.WithAttributes(attribute.Int("column-count", len(colsToSelect))))
+ return func(yield func(tuple.Relationship, error) bool) {
+ span.AddEvent("Issuing query to database")
+ err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error {
+ span.AddEvent("Query issued to database")
+
+ var r Rows = rows
+ if crwe, ok := r.(closeRowsWithError); ok {
+ defer LogOnError(ctx, crwe.Close)
+ } else if cr, ok := r.(closeRows); ok {
+ defer cr.Close()
+ }
+
+ relCount := 0
+ for rows.Next() {
+ if relCount == 0 {
+ span.AddEvent("First row returned")
+ }
+
+ if err := rows.Scan(colsToSelect...); err != nil {
+ return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err))
+ }
+
+ if relCount == 0 {
+ span.AddEvent("First row scanned")
+ }
+
+ var caveat *corev1.ContextualizedCaveat
+ if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone {
+ if caveatName.Valid {
+ var err error
+ caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx)
+ if err != nil {
+ return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err))
+ }
+ }
+ }
+
+ var integrity *corev1.RelationshipIntegrity
+ if integrityKeyID != "" {
+ integrity = &corev1.RelationshipIntegrity{
+ KeyId: integrityKeyID,
+ Hash: integrityHash,
+ HashedAt: timestamppb.New(timestamp),
+ }
+ }
+
+ if expiration != nil {
+ // Ensure the expiration is always read in UTC, since some datastores (like CRDB)
+ // will normalize to local time.
+ t := expiration.UTC()
+ expiration = &t
+ }
+
+ relCount++
+ if !yield(tuple.Relationship{
+ RelationshipReference: tuple.RelationshipReference{
+ Resource: tuple.ObjectAndRelation{
+ ObjectType: resourceObjectType,
+ ObjectID: resourceObjectID,
+ Relation: resourceRelation,
+ },
+ Subject: tuple.ObjectAndRelation{
+ ObjectType: subjectObjectType,
+ ObjectID: subjectObjectID,
+ Relation: subjectRelation,
+ },
+ },
+ OptionalCaveat: caveat,
+ OptionalExpiration: expiration,
+ OptionalIntegrity: integrity,
+ }, nil) {
+ return nil
+ }
+ }
+
+ span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount)))
+ if err := rows.Err(); err != nil {
+ return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err))
+ }
+
+ return nil
+ }, sqlString, args...)
+ if err != nil {
+ if !yield(tuple.Relationship{}, err) {
+ return
+ }
+ }
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go
new file mode 100644
index 0000000..6e44d0b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go
@@ -0,0 +1,188 @@
+package common
+
+import (
+ sq "github.com/Masterminds/squirrel"
+
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+const (
+ relationshipStandardColumnCount = 6 // ColNamespace, ColObjectID, ColRelation, ColUsersetNamespace, ColUsersetObjectID, ColUsersetRelation
+ relationshipCaveatColumnCount = 2 // ColCaveatName, ColCaveatContext
+ relationshipExpirationColumnCount = 1 // ColExpiration
+ relationshipIntegrityColumnCount = 3 // ColIntegrityKeyID, ColIntegrityHash, ColIntegrityTimestamp
+)
+
+// SchemaInformation holds the schema information from the SQL datastore implementation.
+//
+//go:generate go run github.com/ecordell/optgen -output zz_generated.schema_options.go . SchemaInformation
+type SchemaInformation struct {
+ RelationshipTableName string `debugmap:"visible"`
+
+ ColNamespace string `debugmap:"visible"`
+ ColObjectID string `debugmap:"visible"`
+ ColRelation string `debugmap:"visible"`
+ ColUsersetNamespace string `debugmap:"visible"`
+ ColUsersetObjectID string `debugmap:"visible"`
+ ColUsersetRelation string `debugmap:"visible"`
+
+ ColCaveatName string `debugmap:"visible"`
+ ColCaveatContext string `debugmap:"visible"`
+
+ ColExpiration string `debugmap:"visible"`
+
+ ColIntegrityKeyID string `debugmap:"visible"`
+ ColIntegrityHash string `debugmap:"visible"`
+ ColIntegrityTimestamp string `debugmap:"visible"`
+
+ // Indexes are the indexes to use for this schema.
+ Indexes []IndexDefinition `debugmap:"visible"`
+
+ // PaginationFilterType is the type of pagination filter to use for this schema.
+ PaginationFilterType PaginationFilterType `debugmap:"visible"`
+
+ // PlaceholderFormat is the format of placeholders to use for this schema.
+ PlaceholderFormat sq.PlaceholderFormat `debugmap:"visible"`
+
+ // NowFunction is the function to use to get the current time in the datastore.
+ NowFunction string `debugmap:"visible"`
+
+ // ColumnOptimization is the optimization to use for columns in the schema, if any.
+ ColumnOptimization ColumnOptimizationOption `debugmap:"visible"`
+
+ // IntegrityEnabled is a flag to indicate if the schema has integrity columns.
+ IntegrityEnabled bool `debugmap:"visible"`
+
+ // ExpirationDisabled is a flag to indicate whether expiration support is disabled.
+ ExpirationDisabled bool `debugmap:"visible"`
+
+ // SortByResourceColumnOrder is the order of the resource columns in the schema to use
+ // when sorting by resource. If unspecified, the default will be used.
+ SortByResourceColumnOrder []string `debugmap:"visible"`
+
+ // SortBySubjectColumnOrder is the order of the subject columns in the schema to use
+ // when sorting by subject. If unspecified, the default will be used.
+ SortBySubjectColumnOrder []string `debugmap:"visible"`
+}
+
+// expectedIndexesForShape returns the expected index names for a given query shape.
+func (si SchemaInformation) expectedIndexesForShape(shape queryshape.Shape) options.SQLIndexInformation {
+ expectedIndexes := options.SQLIndexInformation{}
+ for _, index := range si.Indexes {
+ if index.matchesShape(shape) {
+ expectedIndexes.ExpectedIndexNames = append(expectedIndexes.ExpectedIndexNames, index.Name)
+ }
+ }
+ return expectedIndexes
+}
+
+func (si SchemaInformation) debugValidate() {
+ spiceerrors.DebugAssert(func() bool {
+ si.mustValidate()
+ return true
+ }, "SchemaInformation failed to validate")
+}
+
+func (si SchemaInformation) sortByResourceColumnOrderColumns() []string {
+ if len(si.SortByResourceColumnOrder) > 0 {
+ return si.SortByResourceColumnOrder
+ }
+
+ return []string{
+ si.ColNamespace,
+ si.ColObjectID,
+ si.ColRelation,
+ si.ColUsersetNamespace,
+ si.ColUsersetObjectID,
+ si.ColUsersetRelation,
+ }
+}
+
+func (si SchemaInformation) sortBySubjectColumnOrderColumns() []string {
+ if len(si.SortBySubjectColumnOrder) > 0 {
+ return si.SortBySubjectColumnOrder
+ }
+
+ return []string{
+ si.ColUsersetNamespace,
+ si.ColUsersetObjectID,
+ si.ColUsersetRelation,
+ si.ColNamespace,
+ si.ColObjectID,
+ si.ColRelation,
+ }
+}
+
+func (si SchemaInformation) mustValidate() {
+ if si.RelationshipTableName == "" {
+ panic("RelationshipTableName is required")
+ }
+
+ if si.ColNamespace == "" {
+ panic("ColNamespace is required")
+ }
+
+ if si.ColObjectID == "" {
+ panic("ColObjectID is required")
+ }
+
+ if si.ColRelation == "" {
+ panic("ColRelation is required")
+ }
+
+ if si.ColUsersetNamespace == "" {
+ panic("ColUsersetNamespace is required")
+ }
+
+ if si.ColUsersetObjectID == "" {
+ panic("ColUsersetObjectID is required")
+ }
+
+ if si.ColUsersetRelation == "" {
+ panic("ColUsersetRelation is required")
+ }
+
+ if si.ColCaveatName == "" {
+ panic("ColCaveatName is required")
+ }
+
+ if si.ColCaveatContext == "" {
+ panic("ColCaveatContext is required")
+ }
+
+ if si.ColExpiration == "" {
+ panic("ColExpiration is required")
+ }
+
+ if si.IntegrityEnabled {
+ if si.ColIntegrityKeyID == "" {
+ panic("ColIntegrityKeyID is required")
+ }
+
+ if si.ColIntegrityHash == "" {
+ panic("ColIntegrityHash is required")
+ }
+
+ if si.ColIntegrityTimestamp == "" {
+ panic("ColIntegrityTimestamp is required")
+ }
+ }
+
+ if si.NowFunction == "" {
+ panic("NowFunction is required")
+ }
+
+ if si.ColumnOptimization == ColumnOptimizationOptionUnknown {
+ panic("ColumnOptimization is required")
+ }
+
+ if si.PaginationFilterType == PaginationFilterTypeUnknown {
+ panic("PaginationFilterType is required")
+ }
+
+ if si.PlaceholderFormat == nil {
+ panic("PlaceholderFormat is required")
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go
new file mode 100644
index 0000000..4972700
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go
@@ -0,0 +1,17 @@
+package common
+
+import (
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// NewSliceRelationshipIterator creates a datastore.RelationshipIterator instance from a materialized slice of tuples.
+func NewSliceRelationshipIterator(rels []tuple.Relationship) datastore.RelationshipIterator {
+ return func(yield func(tuple.Relationship, error) bool) {
+ for _, rel := range rels {
+ if !yield(rel, nil) {
+ break
+ }
+ }
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go
new file mode 100644
index 0000000..ba9c4f6
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go
@@ -0,0 +1,961 @@
+package common
+
+import (
+ "context"
+ "fmt"
+ "maps"
+ "math"
+ "strings"
+ "time"
+
+ sq "github.com/Masterminds/squirrel"
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/jzelinskie/stringz"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+var (
+ // CaveatNameKey is a tracing attribute representing a caveat name
+ CaveatNameKey = attribute.Key("authzed.com/spicedb/sql/caveatName")
+
+ // ObjNamespaceNameKey is a tracing attribute representing the resource
+ // object type.
+ ObjNamespaceNameKey = attribute.Key("authzed.com/spicedb/sql/objNamespaceName")
+
+ // ObjRelationNameKey is a tracing attribute representing the resource
+ // relation.
+ ObjRelationNameKey = attribute.Key("authzed.com/spicedb/sql/objRelationName")
+
+ // ObjIDKey is a tracing attribute representing the resource object ID.
+ ObjIDKey = attribute.Key("authzed.com/spicedb/sql/objId")
+
+ // SubNamespaceNameKey is a tracing attribute representing the subject object
+ // type.
+ SubNamespaceNameKey = attribute.Key("authzed.com/spicedb/sql/subNamespaceName")
+
+ // SubRelationNameKey is a tracing attribute representing the subject
+ // relation.
+ SubRelationNameKey = attribute.Key("authzed.com/spicedb/sql/subRelationName")
+
+ // SubObjectIDKey is a tracing attribute representing the the subject object
+ // ID.
+ SubObjectIDKey = attribute.Key("authzed.com/spicedb/sql/subObjectId")
+
+ tracer = otel.Tracer("spicedb/internal/datastore/common")
+)
+
+// PaginationFilterType is an enumerator for pagination filter types.
+type PaginationFilterType uint8
+
+const (
+ PaginationFilterTypeUnknown PaginationFilterType = iota
+
+ // TupleComparison uses a comparison with a compound key,
+ // e.g. (namespace, object_id, relation) > ('ns', '123', 'viewer')
+ // which is not compatible with all datastores.
+ TupleComparison = 1
+
+ // ExpandedLogicComparison comparison uses a nested tree of ANDs and ORs to properly
+ // filter out already received relationships. Useful for databases that do not support
+ // tuple comparison, or do not execute it efficiently
+ ExpandedLogicComparison = 2
+)
+
+// ColumnOptimizationOption is an enumerator for column optimization options.
+type ColumnOptimizationOption int
+
+const (
+ ColumnOptimizationOptionUnknown ColumnOptimizationOption = iota
+
+ // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns.
+ ColumnOptimizationOptionNone
+
+ // ColumnOptimizationOptionStaticValues is an option that optimizes columns for static values.
+ ColumnOptimizationOptionStaticValues
+)
+
+type columnTracker struct {
+ SingleValue *string
+}
+
+type columnTrackerMap map[string]columnTracker
+
+func (ctm columnTrackerMap) hasStaticValue(columnName string) bool {
+ if r, ok := ctm[columnName]; ok && r.SingleValue != nil {
+ return true
+ }
+ return false
+}
+
+// SchemaQueryFilterer wraps a SchemaInformation and SelectBuilder to give an opinionated
+// way to build query objects.
+type SchemaQueryFilterer struct {
+ schema SchemaInformation
+ queryBuilder sq.SelectBuilder
+ filteringColumnTracker columnTrackerMap
+ filterMaximumIDCount uint16
+ isCustomQuery bool
+ extraFields []string
+ fromSuffix string
+ fromTable string
+ indexingHint IndexingHint
+}
+
+// IndexingHint is an interface that can be implemented to provide a hint for the SQL query.
+type IndexingHint interface {
+ // SQLPrefix returns the SQL prefix to be used for the indexing hint, if any.
+ SQLPrefix() (string, error)
+
+ // FromTable returns the table name to be used for the indexing hint, if any.
+ FromTable(existingTableName string) (string, error)
+
+ // FromSQLSuffix returns the suffix to be used for the indexing hint, if any.
+ FromSQLSuffix() (string, error)
+}
+
+// NewSchemaQueryFiltererForRelationshipsSelect creates a new SchemaQueryFilterer object for selecting
+// relationships. This method will automatically filter the columns retrieved from the database, only
+// selecting the columns that are not already specified with a single static value in the query.
+func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer {
+ schema.debugValidate()
+
+ if filterMaximumIDCount == 0 {
+ filterMaximumIDCount = 100
+ log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100")
+ }
+
+ queryBuilder := sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFormat).Select()
+ return SchemaQueryFilterer{
+ schema: schema,
+ queryBuilder: queryBuilder,
+ filteringColumnTracker: map[string]columnTracker{},
+ filterMaximumIDCount: filterMaximumIDCount,
+ isCustomQuery: false,
+ extraFields: extraFields,
+ fromTable: "",
+ }
+}
+
+// NewSchemaQueryFiltererWithStartingQuery creates a new SchemaQueryFilterer object for selecting
+// relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect,
+// this method will not auto-filter the columns retrieved from the database.
+func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer {
+ schema.debugValidate()
+
+ if filterMaximumIDCount == 0 {
+ filterMaximumIDCount = 100
+ log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100")
+ }
+
+ return SchemaQueryFilterer{
+ schema: schema,
+ queryBuilder: startingQuery,
+ filteringColumnTracker: map[string]columnTracker{},
+ filterMaximumIDCount: filterMaximumIDCount,
+ isCustomQuery: true,
+ extraFields: nil,
+ fromTable: "",
+ }
+}
+
+// WithAdditionalFilter returns the SchemaQueryFilterer with an additional filter applied to the query.
+func (sqf SchemaQueryFilterer) WithAdditionalFilter(filter func(original sq.SelectBuilder) sq.SelectBuilder) SchemaQueryFilterer {
+ sqf.queryBuilder = filter(sqf.queryBuilder)
+ return sqf
+}
+
+// WithFromTable returns the SchemaQueryFilterer with a custom FROM table.
+func (sqf SchemaQueryFilterer) WithFromTable(fromTable string) SchemaQueryFilterer {
+ sqf.fromTable = fromTable
+ return sqf
+}
+
+// WithFromSuffix returns the SchemaQueryFilterer with a suffix added to the FROM clause.
+func (sqf SchemaQueryFilterer) WithFromSuffix(fromSuffix string) SchemaQueryFilterer {
+ sqf.fromSuffix = fromSuffix
+ return sqf
+}
+
+// WithIndexingHint returns the SchemaQueryFilterer with an indexing hint applied to the query.
+func (sqf SchemaQueryFilterer) WithIndexingHint(indexingHint IndexingHint) SchemaQueryFilterer {
+ sqf.indexingHint = indexingHint
+ return sqf
+}
+
+func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder {
+ spiceerrors.DebugAssert(func() bool {
+ return sqf.isCustomQuery
+ }, "UnderlyingQueryBuilder should only be called on custom queries")
+ return sqf.queryBuilderWithMaybeExpirationFilter(false)
+}
+
+// queryBuilderWithMaybeExpirationFilter returns the query builder with the expiration filter applied, when necessary.
+// Note that this adds the clause to the existing builder.
+func (sqf SchemaQueryFilterer) queryBuilderWithMaybeExpirationFilter(skipExpiration bool) sq.SelectBuilder {
+ if sqf.schema.ExpirationDisabled || skipExpiration {
+ return sqf.queryBuilder
+ }
+
+ // Filter out any expired relationships.
+ return sqf.queryBuilder.Where(sq.Or{
+ sq.Eq{sqf.schema.ColExpiration: nil},
+ sq.Expr(sqf.schema.ColExpiration + " > " + sqf.schema.NowFunction + "()"),
+ })
+}
+
+func (sqf SchemaQueryFilterer) TupleOrder(order options.SortOrder) SchemaQueryFilterer {
+ switch order {
+ case options.ByResource:
+ sqf.queryBuilder = sqf.queryBuilder.OrderBy(sqf.schema.sortByResourceColumnOrderColumns()...)
+
+ case options.BySubject:
+ sqf.queryBuilder = sqf.queryBuilder.OrderBy(sqf.schema.sortBySubjectColumnOrderColumns()...)
+ }
+
+ return sqf
+}
+
+type nameAndValue struct {
+ name string
+ value string
+}
+
+func columnsAndValuesForSort(
+ order options.SortOrder,
+ schema SchemaInformation,
+ cursor options.Cursor,
+) ([]nameAndValue, error) {
+ var columnNames []string
+
+ switch order {
+ case options.ByResource:
+ columnNames = schema.sortByResourceColumnOrderColumns()
+
+ case options.BySubject:
+ columnNames = schema.sortBySubjectColumnOrderColumns()
+
+ default:
+ return nil, spiceerrors.MustBugf("invalid sort order %q", order)
+ }
+
+ nameAndValues := make([]nameAndValue, 0, len(columnNames))
+ for _, columnName := range columnNames {
+ switch columnName {
+ case schema.ColNamespace:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Resource.ObjectType,
+ })
+
+ case schema.ColObjectID:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Resource.ObjectID,
+ })
+
+ case schema.ColRelation:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Resource.Relation,
+ })
+
+ case schema.ColUsersetNamespace:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Subject.ObjectType,
+ })
+
+ case schema.ColUsersetObjectID:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Subject.ObjectID,
+ })
+
+ case schema.ColUsersetRelation:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Subject.Relation,
+ })
+
+ default:
+ return nil, spiceerrors.MustBugf("invalid column name %q", columnName)
+ }
+ }
+
+ return nameAndValues, nil
+}
+
+func (sqf SchemaQueryFilterer) MustAfter(cursor options.Cursor, order options.SortOrder) SchemaQueryFilterer {
+ updated, err := sqf.After(cursor, order)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOrder) (SchemaQueryFilterer, error) {
+ spiceerrors.DebugAssertNotNil(cursor, "cursor cannot be nil")
+
+ // NOTE: The ordering of these columns can affect query performance, be aware when changing.
+ columnsAndValues, err := columnsAndValuesForSort(order, sqf.schema, cursor)
+ if err != nil {
+ return sqf, err
+ }
+
+ switch sqf.schema.PaginationFilterType {
+ case TupleComparison:
+ // For performance reasons, remove any column names that have static values in the query.
+ columnNames := make([]string, 0, len(columnsAndValues))
+ valueSlots := make([]any, 0, len(columnsAndValues))
+ comparisonSlotCount := 0
+
+ for _, cav := range columnsAndValues {
+ if !sqf.filteringColumnTracker.hasStaticValue(cav.name) {
+ columnNames = append(columnNames, cav.name)
+ valueSlots = append(valueSlots, cav.value)
+ comparisonSlotCount++
+ }
+ }
+
+ if comparisonSlotCount > 0 {
+ comparisonTuple := "(" + strings.Join(columnNames, ",") + ") > (" + strings.Repeat(",?", comparisonSlotCount)[1:] + ")"
+ sqf.queryBuilder = sqf.queryBuilder.Where(
+ comparisonTuple,
+ valueSlots...,
+ )
+ }
+
+ case ExpandedLogicComparison:
+ // For performance reasons, remove any column names that have static values in the query.
+ orClause := sq.Or{}
+
+ for index, cav := range columnsAndValues {
+ if !sqf.filteringColumnTracker.hasStaticValue(cav.name) {
+ andClause := sq.And{}
+ for _, previous := range columnsAndValues[0:index] {
+ if !sqf.filteringColumnTracker.hasStaticValue(previous.name) {
+ andClause = append(andClause, sq.Eq{previous.name: previous.value})
+ }
+ }
+
+ andClause = append(andClause, sq.Gt{cav.name: cav.value})
+ orClause = append(orClause, andClause)
+ }
+ }
+
+ if len(orClause) > 0 {
+ sqf.queryBuilder = sqf.queryBuilder.Where(orClause)
+ }
+ }
+
+ return sqf, nil
+}
+
+// FilterToResourceType returns a new SchemaQueryFilterer that is limited to resources of the
+// specified type.
+func (sqf SchemaQueryFilterer) FilterToResourceType(resourceType string) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColNamespace: resourceType})
+ sqf.recordColumnValue(sqf.schema.ColNamespace, resourceType)
+ return sqf
+}
+
+func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string) {
+ existing, ok := sqf.filteringColumnTracker[colName]
+ if ok {
+ if existing.SingleValue != nil && *existing.SingleValue != colValue {
+ sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil}
+ }
+ } else {
+ sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: &colValue}
+ }
+}
+
+func (sqf SchemaQueryFilterer) recordVaryingColumnValue(colName string) {
+ sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil}
+}
+
+// FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the
+// specified ID.
+func (sqf SchemaQueryFilterer) FilterToResourceID(objectID string) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColObjectID: objectID})
+ sqf.recordColumnValue(sqf.schema.ColObjectID, objectID)
+ return sqf
+}
+
+func (sqf SchemaQueryFilterer) MustFilterToResourceIDs(resourceIds []string) SchemaQueryFilterer {
+ updated, err := sqf.FilterToResourceIDs(resourceIds)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+// FilterWithResourceIDPrefix returns new SchemaQueryFilterer that is limited to resources whose ID
+// starts with the specified prefix.
+func (sqf SchemaQueryFilterer) FilterWithResourceIDPrefix(prefix string) (SchemaQueryFilterer, error) {
+ if strings.Contains(prefix, "%") {
+ return sqf, spiceerrors.MustBugf("prefix cannot contain the percent sign")
+ }
+ if prefix == "" {
+ return sqf, spiceerrors.MustBugf("prefix cannot be empty")
+ }
+
+ prefix = strings.ReplaceAll(prefix, `\`, `\\`)
+ prefix = strings.ReplaceAll(prefix, "_", `\_`)
+
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.ColObjectID: prefix + "%"})
+
+ // NOTE: we do *not* record the use of the resource ID column here, because it is not used
+ // statically and thus is necessary for sorting operations.
+ return sqf, nil
+}
+
+func (sqf SchemaQueryFilterer) MustFilterWithResourceIDPrefix(prefix string) SchemaQueryFilterer {
+ updated, err := sqf.FilterWithResourceIDPrefix(prefix)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+// FilterToResourceIDs returns a new SchemaQueryFilterer that is limited to resources with any of the
+// specified IDs.
+func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (SchemaQueryFilterer, error) {
+ spiceerrors.DebugAssert(func() bool {
+ return len(resourceIds) <= int(sqf.filterMaximumIDCount)
+ }, "cannot have more than %d resource IDs in a single filter", sqf.filterMaximumIDCount)
+
+ var builder strings.Builder
+ builder.WriteString(sqf.schema.ColObjectID)
+ builder.WriteString(" IN (")
+ args := make([]any, 0, len(resourceIds))
+
+ for _, resourceID := range resourceIds {
+ if len(resourceID) == 0 {
+ return sqf, spiceerrors.MustBugf("got empty resource ID")
+ }
+
+ args = append(args, resourceID)
+ sqf.recordColumnValue(sqf.schema.ColObjectID, resourceID)
+ }
+
+ builder.WriteString("?")
+ if len(resourceIds) > 1 {
+ builder.WriteString(strings.Repeat(",?", len(resourceIds)-1))
+ }
+ builder.WriteString(")")
+
+ sqf.queryBuilder = sqf.queryBuilder.Where(builder.String(), args...)
+ return sqf, nil
+}
+
+// FilterToRelation returns a new SchemaQueryFilterer that is limited to resources with the
+// specified relation.
+func (sqf SchemaQueryFilterer) FilterToRelation(relation string) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColRelation: relation})
+ sqf.recordColumnValue(sqf.schema.ColRelation, relation)
+ return sqf
+}
+
+// MustFilterWithRelationshipsFilter returns a new SchemaQueryFilterer that is limited to resources with
+// resources that match the specified filter.
+func (sqf SchemaQueryFilterer) MustFilterWithRelationshipsFilter(filter datastore.RelationshipsFilter) SchemaQueryFilterer {
+ updated, err := sqf.FilterWithRelationshipsFilter(filter)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+func (sqf SchemaQueryFilterer) FilterWithRelationshipsFilter(filter datastore.RelationshipsFilter) (SchemaQueryFilterer, error) {
+ csqf := sqf
+
+ if filter.OptionalResourceType != "" {
+ csqf = csqf.FilterToResourceType(filter.OptionalResourceType)
+ }
+
+ if filter.OptionalResourceRelation != "" {
+ csqf = csqf.FilterToRelation(filter.OptionalResourceRelation)
+ }
+
+ if len(filter.OptionalResourceIds) > 0 && filter.OptionalResourceIDPrefix != "" {
+ return csqf, spiceerrors.MustBugf("cannot filter by both resource IDs and ID prefix")
+ }
+
+ if len(filter.OptionalResourceIds) > 0 {
+ usqf, err := csqf.FilterToResourceIDs(filter.OptionalResourceIds)
+ if err != nil {
+ return csqf, err
+ }
+ csqf = usqf
+ }
+
+ if len(filter.OptionalResourceIDPrefix) > 0 {
+ usqf, err := csqf.FilterWithResourceIDPrefix(filter.OptionalResourceIDPrefix)
+ if err != nil {
+ return csqf, err
+ }
+ csqf = usqf
+ }
+
+ if len(filter.OptionalSubjectsSelectors) > 0 {
+ usqf, err := csqf.FilterWithSubjectsSelectors(filter.OptionalSubjectsSelectors...)
+ if err != nil {
+ return csqf, err
+ }
+ csqf = usqf
+ }
+
+ switch filter.OptionalCaveatNameFilter.Option {
+ case datastore.CaveatFilterOptionHasMatchingCaveat:
+ spiceerrors.DebugAssert(func() bool {
+ return filter.OptionalCaveatNameFilter.CaveatName != ""
+ }, "caveat name must be set when using HasMatchingCaveat")
+ csqf = csqf.FilterWithCaveatName(filter.OptionalCaveatNameFilter.CaveatName)
+
+ case datastore.CaveatFilterOptionNoCaveat:
+ csqf = csqf.FilterWithNoCaveat()
+
+ case datastore.CaveatFilterOptionNone:
+ // No action needed, as this is the default behavior.
+
+ default:
+ return csqf, spiceerrors.MustBugf("unknown caveat filter option: %v", filter.OptionalCaveatNameFilter.Option)
+ }
+
+ if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionHasExpiration {
+ csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.ColExpiration: nil})
+ spiceerrors.DebugAssert(func() bool { return !sqf.schema.ExpirationDisabled }, "expiration filter requested but schema does not support expiration")
+ } else if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionNoExpiration {
+ csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.ColExpiration: nil})
+ }
+
+ return csqf, nil
+}
+
+// MustFilterWithSubjectsSelectors returns a new SchemaQueryFilterer that is limited to resources with
+// subjects that match the specified selector(s).
+func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) SchemaQueryFilterer {
+ usqf, err := sqf.FilterWithSubjectsSelectors(selectors...)
+ if err != nil {
+ panic(err)
+ }
+ return usqf
+}
+
+// FilterWithSubjectsSelectors returns a new SchemaQueryFilterer that is limited to resources with
+// subjects that match the specified selector(s).
+func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) {
+ selectorsOrClause := sq.Or{}
+
+ // If there is more than a single filter, record all the subjects as varying, as the subjects returned
+ // can differ for each branch.
+ // TODO(jschorr): Optimize this further where applicable.
+ if len(selectors) > 1 {
+ sqf.recordVaryingColumnValue(sqf.schema.ColUsersetNamespace)
+ sqf.recordVaryingColumnValue(sqf.schema.ColUsersetObjectID)
+ sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation)
+ }
+
+ for _, selector := range selectors {
+ selectorClause := sq.And{}
+
+ if len(selector.OptionalSubjectType) > 0 {
+ selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetNamespace: selector.OptionalSubjectType})
+ sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, selector.OptionalSubjectType)
+ }
+
+ if len(selector.OptionalSubjectIds) > 0 {
+ spiceerrors.DebugAssert(func() bool {
+ return len(selector.OptionalSubjectIds) <= int(sqf.filterMaximumIDCount)
+ }, "cannot have more than %d subject IDs in a single filter", sqf.filterMaximumIDCount)
+
+ var builder strings.Builder
+ builder.WriteString(sqf.schema.ColUsersetObjectID)
+ builder.WriteString(" IN (")
+ args := make([]any, 0, len(selector.OptionalSubjectIds))
+
+ for _, subjectID := range selector.OptionalSubjectIds {
+ if len(subjectID) == 0 {
+ return sqf, spiceerrors.MustBugf("got empty subject ID")
+ }
+
+ args = append(args, subjectID)
+ sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, subjectID)
+ }
+
+ builder.WriteString("?")
+ if len(selector.OptionalSubjectIds) > 1 {
+ builder.WriteString(strings.Repeat(",?", len(selector.OptionalSubjectIds)-1))
+ }
+
+ builder.WriteString(")")
+ selectorClause = append(selectorClause, sq.Expr(builder.String(), args...))
+ }
+
+ if !selector.RelationFilter.IsEmpty() {
+ if selector.RelationFilter.OnlyNonEllipsisRelations {
+ selectorClause = append(selectorClause, sq.NotEq{sqf.schema.ColUsersetRelation: datastore.Ellipsis})
+ sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation)
+ } else {
+ relations := make([]string, 0, 2)
+ if selector.RelationFilter.IncludeEllipsisRelation {
+ relations = append(relations, datastore.Ellipsis)
+ }
+
+ if selector.RelationFilter.NonEllipsisRelation != "" {
+ relations = append(relations, selector.RelationFilter.NonEllipsisRelation)
+ }
+
+ if len(relations) == 1 {
+ relName := relations[0]
+ selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetRelation: relName})
+ sqf.recordColumnValue(sqf.schema.ColUsersetRelation, relName)
+ } else {
+ orClause := sq.Or{}
+ for _, relationName := range relations {
+ dsRelationName := stringz.DefaultEmpty(relationName, datastore.Ellipsis)
+ orClause = append(orClause, sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName})
+ sqf.recordColumnValue(sqf.schema.ColUsersetRelation, dsRelationName)
+ }
+
+ selectorClause = append(selectorClause, orClause)
+ }
+ }
+ }
+
+ selectorsOrClause = append(selectorsOrClause, selectorClause)
+ }
+
+ sqf.queryBuilder = sqf.queryBuilder.Where(selectorsOrClause)
+ return sqf, nil
+}
+
+// FilterToSubjectFilter returns a new SchemaQueryFilterer that is limited to resources with
+// subjects that match the specified filter.
+func (sqf SchemaQueryFilterer) FilterToSubjectFilter(filter *v1.SubjectFilter) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetNamespace: filter.SubjectType})
+ sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, filter.SubjectType)
+
+ if filter.OptionalSubjectId != "" {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetObjectID: filter.OptionalSubjectId})
+ sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, filter.OptionalSubjectId)
+ }
+
+ if filter.OptionalRelation != nil {
+ dsRelationName := stringz.DefaultEmpty(filter.OptionalRelation.Relation, datastore.Ellipsis)
+
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName})
+ sqf.recordColumnValue(sqf.schema.ColUsersetRelation, datastore.Ellipsis)
+ }
+
+ return sqf
+}
+
+// FilterWithCaveatName returns a new SchemaQueryFilterer that is limited to resources with the
+// specified caveat name.
+func (sqf SchemaQueryFilterer) FilterWithCaveatName(caveatName string) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColCaveatName: caveatName})
+ sqf.recordColumnValue(sqf.schema.ColCaveatName, caveatName)
+ return sqf
+}
+
+// FilterWithNoCaveat returns a new SchemaQueryFilterer that is limited to resources with no caveat.
+func (sqf SchemaQueryFilterer) FilterWithNoCaveat() SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(
+ sq.Or{
+ sq.Eq{sqf.schema.ColCaveatName: nil},
+ sq.Eq{sqf.schema.ColCaveatName: ""},
+ })
+ sqf.recordVaryingColumnValue(sqf.schema.ColCaveatName)
+ return sqf
+}
+
+// Limit returns a new SchemaQueryFilterer which is limited to the specified number of results.
+func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Limit(limit)
+ return sqf
+}
+
+// QueryRelationshipsExecutor is a relationships query runner shared by SQL implementations of the datastore.
+type QueryRelationshipsExecutor struct {
+ Executor ExecuteReadRelsQueryFunc
+}
+
+// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query.
+type ExecuteReadRelsQueryFunc func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error)
+
+// ExecuteQuery executes the query.
+func (exc QueryRelationshipsExecutor) ExecuteQuery(
+ ctx context.Context,
+ query SchemaQueryFilterer,
+ opts ...options.QueryOptionsOption,
+) (datastore.RelationshipIterator, error) {
+ if query.isCustomQuery {
+ return nil, spiceerrors.MustBugf("ExecuteQuery should not be called on custom queries")
+ }
+
+ queryOpts := options.NewQueryOptionsWithOptions(opts...)
+
+ // Add sort order.
+ query = query.TupleOrder(queryOpts.Sort)
+
+ // Add cursor.
+ if queryOpts.After != nil {
+ if queryOpts.Sort == options.Unsorted {
+ return nil, datastore.ErrCursorsWithoutSorting
+ }
+
+ q, err := query.After(queryOpts.After, queryOpts.Sort)
+ if err != nil {
+ return nil, err
+ }
+ query = q
+ }
+
+ // Add limit.
+ var limit uint64
+ // NOTE: we use a uint here because it lines up with the
+ // assignments in this function, but we set it to MaxInt64
+ // because that's the biggest value that postgres and friends
+ // treat as valid.
+ limit = math.MaxInt64
+ if queryOpts.Limit != nil {
+ limit = *queryOpts.Limit
+ }
+
+ if limit < math.MaxInt64 {
+ query = query.limit(limit)
+ }
+
+ // Add FROM clause.
+ from := query.schema.RelationshipTableName
+ if query.fromTable != "" {
+ from = query.fromTable
+ }
+
+ // Add index hints, if any.
+ if query.indexingHint != nil {
+ // Check for a SQL prefix (pg_hint_plan).
+ sqlPrefix, err := query.indexingHint.SQLPrefix()
+ if err != nil {
+ return nil, fmt.Errorf("error getting SQL prefix for indexing hint: %w", err)
+ }
+
+ if sqlPrefix != "" {
+ query.queryBuilder = query.queryBuilder.Prefix(sqlPrefix)
+ }
+
+ // Check for an adjusting FROM table name (CRDB).
+ fromTableName, err := query.indexingHint.FromTable(from)
+ if err != nil {
+ return nil, fmt.Errorf("error getting FROM table name for indexing hint: %w", err)
+ }
+ from = fromTableName
+
+ // Check for a SQL suffix (MySQL, Spanner).
+ fromSuffix, err := query.indexingHint.FromSQLSuffix()
+ if err != nil {
+ return nil, fmt.Errorf("error getting SQL suffix for indexing hint: %w", err)
+ }
+
+ if fromSuffix != "" {
+ from += " " + fromSuffix
+ }
+ }
+
+ if query.fromSuffix != "" {
+ from += " " + query.fromSuffix
+ }
+
+ query.queryBuilder = query.queryBuilder.From(from)
+
+ builder := RelationshipsQueryBuilder{
+ Schema: query.schema,
+ SkipCaveats: queryOpts.SkipCaveats,
+ SkipExpiration: queryOpts.SkipExpiration,
+ SQLCheckAssertionForTest: queryOpts.SQLCheckAssertionForTest,
+ SQLExplainCallbackForTest: queryOpts.SQLExplainCallbackForTest,
+ filteringValues: query.filteringColumnTracker,
+ queryShape: queryOpts.QueryShape,
+ baseQueryBuilder: query,
+ }
+
+ return exc.Executor(ctx, builder)
+}
+
+// RelationshipsQueryBuilder is a builder for producing the SQL and arguments necessary for reading
+// relationships.
+type RelationshipsQueryBuilder struct {
+ Schema SchemaInformation
+ SkipCaveats bool
+ SkipExpiration bool
+
+ filteringValues columnTrackerMap
+ baseQueryBuilder SchemaQueryFilterer
+ SQLCheckAssertionForTest options.SQLCheckAssertionForTest
+ SQLExplainCallbackForTest options.SQLExplainCallbackForTest
+ queryShape queryshape.Shape
+}
+
+// withCaveats returns true if caveats should be included in the query.
+func (b RelationshipsQueryBuilder) withCaveats() bool {
+ return !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone
+}
+
+// withExpiration returns true if expiration should be included in the query.
+func (b RelationshipsQueryBuilder) withExpiration() bool {
+ return !b.SkipExpiration && !b.Schema.ExpirationDisabled
+}
+
+// integrityEnabled returns true if integrity columns should be included in the query.
+func (b RelationshipsQueryBuilder) integrityEnabled() bool {
+ return b.Schema.IntegrityEnabled
+}
+
+// columnCount returns the number of columns that will be selected in the query.
+func (b RelationshipsQueryBuilder) columnCount() int {
+ columnCount := relationshipStandardColumnCount
+ if b.withCaveats() {
+ columnCount += relationshipCaveatColumnCount
+ }
+ if b.withExpiration() {
+ columnCount += relationshipExpirationColumnCount
+ }
+ if b.integrityEnabled() {
+ columnCount += relationshipIntegrityColumnCount
+ }
+ return columnCount
+}
+
+// SelectSQL returns the SQL and arguments necessary for reading relationships.
+func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) {
+ // Set the column names to select.
+ columnNamesToSelect := make([]string, 0, b.columnCount())
+
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColNamespace)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColObjectID)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColRelation)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetNamespace)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetObjectID)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetRelation)
+
+ if b.withCaveats() {
+ columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext)
+ }
+
+ if b.withExpiration() {
+ columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration)
+ }
+
+ if b.integrityEnabled() {
+ columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp)
+ }
+
+ if len(columnNamesToSelect) == 0 {
+ columnNamesToSelect = append(columnNamesToSelect, "1")
+ }
+
+ sqlBuilder := b.baseQueryBuilder.queryBuilderWithMaybeExpirationFilter(b.SkipExpiration)
+ sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...)
+
+ sql, args, err := sqlBuilder.ToSql()
+ if err != nil {
+ return "", nil, err
+ }
+
+ if b.SQLCheckAssertionForTest != nil {
+ b.SQLCheckAssertionForTest(sql)
+ }
+
+ return sql, args, nil
+}
+
+// FilteringValuesForTesting returns the filtering values. For test use only.
+func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]columnTracker {
+ return maps.Clone(b.filteringValues)
+}
+
+func (b RelationshipsQueryBuilder) checkColumn(columns []string, colName string) []string {
+ if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone {
+ return append(columns, colName)
+ }
+
+ if !b.filteringValues.hasStaticValue(colName) {
+ return append(columns, colName)
+ }
+
+ return columns
+}
+
+func (b RelationshipsQueryBuilder) staticValueOrAddColumnForSelect(colsToSelect []any, colName string, field *string) []any {
+ if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone {
+ // If column optimization is disabled, always add the column to the list of columns to select.
+ colsToSelect = append(colsToSelect, field)
+ return colsToSelect
+ }
+
+ // If the value is static, set the field to it and return.
+ if found, ok := b.filteringValues[colName]; ok && found.SingleValue != nil {
+ *field = *found.SingleValue
+ return colsToSelect
+ }
+
+ // Otherwise, add the column to the list of columns to select, as the value is not static.
+ colsToSelect = append(colsToSelect, field)
+ return colsToSelect
+}
+
+// ColumnsToSelect returns the columns to select for a given query. The columns provided are
+// the references to the slots in which the values for each relationship will be placed.
+func ColumnsToSelect[CN any, CC any, EC any](
+ b RelationshipsQueryBuilder,
+ resourceObjectType *string,
+ resourceObjectID *string,
+ resourceRelation *string,
+ subjectObjectType *string,
+ subjectObjectID *string,
+ subjectRelation *string,
+ caveatName *CN,
+ caveatCtx *CC,
+ expiration EC,
+
+ integrityKeyID *string,
+ integrityHash *[]byte,
+ timestamp *time.Time,
+) ([]any, error) {
+ colsToSelect := make([]any, 0, b.columnCount())
+
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColNamespace, resourceObjectType)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColObjectID, resourceObjectID)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColRelation, resourceRelation)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetNamespace, subjectObjectType)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetObjectID, subjectObjectID)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetRelation, subjectRelation)
+
+ if b.withCaveats() {
+ colsToSelect = append(colsToSelect, caveatName, caveatCtx)
+ }
+
+ if b.withExpiration() {
+ colsToSelect = append(colsToSelect, expiration)
+ }
+
+ if b.Schema.IntegrityEnabled {
+ colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp)
+ }
+
+ if len(colsToSelect) == 0 {
+ var unused int64
+ colsToSelect = append(colsToSelect, &unused)
+ }
+
+ return colsToSelect, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go
new file mode 100644
index 0000000..fa23efc
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go
@@ -0,0 +1,31 @@
+package common
+
+import (
+ "context"
+ "errors"
+ "strings"
+)
+
+// IsCancellationError determines if an error returned by pgx has been caused by context cancellation.
+func IsCancellationError(err error) bool {
+ if errors.Is(err, context.Canceled) ||
+ errors.Is(err, context.DeadlineExceeded) ||
+ err.Error() == "conn closed" { // conns are sometimes closed async upon cancellation
+ return true
+ }
+ return false
+}
+
+// IsResettableError returns whether the given error is a resettable error.
+func IsResettableError(err error) bool {
+ // detect when an error is likely due to a node taken out of service
+ if strings.Contains(err.Error(), "broken pipe") ||
+ strings.Contains(err.Error(), "unexpected EOF") ||
+ strings.Contains(err.Error(), "conn closed") ||
+ strings.Contains(err.Error(), "connection refused") ||
+ strings.Contains(err.Error(), "connection reset by peer") {
+ return true
+ }
+
+ return false
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go
new file mode 100644
index 0000000..be665ed
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go
@@ -0,0 +1,19 @@
+package common
+
+import (
+ "errors"
+ "net/url"
+)
+
+// MetricsIDFromURL extracts the metrics ID from a given datastore URL.
+func MetricsIDFromURL(dsURL string) (string, error) {
+ if dsURL == "" {
+ return "", errors.New("datastore URL is empty")
+ }
+
+ u, err := url.Parse(dsURL)
+ if err != nil {
+ return "", errors.New("could not parse datastore URL")
+ }
+ return u.Host + u.Path, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go
new file mode 100644
index 0000000..2caa57a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go
@@ -0,0 +1,276 @@
+// Code generated by github.com/ecordell/optgen. DO NOT EDIT.
+package common
+
+import (
+ squirrel "github.com/Masterminds/squirrel"
+ defaults "github.com/creasty/defaults"
+ helpers "github.com/ecordell/optgen/helpers"
+)
+
+type SchemaInformationOption func(s *SchemaInformation)
+
+// NewSchemaInformationWithOptions creates a new SchemaInformation with the passed in options set
+func NewSchemaInformationWithOptions(opts ...SchemaInformationOption) *SchemaInformation {
+ s := &SchemaInformation{}
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// NewSchemaInformationWithOptionsAndDefaults creates a new SchemaInformation with the passed in options set starting from the defaults
+func NewSchemaInformationWithOptionsAndDefaults(opts ...SchemaInformationOption) *SchemaInformation {
+ s := &SchemaInformation{}
+ defaults.MustSet(s)
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// ToOption returns a new SchemaInformationOption that sets the values from the passed in SchemaInformation
+func (s *SchemaInformation) ToOption() SchemaInformationOption {
+ return func(to *SchemaInformation) {
+ to.RelationshipTableName = s.RelationshipTableName
+ to.ColNamespace = s.ColNamespace
+ to.ColObjectID = s.ColObjectID
+ to.ColRelation = s.ColRelation
+ to.ColUsersetNamespace = s.ColUsersetNamespace
+ to.ColUsersetObjectID = s.ColUsersetObjectID
+ to.ColUsersetRelation = s.ColUsersetRelation
+ to.ColCaveatName = s.ColCaveatName
+ to.ColCaveatContext = s.ColCaveatContext
+ to.ColExpiration = s.ColExpiration
+ to.ColIntegrityKeyID = s.ColIntegrityKeyID
+ to.ColIntegrityHash = s.ColIntegrityHash
+ to.ColIntegrityTimestamp = s.ColIntegrityTimestamp
+ to.Indexes = s.Indexes
+ to.PaginationFilterType = s.PaginationFilterType
+ to.PlaceholderFormat = s.PlaceholderFormat
+ to.NowFunction = s.NowFunction
+ to.ColumnOptimization = s.ColumnOptimization
+ to.IntegrityEnabled = s.IntegrityEnabled
+ to.ExpirationDisabled = s.ExpirationDisabled
+ to.SortByResourceColumnOrder = s.SortByResourceColumnOrder
+ to.SortBySubjectColumnOrder = s.SortBySubjectColumnOrder
+ }
+}
+
+// DebugMap returns a map form of SchemaInformation for debugging
+func (s SchemaInformation) DebugMap() map[string]any {
+ debugMap := map[string]any{}
+ debugMap["RelationshipTableName"] = helpers.DebugValue(s.RelationshipTableName, false)
+ debugMap["ColNamespace"] = helpers.DebugValue(s.ColNamespace, false)
+ debugMap["ColObjectID"] = helpers.DebugValue(s.ColObjectID, false)
+ debugMap["ColRelation"] = helpers.DebugValue(s.ColRelation, false)
+ debugMap["ColUsersetNamespace"] = helpers.DebugValue(s.ColUsersetNamespace, false)
+ debugMap["ColUsersetObjectID"] = helpers.DebugValue(s.ColUsersetObjectID, false)
+ debugMap["ColUsersetRelation"] = helpers.DebugValue(s.ColUsersetRelation, false)
+ debugMap["ColCaveatName"] = helpers.DebugValue(s.ColCaveatName, false)
+ debugMap["ColCaveatContext"] = helpers.DebugValue(s.ColCaveatContext, false)
+ debugMap["ColExpiration"] = helpers.DebugValue(s.ColExpiration, false)
+ debugMap["ColIntegrityKeyID"] = helpers.DebugValue(s.ColIntegrityKeyID, false)
+ debugMap["ColIntegrityHash"] = helpers.DebugValue(s.ColIntegrityHash, false)
+ debugMap["ColIntegrityTimestamp"] = helpers.DebugValue(s.ColIntegrityTimestamp, false)
+ debugMap["Indexes"] = helpers.DebugValue(s.Indexes, false)
+ debugMap["PaginationFilterType"] = helpers.DebugValue(s.PaginationFilterType, false)
+ debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false)
+ debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false)
+ debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false)
+ debugMap["IntegrityEnabled"] = helpers.DebugValue(s.IntegrityEnabled, false)
+ debugMap["ExpirationDisabled"] = helpers.DebugValue(s.ExpirationDisabled, false)
+ debugMap["SortByResourceColumnOrder"] = helpers.DebugValue(s.SortByResourceColumnOrder, false)
+ debugMap["SortBySubjectColumnOrder"] = helpers.DebugValue(s.SortBySubjectColumnOrder, false)
+ return debugMap
+}
+
+// SchemaInformationWithOptions configures an existing SchemaInformation with the passed in options set
+func SchemaInformationWithOptions(s *SchemaInformation, opts ...SchemaInformationOption) *SchemaInformation {
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// WithOptions configures the receiver SchemaInformation with the passed in options set
+func (s *SchemaInformation) WithOptions(opts ...SchemaInformationOption) *SchemaInformation {
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// WithRelationshipTableName returns an option that can set RelationshipTableName on a SchemaInformation
+func WithRelationshipTableName(relationshipTableName string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.RelationshipTableName = relationshipTableName
+ }
+}
+
+// WithColNamespace returns an option that can set ColNamespace on a SchemaInformation
+func WithColNamespace(colNamespace string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColNamespace = colNamespace
+ }
+}
+
+// WithColObjectID returns an option that can set ColObjectID on a SchemaInformation
+func WithColObjectID(colObjectID string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColObjectID = colObjectID
+ }
+}
+
+// WithColRelation returns an option that can set ColRelation on a SchemaInformation
+func WithColRelation(colRelation string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColRelation = colRelation
+ }
+}
+
+// WithColUsersetNamespace returns an option that can set ColUsersetNamespace on a SchemaInformation
+func WithColUsersetNamespace(colUsersetNamespace string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColUsersetNamespace = colUsersetNamespace
+ }
+}
+
+// WithColUsersetObjectID returns an option that can set ColUsersetObjectID on a SchemaInformation
+func WithColUsersetObjectID(colUsersetObjectID string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColUsersetObjectID = colUsersetObjectID
+ }
+}
+
+// WithColUsersetRelation returns an option that can set ColUsersetRelation on a SchemaInformation
+func WithColUsersetRelation(colUsersetRelation string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColUsersetRelation = colUsersetRelation
+ }
+}
+
+// WithColCaveatName returns an option that can set ColCaveatName on a SchemaInformation
+func WithColCaveatName(colCaveatName string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColCaveatName = colCaveatName
+ }
+}
+
+// WithColCaveatContext returns an option that can set ColCaveatContext on a SchemaInformation
+func WithColCaveatContext(colCaveatContext string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColCaveatContext = colCaveatContext
+ }
+}
+
+// WithColExpiration returns an option that can set ColExpiration on a SchemaInformation
+func WithColExpiration(colExpiration string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColExpiration = colExpiration
+ }
+}
+
+// WithColIntegrityKeyID returns an option that can set ColIntegrityKeyID on a SchemaInformation
+func WithColIntegrityKeyID(colIntegrityKeyID string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColIntegrityKeyID = colIntegrityKeyID
+ }
+}
+
+// WithColIntegrityHash returns an option that can set ColIntegrityHash on a SchemaInformation
+func WithColIntegrityHash(colIntegrityHash string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColIntegrityHash = colIntegrityHash
+ }
+}
+
+// WithColIntegrityTimestamp returns an option that can set ColIntegrityTimestamp on a SchemaInformation
+func WithColIntegrityTimestamp(colIntegrityTimestamp string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColIntegrityTimestamp = colIntegrityTimestamp
+ }
+}
+
+// WithIndexes returns an option that can append Indexess to SchemaInformation.Indexes
+func WithIndexes(indexes IndexDefinition) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.Indexes = append(s.Indexes, indexes)
+ }
+}
+
+// SetIndexes returns an option that can set Indexes on a SchemaInformation
+func SetIndexes(indexes []IndexDefinition) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.Indexes = indexes
+ }
+}
+
+// WithPaginationFilterType returns an option that can set PaginationFilterType on a SchemaInformation
+func WithPaginationFilterType(paginationFilterType PaginationFilterType) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.PaginationFilterType = paginationFilterType
+ }
+}
+
+// WithPlaceholderFormat returns an option that can set PlaceholderFormat on a SchemaInformation
+func WithPlaceholderFormat(placeholderFormat squirrel.PlaceholderFormat) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.PlaceholderFormat = placeholderFormat
+ }
+}
+
+// WithNowFunction returns an option that can set NowFunction on a SchemaInformation
+func WithNowFunction(nowFunction string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.NowFunction = nowFunction
+ }
+}
+
+// WithColumnOptimization returns an option that can set ColumnOptimization on a SchemaInformation
+func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColumnOptimization = columnOptimization
+ }
+}
+
+// WithIntegrityEnabled returns an option that can set IntegrityEnabled on a SchemaInformation
+func WithIntegrityEnabled(integrityEnabled bool) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.IntegrityEnabled = integrityEnabled
+ }
+}
+
+// WithExpirationDisabled returns an option that can set ExpirationDisabled on a SchemaInformation
+func WithExpirationDisabled(expirationDisabled bool) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ExpirationDisabled = expirationDisabled
+ }
+}
+
+// WithSortByResourceColumnOrder returns an option that can append SortByResourceColumnOrders to SchemaInformation.SortByResourceColumnOrder
+func WithSortByResourceColumnOrder(sortByResourceColumnOrder string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.SortByResourceColumnOrder = append(s.SortByResourceColumnOrder, sortByResourceColumnOrder)
+ }
+}
+
+// SetSortByResourceColumnOrder returns an option that can set SortByResourceColumnOrder on a SchemaInformation
+func SetSortByResourceColumnOrder(sortByResourceColumnOrder []string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.SortByResourceColumnOrder = sortByResourceColumnOrder
+ }
+}
+
+// WithSortBySubjectColumnOrder returns an option that can append SortBySubjectColumnOrders to SchemaInformation.SortBySubjectColumnOrder
+func WithSortBySubjectColumnOrder(sortBySubjectColumnOrder string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.SortBySubjectColumnOrder = append(s.SortBySubjectColumnOrder, sortBySubjectColumnOrder)
+ }
+}
+
+// SetSortBySubjectColumnOrder returns an option that can set SortBySubjectColumnOrder on a SchemaInformation
+func SetSortBySubjectColumnOrder(sortBySubjectColumnOrder []string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.SortBySubjectColumnOrder = sortBySubjectColumnOrder
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md
new file mode 100644
index 0000000..de32e34
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md
@@ -0,0 +1,23 @@
+# MemDB Datastore Implementation
+
+The `memdb` datastore implementation is based on Hashicorp's [go-memdb library](https://github.com/hashicorp/go-memdb).
+Its implementation most closely mimics that of `spanner`, or `crdb`, where there is a single immutable datastore that supports querying at any point in time.
+The `memdb` datastore is used for validating and rapidly iterating on concepts from consumers of other datastores.
+It is 100% compliant with the datastore acceptance test suite and it should be possible to use it in place of any other datastore for development purposes.
+Differences between the `memdb` datastore and other implementations that manifest themselves as differences visible to the caller should be reported as bugs.
+
+**The memdb datastore can NOT be used in a production setting!**
+
+## Implementation Caveats
+
+### No Garbage Collection
+
+This implementation of the datastore has no garbage collection, meaning that memory usage will grow monotonically with mutations.
+
+### No Durable Storage
+
+The `memdb` datastore, as its name implies, stores information entirely in memory, and therefore will lose all data when the host process terminates.
+
+### Cannot be used for multi-node dispatch
+
+If you attempt to run SpiceDB with multi-node dispatch enabled using the memory datastore, each independent node will get a separate copy of the datastore, and you will end up very confused.
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go
new file mode 100644
index 0000000..2b4baca
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go
@@ -0,0 +1,156 @@
+package memdb
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/hashicorp/go-memdb"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+)
+
+const tableCaveats = "caveats"
+
+type caveat struct {
+ name string
+ definition []byte
+ revision datastore.Revision
+}
+
+func (c *caveat) Unwrap() (*core.CaveatDefinition, error) {
+ definition := core.CaveatDefinition{}
+ err := definition.UnmarshalVT(c.definition)
+ return &definition, err
+}
+
+func (r *memdbReader) ReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+ return r.readUnwrappedCaveatByName(tx, name)
+}
+
+func (r *memdbReader) readCaveatByName(tx *memdb.Txn, name string) (*caveat, datastore.Revision, error) {
+ found, err := tx.First(tableCaveats, indexID, name)
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+ if found == nil {
+ return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name)
+ }
+ cvt := found.(*caveat)
+ return cvt, cvt.revision, nil
+}
+
+func (r *memdbReader) readUnwrappedCaveatByName(tx *memdb.Txn, name string) (*core.CaveatDefinition, datastore.Revision, error) {
+ c, rev, err := r.readCaveatByName(tx, name)
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+ unwrapped, err := c.Unwrap()
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+ return unwrapped, rev, nil
+}
+
+func (r *memdbReader) ListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) {
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ var caveats []datastore.RevisionedCaveat
+ it, err := tx.LowerBound(tableCaveats, indexID)
+ if err != nil {
+ return nil, err
+ }
+
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ rawCaveat := foundRaw.(*caveat)
+ definition, err := rawCaveat.Unwrap()
+ if err != nil {
+ return nil, err
+ }
+ caveats = append(caveats, datastore.RevisionedCaveat{
+ Definition: definition,
+ LastWrittenRevision: rawCaveat.revision,
+ })
+ }
+
+ return caveats, nil
+}
+
+func (r *memdbReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
+ allCaveats, err := r.ListAllCaveats(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ allowedCaveatNames := mapz.NewSet[string]()
+ allowedCaveatNames.Extend(caveatNames)
+
+ toReturn := make([]datastore.RevisionedCaveat, 0, len(caveatNames))
+ for _, caveat := range allCaveats {
+ if allowedCaveatNames.Has(caveat.Definition.Name) {
+ toReturn = append(toReturn, caveat)
+ }
+ }
+ return toReturn, nil
+}
+
+func (rwt *memdbReadWriteTx) WriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+ return rwt.writeCaveat(tx, caveats)
+}
+
+func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.CaveatDefinition) error {
+ caveatNames := mapz.NewSet[string]()
+ for _, coreCaveat := range caveats {
+ if !caveatNames.Add(coreCaveat.Name) {
+ return fmt.Errorf("duplicate caveat %s", coreCaveat.Name)
+ }
+ marshalled, err := coreCaveat.MarshalVT()
+ if err != nil {
+ return err
+ }
+ c := caveat{
+ name: coreCaveat.Name,
+ definition: marshalled,
+ revision: rwt.newRevision,
+ }
+ if err := tx.Insert(tableCaveats, &c); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (rwt *memdbReadWriteTx) DeleteCaveats(_ context.Context, names []string) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ if err := tx.Delete(tableCaveats, caveat{name: name}); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go
new file mode 100644
index 0000000..0ef4b8b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go
@@ -0,0 +1,37 @@
+package memdb
+
+import (
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// SerializationMaxRetriesReachedError occurs when a write request has reached its maximum number
+// of retries due to serialization errors.
+type SerializationMaxRetriesReachedError struct {
+ error
+}
+
+// NewSerializationMaxRetriesReachedErr constructs a new max retries reached error.
+func NewSerializationMaxRetriesReachedErr(baseErr error) error {
+ return SerializationMaxRetriesReachedError{
+ error: baseErr,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err SerializationMaxRetriesReachedError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.DeadlineExceeded,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_UNSPECIFIED,
+ map[string]string{
+ "details": "too many updates were made to the in-memory datastore at once; this datastore has limited write throughput capability",
+ },
+ ),
+ )
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go
new file mode 100644
index 0000000..61eba84
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go
@@ -0,0 +1,386 @@
+package memdb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/authzed/spicedb/internal/datastore/common"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+
+ "github.com/google/uuid"
+ "github.com/hashicorp/go-memdb"
+
+ "github.com/authzed/spicedb/internal/datastore/revisions"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
+)
+
+const (
+ Engine = "memory"
+ defaultWatchBufferLength = 128
+ numAttempts = 10
+)
+
+var (
+ ErrMemDBIsClosed = errors.New("datastore is closed")
+ ErrSerialization = errors.New("serialization error")
+)
+
+// DisableGC is a convenient constant for setting the garbage collection
+// interval high enough that it will never run.
+const DisableGC = time.Duration(math.MaxInt64)
+
+// NewMemdbDatastore creates a new Datastore compliant datastore backed by memdb.
+//
+// If the watchBufferLength value of 0 is set then a default value of 128 will be used.
+func NewMemdbDatastore(
+ watchBufferLength uint16,
+ revisionQuantization,
+ gcWindow time.Duration,
+) (datastore.Datastore, error) {
+ if revisionQuantization > gcWindow {
+ return nil, errors.New("gc window must be larger than quantization interval")
+ }
+
+ if revisionQuantization <= 1 {
+ revisionQuantization = 1
+ }
+
+ db, err := memdb.NewMemDB(schema)
+ if err != nil {
+ return nil, err
+ }
+
+ if watchBufferLength == 0 {
+ watchBufferLength = defaultWatchBufferLength
+ }
+
+ uniqueID := uuid.NewString()
+ return &memdbDatastore{
+ CommonDecoder: revisions.CommonDecoder{
+ Kind: revisions.Timestamp,
+ },
+ db: db,
+ revisions: []snapshot{
+ {
+ revision: nowRevision(),
+ db: db,
+ },
+ },
+
+ negativeGCWindow: gcWindow.Nanoseconds() * -1,
+ quantizationPeriod: revisionQuantization.Nanoseconds(),
+ watchBufferLength: watchBufferLength,
+ watchBufferWriteTimeout: 100 * time.Millisecond,
+ uniqueID: uniqueID,
+ }, nil
+}
+
+type memdbDatastore struct {
+ sync.RWMutex
+ revisions.CommonDecoder
+
+ // NOTE: call checkNotClosed before using
+ db *memdb.MemDB // GUARDED_BY(RWMutex)
+ revisions []snapshot // GUARDED_BY(RWMutex)
+ activeWriteTxn *memdb.Txn // GUARDED_BY(RWMutex)
+
+ negativeGCWindow int64
+ quantizationPeriod int64
+ watchBufferLength uint16
+ watchBufferWriteTimeout time.Duration
+ uniqueID string
+}
+
+type snapshot struct {
+ revision revisions.TimestampRevision
+ db *memdb.MemDB
+}
+
+func (mdb *memdbDatastore) MetricsID() (string, error) {
+ return "memdb", nil
+}
+
+func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision) datastore.Reader {
+ mdb.RLock()
+ defer mdb.RUnlock()
+
+ if err := mdb.checkNotClosed(); err != nil {
+ return &memdbReader{nil, nil, err, time.Now()}
+ }
+
+ if len(mdb.revisions) == 0 {
+ return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is not ready"), time.Now()}
+ }
+
+ if err := mdb.checkRevisionLocalCallerMustLock(dr); err != nil {
+ return &memdbReader{nil, nil, err, time.Now()}
+ }
+
+ revIndex := sort.Search(len(mdb.revisions), func(i int) bool {
+ return mdb.revisions[i].revision.GreaterThan(dr) || mdb.revisions[i].revision.Equal(dr)
+ })
+
+ // handle the case when there is no revision snapshot newer than the requested revision
+ if revIndex == len(mdb.revisions) {
+ revIndex = len(mdb.revisions) - 1
+ }
+
+ rev := mdb.revisions[revIndex]
+ if rev.db == nil {
+ return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is already closed"), time.Now()}
+ }
+
+ roTxn := rev.db.Txn(false)
+ txSrc := func() (*memdb.Txn, error) {
+ return roTxn, nil
+ }
+
+ return &memdbReader{noopTryLocker{}, txSrc, nil, time.Now()}
+}
+
+func (mdb *memdbDatastore) SupportsIntegrity() bool {
+ return true
+}
+
+func (mdb *memdbDatastore) ReadWriteTx(
+ ctx context.Context,
+ f datastore.TxUserFunc,
+ opts ...options.RWTOptionsOption,
+) (datastore.Revision, error) {
+ config := options.NewRWTOptionsWithOptions(opts...)
+ txNumAttempts := numAttempts
+ if config.DisableRetries {
+ txNumAttempts = 1
+ }
+
+ for i := 0; i < txNumAttempts; i++ {
+ var tx *memdb.Txn
+ createTxOnce := sync.Once{}
+ txSrc := func() (*memdb.Txn, error) {
+ var err error
+ createTxOnce.Do(func() {
+ mdb.Lock()
+ defer mdb.Unlock()
+
+ if mdb.activeWriteTxn != nil {
+ err = ErrSerialization
+ return
+ }
+
+ if err = mdb.checkNotClosed(); err != nil {
+ return
+ }
+
+ tx = mdb.db.Txn(true)
+ tx.TrackChanges()
+ mdb.activeWriteTxn = tx
+ })
+
+ return tx, err
+ }
+
+ newRevision := mdb.newRevisionID()
+ rwt := &memdbReadWriteTx{memdbReader{&sync.Mutex{}, txSrc, nil, time.Now()}, newRevision}
+ if err := f(ctx, rwt); err != nil {
+ mdb.Lock()
+ if tx != nil {
+ tx.Abort()
+ mdb.activeWriteTxn = nil
+ }
+
+ // If the error was a serialization error, retry the transaction
+ if errors.Is(err, ErrSerialization) {
+ mdb.Unlock()
+
+ // If we don't sleep here, we run out of retries instantaneously
+ time.Sleep(1 * time.Millisecond)
+ continue
+ }
+ defer mdb.Unlock()
+
+ // We *must* return the inner error unmodified in case it's not an error type
+ // that supports unwrapping (e.g. gRPC errors)
+ return datastore.NoRevision, err
+ }
+
+ mdb.Lock()
+ defer mdb.Unlock()
+
+ tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0)
+ if tx != nil {
+ if config.Metadata != nil && len(config.Metadata.GetFields()) > 0 {
+ if err := tracked.SetRevisionMetadata(ctx, newRevision, config.Metadata.AsMap()); err != nil {
+ return datastore.NoRevision, err
+ }
+ }
+
+ for _, change := range tx.Changes() {
+ switch change.Table {
+ case tableRelationship:
+ if change.After != nil {
+ rt, err := change.After.(*relationship).Relationship()
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+
+ if err := tracked.AddRelationshipChange(ctx, newRevision, rt, tuple.UpdateOperationTouch); err != nil {
+ return datastore.NoRevision, err
+ }
+ } else if change.After == nil && change.Before != nil {
+ rt, err := change.Before.(*relationship).Relationship()
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+
+ if err := tracked.AddRelationshipChange(ctx, newRevision, rt, tuple.UpdateOperationDelete); err != nil {
+ return datastore.NoRevision, err
+ }
+ } else {
+ return datastore.NoRevision, spiceerrors.MustBugf("unexpected relationship change")
+ }
+ case tableNamespace:
+ if change.After != nil {
+ loaded := &corev1.NamespaceDefinition{}
+ if err := loaded.UnmarshalVT(change.After.(*namespace).configBytes); err != nil {
+ return datastore.NoRevision, err
+ }
+
+ err := tracked.AddChangedDefinition(ctx, newRevision, loaded)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ } else if change.After == nil && change.Before != nil {
+ err := tracked.AddDeletedNamespace(ctx, newRevision, change.Before.(*namespace).name)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ } else {
+ return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change")
+ }
+ case tableCaveats:
+ if change.After != nil {
+ loaded := &corev1.CaveatDefinition{}
+ if err := loaded.UnmarshalVT(change.After.(*caveat).definition); err != nil {
+ return datastore.NoRevision, err
+ }
+
+ err := tracked.AddChangedDefinition(ctx, newRevision, loaded)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ } else if change.After == nil && change.Before != nil {
+ err := tracked.AddDeletedCaveat(ctx, newRevision, change.Before.(*caveat).name)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ } else {
+ return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change")
+ }
+ }
+ }
+
+ var rc datastore.RevisionChanges
+ changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+
+ if len(changes) > 1 {
+ return datastore.NoRevision, spiceerrors.MustBugf("unexpected MemDB transaction with multiple revision changes")
+ } else if len(changes) == 1 {
+ rc = changes[0]
+ }
+
+ change := &changelog{
+ revisionNanos: newRevision.TimestampNanoSec(),
+ changes: rc,
+ }
+ if err := tx.Insert(tableChangelog, change); err != nil {
+ return datastore.NoRevision, fmt.Errorf("error writing changelog: %w", err)
+ }
+
+ tx.Commit()
+ }
+ mdb.activeWriteTxn = nil
+
+ if err := mdb.checkNotClosed(); err != nil {
+ return datastore.NoRevision, err
+ }
+
+ // Create a snapshot and add it to the revisions slice
+ snap := mdb.db.Snapshot()
+ mdb.revisions = append(mdb.revisions, snapshot{newRevision, snap})
+ return newRevision, nil
+ }
+
+ return datastore.NoRevision, NewSerializationMaxRetriesReachedErr(errors.New("serialization max retries exceeded; please reduce your parallel writes"))
+}
+
+func (mdb *memdbDatastore) ReadyState(_ context.Context) (datastore.ReadyState, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+
+ return datastore.ReadyState{
+ Message: "missing expected initial revision",
+ IsReady: len(mdb.revisions) > 0,
+ }, nil
+}
+
+func (mdb *memdbDatastore) OfflineFeatures() (*datastore.Features, error) {
+ return &datastore.Features{
+ Watch: datastore.Feature{
+ Status: datastore.FeatureSupported,
+ },
+ IntegrityData: datastore.Feature{
+ Status: datastore.FeatureSupported,
+ },
+ ContinuousCheckpointing: datastore.Feature{
+ Status: datastore.FeatureUnsupported,
+ },
+ WatchEmitsImmediately: datastore.Feature{
+ Status: datastore.FeatureUnsupported,
+ },
+ }, nil
+}
+
+func (mdb *memdbDatastore) Features(_ context.Context) (*datastore.Features, error) {
+ return mdb.OfflineFeatures()
+}
+
+func (mdb *memdbDatastore) Close() error {
+ mdb.Lock()
+ defer mdb.Unlock()
+
+ if db := mdb.db; db != nil {
+ mdb.revisions = []snapshot{
+ {
+ revision: nowRevision(),
+ db: db,
+ },
+ }
+ } else {
+ mdb.revisions = []snapshot{}
+ }
+
+ mdb.db = nil
+
+ return nil
+}
+
+// This code assumes that the RWMutex has been acquired.
+func (mdb *memdbDatastore) checkNotClosed() error {
+ if mdb.db == nil {
+ return ErrMemDBIsClosed
+ }
+ return nil
+}
+
+var _ datastore.Datastore = &memdbDatastore{}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go
new file mode 100644
index 0000000..fdd224a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go
@@ -0,0 +1,597 @@
+package memdb
+
+import (
+ "context"
+ "fmt"
+ "slices"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/hashicorp/go-memdb"
+
+ "github.com/authzed/spicedb/internal/datastore/common"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+type txFactory func() (*memdb.Txn, error)
+
+type memdbReader struct {
+ TryLocker
+ txSource txFactory
+ initErr error
+ now time.Time
+}
+
+func (r *memdbReader) CountRelationships(ctx context.Context, name string) (int, error) {
+ counters, err := r.LookupCounters(ctx)
+ if err != nil {
+ return 0, err
+ }
+
+ var found *core.RelationshipFilter
+ for _, counter := range counters {
+ if counter.Name == name {
+ found = counter.Filter
+ break
+ }
+ }
+
+ if found == nil {
+ return 0, datastore.NewCounterNotRegisteredErr(name)
+ }
+
+ coreFilter, err := datastore.RelationshipsFilterFromCoreFilter(found)
+ if err != nil {
+ return 0, err
+ }
+
+ iter, err := r.QueryRelationships(ctx, coreFilter)
+ if err != nil {
+ return 0, err
+ }
+
+ count := 0
+ for _, err := range iter {
+ if err != nil {
+ return 0, err
+ }
+
+ count++
+ }
+ return count, nil
+}
+
+func (r *memdbReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ var counters []datastore.RelationshipCounter
+
+ it, err := tx.LowerBound(tableCounters, indexID)
+ if err != nil {
+ return nil, err
+ }
+
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ found := foundRaw.(*counter)
+
+ loaded := &core.RelationshipFilter{}
+ if err := loaded.UnmarshalVT(found.filterBytes); err != nil {
+ return nil, err
+ }
+
+ counters = append(counters, datastore.RelationshipCounter{
+ Name: found.name,
+ Filter: loaded,
+ Count: found.count,
+ ComputedAtRevision: found.updated,
+ })
+ }
+
+ return counters, nil
+}
+
+// QueryRelationships reads relationships starting from the resource side.
+func (r *memdbReader) QueryRelationships(
+ _ context.Context,
+ filter datastore.RelationshipsFilter,
+ opts ...options.QueryOptionsOption,
+) (datastore.RelationshipIterator, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ queryOpts := options.NewQueryOptionsWithOptions(opts...)
+
+ bestIterator, err := iteratorForFilter(tx, filter)
+ if err != nil {
+ return nil, err
+ }
+
+ if queryOpts.After != nil && queryOpts.Sort == options.Unsorted {
+ return nil, datastore.ErrCursorsWithoutSorting
+ }
+
+ matchingRelationshipsFilterFunc := filterFuncForFilters(
+ filter.OptionalResourceType,
+ filter.OptionalResourceIds,
+ filter.OptionalResourceIDPrefix,
+ filter.OptionalResourceRelation,
+ filter.OptionalSubjectsSelectors,
+ filter.OptionalCaveatNameFilter,
+ filter.OptionalExpirationOption,
+ makeCursorFilterFn(queryOpts.After, queryOpts.Sort),
+ )
+ filteredIterator := memdb.NewFilterIterator(bestIterator, matchingRelationshipsFilterFunc)
+
+ switch queryOpts.Sort {
+ case options.Unsorted:
+ fallthrough
+
+ case options.ByResource:
+ iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration)
+ return iter, nil
+
+ case options.BySubject:
+ return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration)
+
+ default:
+ return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort)
+ }
+}
+
+// ReverseQueryRelationships reads relationships starting from the subject.
+func (r *memdbReader) ReverseQueryRelationships(
+ _ context.Context,
+ subjectsFilter datastore.SubjectsFilter,
+ opts ...options.ReverseQueryOptionsOption,
+) (datastore.RelationshipIterator, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ queryOpts := options.NewReverseQueryOptionsWithOptions(opts...)
+
+ iterator, err := tx.Get(
+ tableRelationship,
+ indexSubjectNamespace,
+ subjectsFilter.SubjectType,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ filterObjectType, filterRelation := "", ""
+ if queryOpts.ResRelation != nil {
+ filterObjectType = queryOpts.ResRelation.Namespace
+ filterRelation = queryOpts.ResRelation.Relation
+ }
+
+ matchingRelationshipsFilterFunc := filterFuncForFilters(
+ filterObjectType,
+ nil,
+ "",
+ filterRelation,
+ []datastore.SubjectsSelector{subjectsFilter.AsSelector()},
+ datastore.CaveatNameFilter{},
+ datastore.ExpirationFilterOptionNone,
+ makeCursorFilterFn(queryOpts.AfterForReverse, queryOpts.SortForReverse),
+ )
+ filteredIterator := memdb.NewFilterIterator(iterator, matchingRelationshipsFilterFunc)
+
+ switch queryOpts.SortForReverse {
+ case options.Unsorted:
+ fallthrough
+
+ case options.ByResource:
+ iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false)
+ return iter, nil
+
+ case options.BySubject:
+ return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false)
+
+ default:
+ return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.SortForReverse)
+ }
+}
+
+// ReadNamespace reads a namespace definition and version and returns it, and the revision at
+// which it was created or last written, if found.
+func (r *memdbReader) ReadNamespaceByName(_ context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) {
+ if r.initErr != nil {
+ return nil, datastore.NoRevision, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+
+ foundRaw, err := tx.First(tableNamespace, indexID, nsName)
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+
+ if foundRaw == nil {
+ return nil, datastore.NoRevision, datastore.NewNamespaceNotFoundErr(nsName)
+ }
+
+ found := foundRaw.(*namespace)
+
+ loaded := &core.NamespaceDefinition{}
+ if err := loaded.UnmarshalVT(found.configBytes); err != nil {
+ return nil, datastore.NoRevision, err
+ }
+
+ return loaded, found.updated, nil
+}
+
+// ListNamespaces lists all namespaces defined.
+func (r *memdbReader) ListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ var nsDefs []datastore.RevisionedNamespace
+
+ it, err := tx.LowerBound(tableNamespace, indexID)
+ if err != nil {
+ return nil, err
+ }
+
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ found := foundRaw.(*namespace)
+
+ loaded := &core.NamespaceDefinition{}
+ if err := loaded.UnmarshalVT(found.configBytes); err != nil {
+ return nil, err
+ }
+
+ nsDefs = append(nsDefs, datastore.RevisionedNamespace{
+ Definition: loaded,
+ LastWrittenRevision: found.updated,
+ })
+ }
+
+ return nsDefs, nil
+}
+
+func (r *memdbReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ if len(nsNames) == 0 {
+ return nil, nil
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ it, err := tx.LowerBound(tableNamespace, indexID)
+ if err != nil {
+ return nil, err
+ }
+
+ nsNameMap := make(map[string]struct{}, len(nsNames))
+ for _, nsName := range nsNames {
+ nsNameMap[nsName] = struct{}{}
+ }
+
+ nsDefs := make([]datastore.RevisionedNamespace, 0, len(nsNames))
+
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ found := foundRaw.(*namespace)
+
+ loaded := &core.NamespaceDefinition{}
+ if err := loaded.UnmarshalVT(found.configBytes); err != nil {
+ return nil, err
+ }
+
+ if _, ok := nsNameMap[loaded.Name]; ok {
+ nsDefs = append(nsDefs, datastore.RevisionedNamespace{
+ Definition: loaded,
+ LastWrittenRevision: found.updated,
+ })
+ }
+ }
+
+ return nsDefs, nil
+}
+
+func (r *memdbReader) mustLock() {
+ if !r.TryLock() {
+ panic("detected concurrent use of ReadWriteTransaction")
+ }
+}
+
+func iteratorForFilter(txn *memdb.Txn, filter datastore.RelationshipsFilter) (memdb.ResultIterator, error) {
+ // "_prefix" is a specialized index suffix used by github.com/hashicorp/go-memdb to match on
+ // a prefix of a string.
+ // See: https://github.com/hashicorp/go-memdb/blob/9940d4a14258e3b887bfb4bc6ebc28f65461a01c/txn.go#L531
+ index := indexNamespace + "_prefix"
+
+ var args []any
+ if filter.OptionalResourceType != "" {
+ args = append(args, filter.OptionalResourceType)
+ index = indexNamespace
+ } else {
+ args = append(args, "")
+ }
+
+ if filter.OptionalResourceType != "" && filter.OptionalResourceRelation != "" {
+ args = append(args, filter.OptionalResourceRelation)
+ index = indexNamespaceAndRelation
+ }
+
+ if len(args) == 0 {
+ return nil, spiceerrors.MustBugf("cannot specify an empty filter")
+ }
+
+ iter, err := txn.Get(tableRelationship, index, args...)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get iterator for filter: %w", err)
+ }
+
+ return iter, err
+}
+
+func filterFuncForFilters(
+ optionalResourceType string,
+ optionalResourceIds []string,
+ optionalResourceIDPrefix string,
+ optionalRelation string,
+ optionalSubjectsSelectors []datastore.SubjectsSelector,
+ optionalCaveatFilter datastore.CaveatNameFilter,
+ optionalExpirationFilter datastore.ExpirationFilterOption,
+ cursorFilter func(*relationship) bool,
+) memdb.FilterFunc {
+ return func(tupleRaw interface{}) bool {
+ tuple := tupleRaw.(*relationship)
+
+ switch {
+ case optionalResourceType != "" && optionalResourceType != tuple.namespace:
+ return true
+ case len(optionalResourceIds) > 0 && !slices.Contains(optionalResourceIds, tuple.resourceID):
+ return true
+ case optionalResourceIDPrefix != "" && !strings.HasPrefix(tuple.resourceID, optionalResourceIDPrefix):
+ return true
+ case optionalRelation != "" && optionalRelation != tuple.relation:
+ return true
+ case optionalCaveatFilter.Option == datastore.CaveatFilterOptionHasMatchingCaveat && (tuple.caveat == nil || tuple.caveat.caveatName != optionalCaveatFilter.CaveatName):
+ return true
+ case optionalCaveatFilter.Option == datastore.CaveatFilterOptionNoCaveat && (tuple.caveat != nil && tuple.caveat.caveatName != ""):
+ return true
+ case optionalExpirationFilter == datastore.ExpirationFilterOptionHasExpiration && tuple.expiration == nil:
+ return true
+ case optionalExpirationFilter == datastore.ExpirationFilterOptionNoExpiration && tuple.expiration != nil:
+ return true
+ }
+
+ applySubjectSelector := func(selector datastore.SubjectsSelector) bool {
+ switch {
+ case len(selector.OptionalSubjectType) > 0 && selector.OptionalSubjectType != tuple.subjectNamespace:
+ return false
+ case len(selector.OptionalSubjectIds) > 0 && !slices.Contains(selector.OptionalSubjectIds, tuple.subjectObjectID):
+ return false
+ }
+
+ if selector.RelationFilter.OnlyNonEllipsisRelations {
+ return tuple.subjectRelation != datastore.Ellipsis
+ }
+
+ relations := make([]string, 0, 2)
+ if selector.RelationFilter.IncludeEllipsisRelation {
+ relations = append(relations, datastore.Ellipsis)
+ }
+
+ if selector.RelationFilter.NonEllipsisRelation != "" {
+ relations = append(relations, selector.RelationFilter.NonEllipsisRelation)
+ }
+
+ return len(relations) == 0 || slices.Contains(relations, tuple.subjectRelation)
+ }
+
+ if len(optionalSubjectsSelectors) > 0 {
+ hasMatchingSelector := false
+ for _, selector := range optionalSubjectsSelectors {
+ if applySubjectSelector(selector) {
+ hasMatchingSelector = true
+ break
+ }
+ }
+
+ if !hasMatchingSelector {
+ return true
+ }
+ }
+
+ return cursorFilter(tuple)
+ }
+}
+
+func makeCursorFilterFn(after options.Cursor, order options.SortOrder) func(tpl *relationship) bool {
+ if after != nil {
+ switch order {
+ case options.ByResource:
+ return func(tpl *relationship) bool {
+ return less(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) ||
+ (eq(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) &&
+ (less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) ||
+ eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject)))
+ }
+ case options.BySubject:
+ return func(tpl *relationship) bool {
+ return less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) ||
+ (eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) &&
+ (less(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) ||
+ eq(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource)))
+ }
+ }
+ }
+ return noopCursorFilter
+}
+
+func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) (datastore.RelationshipIterator, error) {
+ results := make([]tuple.Relationship, 0)
+
+ // Coalesce all of the results into memory
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ rt, err := foundRaw.(*relationship).Relationship()
+ if err != nil {
+ return nil, err
+ }
+
+ if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) {
+ continue
+ }
+
+ if skipCaveats && rt.OptionalCaveat != nil {
+ return nil, spiceerrors.MustBugf("unexpected caveat in result for relationship: %v", rt)
+ }
+
+ if skipExpiration && rt.OptionalExpiration != nil {
+ return nil, spiceerrors.MustBugf("unexpected expiration in result for relationship: %v", rt)
+ }
+
+ results = append(results, rt)
+ }
+
+ // Sort them by subject
+ sort.Slice(results, func(i, j int) bool {
+ lhsRes := results[i].Resource
+ lhsSub := results[i].Subject
+ rhsRes := results[j].Resource
+ rhsSub := results[j].Subject
+ return less(lhsSub.ObjectType, lhsSub.ObjectID, lhsSub.Relation, rhsSub) ||
+ (eq(lhsSub.ObjectType, lhsSub.ObjectID, lhsSub.Relation, rhsSub) &&
+ (less(lhsRes.ObjectType, lhsRes.ObjectID, lhsRes.Relation, rhsRes)))
+ })
+
+ // Limit them if requested
+ if limit != nil && uint64(len(results)) > *limit {
+ results = results[0:*limit]
+ }
+
+ return common.NewSliceRelationshipIterator(results), nil
+}
+
+func noopCursorFilter(_ *relationship) bool {
+ return false
+}
+
+func less(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelation) bool {
+ return lhsNamespace < rhs.ObjectType ||
+ (lhsNamespace == rhs.ObjectType && lhsObjectID < rhs.ObjectID) ||
+ (lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation < rhs.Relation)
+}
+
+func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelation) bool {
+ return lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation == rhs.Relation
+}
+
+func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) datastore.RelationshipIterator {
+ var count uint64
+ return func(yield func(tuple.Relationship, error) bool) {
+ for {
+ foundRaw := it.Next()
+ if foundRaw == nil {
+ return
+ }
+
+ if limit != nil && count >= *limit {
+ return
+ }
+
+ rt, err := foundRaw.(*relationship).Relationship()
+ if err != nil {
+ if !yield(tuple.Relationship{}, err) {
+ return
+ }
+ continue
+ }
+
+ if skipCaveats && rt.OptionalCaveat != nil {
+ yield(rt, fmt.Errorf("unexpected caveat in result for relationship: %v", rt))
+ return
+ }
+
+ if skipExpiration && rt.OptionalExpiration != nil {
+ yield(rt, fmt.Errorf("unexpected expiration in result for relationship: %v", rt))
+ return
+ }
+
+ if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) {
+ continue
+ }
+
+ if !yield(rt, err) {
+ return
+ }
+ count++
+ }
+ }
+}
+
+var _ datastore.Reader = &memdbReader{}
+
+type TryLocker interface {
+ TryLock() bool
+ Unlock()
+}
+
+type noopTryLocker struct{}
+
+func (ntl noopTryLocker) TryLock() bool {
+ return true
+}
+
+func (ntl noopTryLocker) Unlock() {}
+
+var _ TryLocker = noopTryLocker{}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go
new file mode 100644
index 0000000..8929e84
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go
@@ -0,0 +1,386 @@
+package memdb
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/hashicorp/go-memdb"
+ "github.com/jzelinskie/stringz"
+
+ "github.com/authzed/spicedb/internal/datastore/common"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+type memdbReadWriteTx struct {
+ memdbReader
+ newRevision datastore.Revision
+}
+
+func (rwt *memdbReadWriteTx) WriteRelationships(_ context.Context, mutations []tuple.RelationshipUpdate) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ return rwt.write(tx, mutations...)
+}
+
+func (rwt *memdbReadWriteTx) toIntegrity(mutation tuple.RelationshipUpdate) *relationshipIntegrity {
+ var ri *relationshipIntegrity
+ if mutation.Relationship.OptionalIntegrity != nil {
+ ri = &relationshipIntegrity{
+ keyID: mutation.Relationship.OptionalIntegrity.KeyId,
+ hash: mutation.Relationship.OptionalIntegrity.Hash,
+ timestamp: mutation.Relationship.OptionalIntegrity.HashedAt.AsTime(),
+ }
+ }
+ return ri
+}
+
+// Caller must already hold the concurrent access lock!
+func (rwt *memdbReadWriteTx) write(tx *memdb.Txn, mutations ...tuple.RelationshipUpdate) error {
+ // Apply the mutations
+ for _, mutation := range mutations {
+ rel := &relationship{
+ mutation.Relationship.Resource.ObjectType,
+ mutation.Relationship.Resource.ObjectID,
+ mutation.Relationship.Resource.Relation,
+ mutation.Relationship.Subject.ObjectType,
+ mutation.Relationship.Subject.ObjectID,
+ mutation.Relationship.Subject.Relation,
+ rwt.toCaveatReference(mutation),
+ rwt.toIntegrity(mutation),
+ mutation.Relationship.OptionalExpiration,
+ }
+
+ found, err := tx.First(
+ tableRelationship,
+ indexID,
+ rel.namespace,
+ rel.resourceID,
+ rel.relation,
+ rel.subjectNamespace,
+ rel.subjectObjectID,
+ rel.subjectRelation,
+ )
+ if err != nil {
+ return fmt.Errorf("error loading existing relationship: %w", err)
+ }
+
+ var existing *relationship
+ if found != nil {
+ existing = found.(*relationship)
+ }
+
+ switch mutation.Operation {
+ case tuple.UpdateOperationCreate:
+ if existing != nil {
+ rt, err := existing.Relationship()
+ if err != nil {
+ return err
+ }
+ return common.NewCreateRelationshipExistsError(&rt)
+ }
+ if err := tx.Insert(tableRelationship, rel); err != nil {
+ return fmt.Errorf("error inserting relationship: %w", err)
+ }
+
+ case tuple.UpdateOperationTouch:
+ if existing != nil {
+ rt, err := existing.Relationship()
+ if err != nil {
+ return err
+ }
+ if tuple.MustString(rt) == tuple.MustString(mutation.Relationship) {
+ continue
+ }
+ }
+
+ if err := tx.Insert(tableRelationship, rel); err != nil {
+ return fmt.Errorf("error inserting relationship: %w", err)
+ }
+
+ case tuple.UpdateOperationDelete:
+ if existing != nil {
+ if err := tx.Delete(tableRelationship, existing); err != nil {
+ return fmt.Errorf("error deleting relationship: %w", err)
+ }
+ }
+ default:
+ return spiceerrors.MustBugf("unknown tuple mutation operation type: %v", mutation.Operation)
+ }
+ }
+
+ return nil
+}
+
+func (rwt *memdbReadWriteTx) toCaveatReference(mutation tuple.RelationshipUpdate) *contextualizedCaveat {
+ var cr *contextualizedCaveat
+ if mutation.Relationship.OptionalCaveat != nil {
+ cr = &contextualizedCaveat{
+ caveatName: mutation.Relationship.OptionalCaveat.CaveatName,
+ context: mutation.Relationship.OptionalCaveat.Context.AsMap(),
+ }
+ }
+ return cr
+}
+
+func (rwt *memdbReadWriteTx) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return 0, false, err
+ }
+
+ delOpts := options.NewDeleteOptionsWithOptionsAndDefaults(opts...)
+ var delLimit uint64
+ if delOpts.DeleteLimit != nil && *delOpts.DeleteLimit > 0 {
+ delLimit = *delOpts.DeleteLimit
+ }
+
+ return rwt.deleteWithLock(tx, filter, delLimit)
+}
+
+// caller must already hold the concurrent access lock
+func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.RelationshipFilter, limit uint64) (uint64, bool, error) {
+ // Create an iterator to find the relevant tuples
+ dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter)
+ if err != nil {
+ return 0, false, err
+ }
+
+ bestIter, err := iteratorForFilter(tx, dsFilter)
+ if err != nil {
+ return 0, false, err
+ }
+ filteredIter := memdb.NewFilterIterator(bestIter, relationshipFilterFilterFunc(filter))
+
+ // Collect the tuples into a slice of mutations for the changelog
+ var mutations []tuple.RelationshipUpdate
+ var counter uint64
+
+ metLimit := false
+ for row := filteredIter.Next(); row != nil; row = filteredIter.Next() {
+ rt, err := row.(*relationship).Relationship()
+ if err != nil {
+ return 0, false, err
+ }
+ mutations = append(mutations, tuple.Delete(rt))
+ counter++
+
+ if limit > 0 && counter == limit {
+ metLimit = true
+ break
+ }
+ }
+
+ return counter, metLimit, rwt.write(tx, mutations...)
+}
+
+func (rwt *memdbReadWriteTx) RegisterCounter(ctx context.Context, name string, filter *core.RelationshipFilter) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ foundRaw, err := tx.First(tableCounters, indexID, name)
+ if err != nil {
+ return err
+ }
+
+ if foundRaw != nil {
+ return datastore.NewCounterAlreadyRegisteredErr(name, filter)
+ }
+
+ filterBytes, err := filter.MarshalVT()
+ if err != nil {
+ return err
+ }
+
+ // Insert the counter
+ counter := &counter{
+ name,
+ filterBytes,
+ 0,
+ datastore.NoRevision,
+ }
+
+ return tx.Insert(tableCounters, counter)
+}
+
+func (rwt *memdbReadWriteTx) UnregisterCounter(ctx context.Context, name string) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ // Check if the counter exists
+ foundRaw, err := tx.First(tableCounters, indexID, name)
+ if err != nil {
+ return err
+ }
+
+ if foundRaw == nil {
+ return datastore.NewCounterNotRegisteredErr(name)
+ }
+
+ return tx.Delete(tableCounters, foundRaw)
+}
+
+func (rwt *memdbReadWriteTx) StoreCounterValue(ctx context.Context, name string, value int, computedAtRevision datastore.Revision) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ // Check if the counter exists
+ foundRaw, err := tx.First(tableCounters, indexID, name)
+ if err != nil {
+ return err
+ }
+
+ if foundRaw == nil {
+ return datastore.NewCounterNotRegisteredErr(name)
+ }
+
+ counter := foundRaw.(*counter)
+ counter.count = value
+ counter.updated = computedAtRevision
+
+ return tx.Insert(tableCounters, counter)
+}
+
+func (rwt *memdbReadWriteTx) WriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ for _, newConfig := range newConfigs {
+ serialized, err := newConfig.MarshalVT()
+ if err != nil {
+ return err
+ }
+
+ newConfigEntry := &namespace{newConfig.Name, serialized, rwt.newRevision}
+
+ err = tx.Insert(tableNamespace, newConfigEntry)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (rwt *memdbReadWriteTx) DeleteNamespaces(_ context.Context, nsNames ...string) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ for _, nsName := range nsNames {
+ foundRaw, err := tx.First(tableNamespace, indexID, nsName)
+ if err != nil {
+ return err
+ }
+
+ if foundRaw == nil {
+ return fmt.Errorf("namespace not found")
+ }
+
+ if err := tx.Delete(tableNamespace, foundRaw); err != nil {
+ return err
+ }
+
+ // Delete the relationships from the namespace
+ if _, _, err := rwt.deleteWithLock(tx, &v1.RelationshipFilter{
+ ResourceType: nsName,
+ }, 0); err != nil {
+ return fmt.Errorf("unable to delete relationships from deleted namespace: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func (rwt *memdbReadWriteTx) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) {
+ var numCopied uint64
+ var next *tuple.Relationship
+ var err error
+
+ updates := []tuple.RelationshipUpdate{{
+ Operation: tuple.UpdateOperationCreate,
+ }}
+
+ for next, err = iter.Next(ctx); next != nil && err == nil; next, err = iter.Next(ctx) {
+ updates[0].Relationship = *next
+ if err := rwt.WriteRelationships(ctx, updates); err != nil {
+ return 0, err
+ }
+ numCopied++
+ }
+
+ return numCopied, err
+}
+
+func relationshipFilterFilterFunc(filter *v1.RelationshipFilter) func(interface{}) bool {
+ return func(tupleRaw interface{}) bool {
+ tuple := tupleRaw.(*relationship)
+
+ // If it doesn't match one of the resource filters, filter it.
+ switch {
+ case filter.ResourceType != "" && filter.ResourceType != tuple.namespace:
+ return true
+ case filter.OptionalResourceId != "" && filter.OptionalResourceId != tuple.resourceID:
+ return true
+ case filter.OptionalResourceIdPrefix != "" && !strings.HasPrefix(tuple.resourceID, filter.OptionalResourceIdPrefix):
+ return true
+ case filter.OptionalRelation != "" && filter.OptionalRelation != tuple.relation:
+ return true
+ }
+
+ // If it doesn't match one of the subject filters, filter it.
+ if subjectFilter := filter.OptionalSubjectFilter; subjectFilter != nil {
+ switch {
+ case subjectFilter.SubjectType != tuple.subjectNamespace:
+ return true
+ case subjectFilter.OptionalSubjectId != "" && subjectFilter.OptionalSubjectId != tuple.subjectObjectID:
+ return true
+ case subjectFilter.OptionalRelation != nil &&
+ stringz.DefaultEmpty(subjectFilter.OptionalRelation.Relation, datastore.Ellipsis) != tuple.subjectRelation:
+ return true
+ }
+ }
+
+ return false
+ }
+}
+
+var _ datastore.ReadWriteTransaction = &memdbReadWriteTx{}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go
new file mode 100644
index 0000000..be79771
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go
@@ -0,0 +1,118 @@
+package memdb
+
+import (
+ "context"
+ "time"
+
+ "github.com/authzed/spicedb/internal/datastore/revisions"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+var ParseRevisionString = revisions.RevisionParser(revisions.Timestamp)
+
+func nowRevision() revisions.TimestampRevision {
+ return revisions.NewForTime(time.Now().UTC())
+}
+
+func (mdb *memdbDatastore) newRevisionID() revisions.TimestampRevision {
+ mdb.Lock()
+ defer mdb.Unlock()
+
+ existing := mdb.revisions[len(mdb.revisions)-1].revision
+ created := nowRevision()
+
+ // NOTE: The time.Now().UTC() only appears to have *microsecond* level
+ // precision on macOS Monterey in Go 1.19.1. This means that HeadRevision
+ // and the result of a ReadWriteTx could return the *same* transaction ID
+ // if both are executed in sequence without any other forms of delay on
+ // macOS. We therefore check if the created transaction ID matches that
+ // previously created and, if not, add to it.
+ //
+ // See: https://github.com/golang/go/issues/22037 which appeared to fix
+ // this in Go 1.9.2, but there appears to have been a reversion with either
+ // the new version of macOS or Go.
+ if created.Equal(existing) {
+ return revisions.NewForTimestamp(created.TimestampNanoSec() + 1)
+ }
+
+ return created
+}
+
+func (mdb *memdbDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+ if err := mdb.checkNotClosed(); err != nil {
+ return nil, err
+ }
+
+ return mdb.headRevisionNoLock(), nil
+}
+
+func (mdb *memdbDatastore) SquashRevisionsForTesting() {
+ mdb.revisions = []snapshot{
+ {
+ revision: nowRevision(),
+ db: mdb.db,
+ },
+ }
+}
+
+func (mdb *memdbDatastore) headRevisionNoLock() revisions.TimestampRevision {
+ return mdb.revisions[len(mdb.revisions)-1].revision
+}
+
+func (mdb *memdbDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+ if err := mdb.checkNotClosed(); err != nil {
+ return nil, err
+ }
+
+ now := nowRevision()
+ return revisions.NewForTimestamp(now.TimestampNanoSec() - now.TimestampNanoSec()%mdb.quantizationPeriod), nil
+}
+
+func (mdb *memdbDatastore) CheckRevision(_ context.Context, dr datastore.Revision) error {
+ mdb.RLock()
+ defer mdb.RUnlock()
+ if err := mdb.checkNotClosed(); err != nil {
+ return err
+ }
+
+ return mdb.checkRevisionLocalCallerMustLock(dr)
+}
+
+func (mdb *memdbDatastore) checkRevisionLocalCallerMustLock(dr datastore.Revision) error {
+ now := nowRevision()
+
+ // Ensure the revision has not fallen outside of the GC window. If it has, it is considered
+ // invalid.
+ if mdb.revisionOutsideGCWindow(now, dr) {
+ return datastore.NewInvalidRevisionErr(dr, datastore.RevisionStale)
+ }
+
+ // If the revision <= now and later than the GC window, it is assumed to be valid, even if
+ // HEAD revision is behind it.
+ if dr.GreaterThan(now) {
+ // If the revision is in the "future", then check to ensure that it is <= of HEAD to handle
+ // the microsecond granularity on macos (see comment above in newRevisionID)
+ headRevision := mdb.headRevisionNoLock()
+ if dr.LessThan(headRevision) || dr.Equal(headRevision) {
+ return nil
+ }
+
+ return datastore.NewInvalidRevisionErr(dr, datastore.CouldNotDetermineRevision)
+ }
+
+ return nil
+}
+
+func (mdb *memdbDatastore) revisionOutsideGCWindow(now revisions.TimestampRevision, revisionRaw datastore.Revision) bool {
+ // make an exception for head revision - it will be acceptable even if outside GC Window
+ if revisionRaw.Equal(mdb.headRevisionNoLock()) {
+ return false
+ }
+
+ oldest := revisions.NewForTimestamp(now.TimestampNanoSec() + mdb.negativeGCWindow)
+ return revisionRaw.LessThan(oldest)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go
new file mode 100644
index 0000000..7905d48
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go
@@ -0,0 +1,232 @@
+package memdb
+
+import (
+ "time"
+
+ "github.com/hashicorp/go-memdb"
+ "github.com/rs/zerolog"
+ "google.golang.org/protobuf/types/known/structpb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+const (
+ tableNamespace = "namespace"
+
+ tableRelationship = "relationship"
+ indexID = "id"
+ indexNamespace = "namespace"
+ indexNamespaceAndRelation = "namespaceAndRelation"
+ indexSubjectNamespace = "subjectNamespace"
+
+ tableCounters = "counters"
+
+ tableChangelog = "changelog"
+ indexRevision = "id"
+)
+
+type namespace struct {
+ name string
+ configBytes []byte
+ updated datastore.Revision
+}
+
+func (ns namespace) MarshalZerologObject(e *zerolog.Event) {
+ e.Stringer("rev", ns.updated).Str("name", ns.name)
+}
+
+type counter struct {
+ name string
+ filterBytes []byte
+ count int
+ updated datastore.Revision
+}
+
+type relationship struct {
+ namespace string
+ resourceID string
+ relation string
+ subjectNamespace string
+ subjectObjectID string
+ subjectRelation string
+ caveat *contextualizedCaveat
+ integrity *relationshipIntegrity
+ expiration *time.Time
+}
+
+type relationshipIntegrity struct {
+ keyID string
+ hash []byte
+ timestamp time.Time
+}
+
+func (ri relationshipIntegrity) MarshalZerologObject(e *zerolog.Event) {
+ e.Str("keyID", ri.keyID).Bytes("hash", ri.hash).Time("timestamp", ri.timestamp)
+}
+
+func (ri relationshipIntegrity) RelationshipIntegrity() *core.RelationshipIntegrity {
+ return &core.RelationshipIntegrity{
+ KeyId: ri.keyID,
+ Hash: ri.hash,
+ HashedAt: timestamppb.New(ri.timestamp),
+ }
+}
+
+type contextualizedCaveat struct {
+ caveatName string
+ context map[string]any
+}
+
+func (cr *contextualizedCaveat) ContextualizedCaveat() (*core.ContextualizedCaveat, error) {
+ if cr == nil {
+ return nil, nil
+ }
+ v, err := structpb.NewStruct(cr.context)
+ if err != nil {
+ return nil, err
+ }
+ return &core.ContextualizedCaveat{
+ CaveatName: cr.caveatName,
+ Context: v,
+ }, nil
+}
+
+func (r relationship) String() string {
+ caveat := ""
+ if r.caveat != nil {
+ caveat = "[" + r.caveat.caveatName + "]"
+ }
+
+ expiration := ""
+ if r.expiration != nil {
+ expiration = "[expiration:" + r.expiration.Format(time.RFC3339Nano) + "]"
+ }
+
+ return r.namespace + ":" + r.resourceID + "#" + r.relation + "@" + r.subjectNamespace + ":" + r.subjectObjectID + "#" + r.subjectRelation + caveat + expiration
+}
+
+func (r relationship) MarshalZerologObject(e *zerolog.Event) {
+ e.Str("rel", r.String())
+}
+
+func (r relationship) Relationship() (tuple.Relationship, error) {
+ cr, err := r.caveat.ContextualizedCaveat()
+ if err != nil {
+ return tuple.Relationship{}, err
+ }
+
+ var ig *core.RelationshipIntegrity
+ if r.integrity != nil {
+ ig = r.integrity.RelationshipIntegrity()
+ }
+
+ return tuple.Relationship{
+ RelationshipReference: tuple.RelationshipReference{
+ Resource: tuple.ObjectAndRelation{
+ ObjectType: r.namespace,
+ ObjectID: r.resourceID,
+ Relation: r.relation,
+ },
+ Subject: tuple.ObjectAndRelation{
+ ObjectType: r.subjectNamespace,
+ ObjectID: r.subjectObjectID,
+ Relation: r.subjectRelation,
+ },
+ },
+ OptionalCaveat: cr,
+ OptionalIntegrity: ig,
+ OptionalExpiration: r.expiration,
+ }, nil
+}
+
+type changelog struct {
+ revisionNanos int64
+ changes datastore.RevisionChanges
+}
+
+var schema = &memdb.DBSchema{
+ Tables: map[string]*memdb.TableSchema{
+ tableNamespace: {
+ Name: tableNamespace,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexID: {
+ Name: indexID,
+ Unique: true,
+ Indexer: &memdb.StringFieldIndex{Field: "name"},
+ },
+ },
+ },
+ tableChangelog: {
+ Name: tableChangelog,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexRevision: {
+ Name: indexRevision,
+ Unique: true,
+ Indexer: &memdb.IntFieldIndex{Field: "revisionNanos"},
+ },
+ },
+ },
+ tableRelationship: {
+ Name: tableRelationship,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexID: {
+ Name: indexID,
+ Unique: true,
+ Indexer: &memdb.CompoundIndex{
+ Indexes: []memdb.Indexer{
+ &memdb.StringFieldIndex{Field: "namespace"},
+ &memdb.StringFieldIndex{Field: "resourceID"},
+ &memdb.StringFieldIndex{Field: "relation"},
+ &memdb.StringFieldIndex{Field: "subjectNamespace"},
+ &memdb.StringFieldIndex{Field: "subjectObjectID"},
+ &memdb.StringFieldIndex{Field: "subjectRelation"},
+ },
+ },
+ },
+ indexNamespace: {
+ Name: indexNamespace,
+ Unique: false,
+ Indexer: &memdb.StringFieldIndex{Field: "namespace"},
+ },
+ indexNamespaceAndRelation: {
+ Name: indexNamespaceAndRelation,
+ Unique: false,
+ Indexer: &memdb.CompoundIndex{
+ Indexes: []memdb.Indexer{
+ &memdb.StringFieldIndex{Field: "namespace"},
+ &memdb.StringFieldIndex{Field: "relation"},
+ },
+ },
+ },
+ indexSubjectNamespace: {
+ Name: indexSubjectNamespace,
+ Unique: false,
+ Indexer: &memdb.StringFieldIndex{Field: "subjectNamespace"},
+ },
+ },
+ },
+ tableCaveats: {
+ Name: tableCaveats,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexID: {
+ Name: indexID,
+ Unique: true,
+ Indexer: &memdb.StringFieldIndex{Field: "name"},
+ },
+ },
+ },
+ tableCounters: {
+ Name: tableCounters,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexID: {
+ Name: indexID,
+ Unique: true,
+ Indexer: &memdb.StringFieldIndex{Field: "name"},
+ },
+ },
+ },
+ },
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go
new file mode 100644
index 0000000..33665a1
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go
@@ -0,0 +1,51 @@
+package memdb
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+func (mdb *memdbDatastore) Statistics(ctx context.Context) (datastore.Stats, error) {
+ head, err := mdb.HeadRevision(ctx)
+ if err != nil {
+ return datastore.Stats{}, fmt.Errorf("unable to compute head revision: %w", err)
+ }
+
+ count, err := mdb.countRelationships(ctx)
+ if err != nil {
+ return datastore.Stats{}, fmt.Errorf("unable to count relationships: %w", err)
+ }
+
+ objTypes, err := mdb.SnapshotReader(head).ListAllNamespaces(ctx)
+ if err != nil {
+ return datastore.Stats{}, fmt.Errorf("unable to list object types: %w", err)
+ }
+
+ return datastore.Stats{
+ UniqueID: mdb.uniqueID,
+ EstimatedRelationshipCount: count,
+ ObjectTypeStatistics: datastore.ComputeObjectTypeStats(objTypes),
+ }, nil
+}
+
+func (mdb *memdbDatastore) countRelationships(_ context.Context) (uint64, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+
+ txn := mdb.db.Txn(false)
+ defer txn.Abort()
+
+ it, err := txn.LowerBound(tableRelationship, indexID)
+ if err != nil {
+ return 0, err
+ }
+
+ var count uint64
+ for row := it.Next(); row != nil; row = it.Next() {
+ count++
+ }
+
+ return count, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go
new file mode 100644
index 0000000..eaa4812
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go
@@ -0,0 +1,148 @@
+package memdb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/hashicorp/go-memdb"
+
+ "github.com/authzed/spicedb/internal/datastore/revisions"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+const errWatchError = "watch error: %w"
+
+func (mdb *memdbDatastore) Watch(ctx context.Context, ar datastore.Revision, options datastore.WatchOptions) (<-chan datastore.RevisionChanges, <-chan error) {
+ watchBufferLength := options.WatchBufferLength
+ if watchBufferLength == 0 {
+ watchBufferLength = mdb.watchBufferLength
+ }
+
+ updates := make(chan datastore.RevisionChanges, watchBufferLength)
+ errs := make(chan error, 1)
+
+ if options.EmissionStrategy == datastore.EmitImmediatelyStrategy {
+ close(updates)
+ errs <- errors.New("emit immediately strategy is unsupported in MemDB")
+ return updates, errs
+ }
+
+ watchBufferWriteTimeout := options.WatchBufferWriteTimeout
+ if watchBufferWriteTimeout == 0 {
+ watchBufferWriteTimeout = mdb.watchBufferWriteTimeout
+ }
+
+ sendChange := func(change datastore.RevisionChanges) bool {
+ select {
+ case updates <- change:
+ return true
+
+ default:
+ // If we cannot immediately write, setup the timer and try again.
+ }
+
+ timer := time.NewTimer(watchBufferWriteTimeout)
+ defer timer.Stop()
+
+ select {
+ case updates <- change:
+ return true
+
+ case <-timer.C:
+ errs <- datastore.NewWatchDisconnectedErr()
+ return false
+ }
+ }
+
+ go func() {
+ defer close(updates)
+ defer close(errs)
+
+ currentTxn := ar.(revisions.TimestampRevision).TimestampNanoSec()
+
+ for {
+ var stagedUpdates []datastore.RevisionChanges
+ var watchChan <-chan struct{}
+ var err error
+ stagedUpdates, currentTxn, watchChan, err = mdb.loadChanges(ctx, currentTxn, options)
+ if err != nil {
+ errs <- err
+ return
+ }
+
+ // Write the staged updates to the channel
+ for _, changeToWrite := range stagedUpdates {
+ if !sendChange(changeToWrite) {
+ return
+ }
+ }
+
+ // Wait for new changes
+ ws := memdb.NewWatchSet()
+ ws.Add(watchChan)
+
+ err = ws.WatchCtx(ctx)
+ if err != nil {
+ switch {
+ case errors.Is(err, context.Canceled):
+ errs <- datastore.NewWatchCanceledErr()
+ default:
+ errs <- fmt.Errorf(errWatchError, err)
+ }
+ return
+ }
+ }
+ }()
+
+ return updates, errs
+}
+
+func (mdb *memdbDatastore) loadChanges(_ context.Context, currentTxn int64, options datastore.WatchOptions) ([]datastore.RevisionChanges, int64, <-chan struct{}, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+
+ if err := mdb.checkNotClosed(); err != nil {
+ return nil, 0, nil, err
+ }
+
+ loadNewTxn := mdb.db.Txn(false)
+ defer loadNewTxn.Abort()
+
+ it, err := loadNewTxn.LowerBound(tableChangelog, indexRevision, currentTxn+1)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf(errWatchError, err)
+ }
+
+ var changes []datastore.RevisionChanges
+ lastRevision := currentTxn
+ for changeRaw := it.Next(); changeRaw != nil; changeRaw = it.Next() {
+ change := changeRaw.(*changelog)
+
+ if options.Content&datastore.WatchRelationships == datastore.WatchRelationships && len(change.changes.RelationshipChanges) > 0 {
+ changes = append(changes, change.changes)
+ }
+
+ if options.Content&datastore.WatchSchema == datastore.WatchSchema &&
+ len(change.changes.ChangedDefinitions) > 0 || len(change.changes.DeletedCaveats) > 0 || len(change.changes.DeletedNamespaces) > 0 {
+ changes = append(changes, change.changes)
+ }
+
+ if options.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints && change.revisionNanos > lastRevision {
+ changes = append(changes, datastore.RevisionChanges{
+ Revision: revisions.NewForTimestamp(change.revisionNanos),
+ IsCheckpoint: true,
+ })
+ }
+
+ lastRevision = change.revisionNanos
+ }
+
+ watchChan, _, err := loadNewTxn.LastWatch(tableChangelog, indexRevision)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf(errWatchError, err)
+ }
+
+ return changes, lastRevision, watchChan, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go
new file mode 100644
index 0000000..7092728
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go
@@ -0,0 +1,79 @@
+package revisions
+
+import (
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// RevisionKind is an enum of the different kinds of revisions that can be used.
+type RevisionKind string
+
+const (
+ // Timestamp is a revision that is a timestamp.
+ Timestamp RevisionKind = "timestamp"
+
+ // TransactionID is a revision that is a transaction ID.
+ TransactionID = "txid"
+
+ // HybridLogicalClock is a revision that is a hybrid logical clock.
+ HybridLogicalClock = "hlc"
+)
+
+// ParsingFunc is a function that can parse a string into a revision.
+type ParsingFunc func(revisionStr string) (rev datastore.Revision, err error)
+
+// RevisionParser returns a ParsingFunc for the given RevisionKind.
+func RevisionParser(kind RevisionKind) ParsingFunc {
+ switch kind {
+ case TransactionID:
+ return parseTransactionIDRevisionString
+
+ case Timestamp:
+ return parseTimestampRevisionString
+
+ case HybridLogicalClock:
+ return parseHLCRevisionString
+
+ default:
+ return func(revisionStr string) (rev datastore.Revision, err error) {
+ return nil, spiceerrors.MustBugf("unknown revision kind: %v", kind)
+ }
+ }
+}
+
+// CommonDecoder is a revision decoder that can decode revisions of a given kind.
+type CommonDecoder struct {
+ Kind RevisionKind
+}
+
+func (cd CommonDecoder) RevisionFromString(s string) (datastore.Revision, error) {
+ switch cd.Kind {
+ case TransactionID:
+ return parseTransactionIDRevisionString(s)
+
+ case Timestamp:
+ return parseTimestampRevisionString(s)
+
+ case HybridLogicalClock:
+ return parseHLCRevisionString(s)
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown revision kind in decoder: %v", cd.Kind)
+ }
+}
+
+// WithInexactFloat64 is an interface that can be implemented by a revision to
+// provide an inexact float64 representation of the revision.
+type WithInexactFloat64 interface {
+ // InexactFloat64 returns a float64 that is an inexact representation of the
+ // revision.
+ InexactFloat64() float64
+}
+
+// WithTimestampRevision is an interface that can be implemented by a revision to
+// provide a timestamp.
+type WithTimestampRevision interface {
+ datastore.Revision
+ TimestampNanoSec() int64
+ ConstructForTimestamp(timestampNanoSec int64) WithTimestampRevision
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go
new file mode 100644
index 0000000..e4f7fc6
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go
@@ -0,0 +1,166 @@
+package revisions
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/shopspring/decimal"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+var zeroHLC = HLCRevision{}
+
+// NOTE: This *must* match the length defined in CRDB or the implementation below will break.
+const logicalClockLength = 10
+
+var logicalClockOffset = uint32(math.Pow10(logicalClockLength + 1))
+
+// HLCRevision is a revision that is a hybrid logical clock, stored as two integers.
+// The first integer is the timestamp in nanoseconds, and the second integer is the
+// logical clock defined as 11 digits, with the first digit being ignored to ensure
+// precision of the given logical clock.
+type HLCRevision struct {
+ time int64
+ logicalclock uint32
+}
+
+// parseHLCRevisionString parses a string into a hybrid logical clock revision.
+func parseHLCRevisionString(revisionStr string) (datastore.Revision, error) {
+ pieces := strings.Split(revisionStr, ".")
+ if len(pieces) == 1 {
+ // If there is no decimal point, assume the revision is a timestamp.
+ timestamp, err := strconv.ParseInt(pieces[0], 10, 64)
+ if err != nil {
+ return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr)
+ }
+ return HLCRevision{timestamp, logicalClockOffset}, nil
+ }
+
+ if len(pieces) != 2 {
+ return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr)
+ }
+
+ timestamp, err := strconv.ParseInt(pieces[0], 10, 64)
+ if err != nil {
+ return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr)
+ }
+
+ if len(pieces[1]) > logicalClockLength {
+ return datastore.NoRevision, spiceerrors.MustBugf("invalid revision string due to unexpected logical clock size (%d): %q", len(pieces[1]), revisionStr)
+ }
+
+ paddedLogicalClockStr := pieces[1] + strings.Repeat("0", logicalClockLength-len(pieces[1]))
+ logicalclock, err := strconv.ParseUint(paddedLogicalClockStr, 10, 64)
+ if err != nil {
+ return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr)
+ }
+
+ if logicalclock > math.MaxUint32 {
+ return datastore.NoRevision, spiceerrors.MustBugf("received logical lock that exceeds MaxUint32 (%d > %d): revision %q", logicalclock, math.MaxUint32, revisionStr)
+ }
+
+ uintLogicalClock, err := safecast.ToUint32(logicalclock)
+ if err != nil {
+ return datastore.NoRevision, spiceerrors.MustBugf("could not cast logicalclock to uint32: %v", err)
+ }
+
+ return HLCRevision{timestamp, uintLogicalClock + logicalClockOffset}, nil
+}
+
+// HLCRevisionFromString parses a string into a hybrid logical clock revision.
+func HLCRevisionFromString(revisionStr string) (HLCRevision, error) {
+ rev, err := parseHLCRevisionString(revisionStr)
+ if err != nil {
+ return zeroHLC, err
+ }
+
+ return rev.(HLCRevision), nil
+}
+
+// NewForHLC creates a new revision for the given hybrid logical clock.
+func NewForHLC(decimal decimal.Decimal) (HLCRevision, error) {
+ rev, err := HLCRevisionFromString(decimal.String())
+ if err != nil {
+ return zeroHLC, fmt.Errorf("invalid HLC decimal: %v (%s) => %w", decimal, decimal.String(), err)
+ }
+
+ return rev, nil
+}
+
+// NewHLCForTime creates a new revision for the given time.
+func NewHLCForTime(time time.Time) HLCRevision {
+ return HLCRevision{time.UnixNano(), logicalClockOffset}
+}
+
+func (hlc HLCRevision) ByteSortable() bool {
+ return true
+}
+
+func (hlc HLCRevision) Equal(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroHLC
+ }
+
+ rhsHLC := rhs.(HLCRevision)
+ return hlc.time == rhsHLC.time && hlc.logicalclock == rhsHLC.logicalclock
+}
+
+func (hlc HLCRevision) GreaterThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroHLC
+ }
+
+ rhsHLC := rhs.(HLCRevision)
+ return hlc.time > rhsHLC.time || (hlc.time == rhsHLC.time && hlc.logicalclock > rhsHLC.logicalclock)
+}
+
+func (hlc HLCRevision) LessThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroHLC
+ }
+
+ rhsHLC := rhs.(HLCRevision)
+ return hlc.time < rhsHLC.time || (hlc.time == rhsHLC.time && hlc.logicalclock < rhsHLC.logicalclock)
+}
+
+func (hlc HLCRevision) String() string {
+ logicalClockString := strconv.FormatInt(int64(hlc.logicalclock)-int64(logicalClockOffset), 10)
+ return strconv.FormatInt(hlc.time, 10) + "." + strings.Repeat("0", logicalClockLength-len(logicalClockString)) + logicalClockString
+}
+
+func (hlc HLCRevision) TimestampNanoSec() int64 {
+ return hlc.time
+}
+
+func (hlc HLCRevision) InexactFloat64() float64 {
+ return float64(hlc.time) + float64(hlc.logicalclock-logicalClockOffset)/math.Pow10(logicalClockLength)
+}
+
+func (hlc HLCRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevision {
+ return HLCRevision{timestamp, logicalClockOffset}
+}
+
+func (hlc HLCRevision) AsDecimal() (decimal.Decimal, error) {
+ return decimal.NewFromString(hlc.String())
+}
+
+var (
+ _ datastore.Revision = HLCRevision{}
+ _ WithTimestampRevision = HLCRevision{}
+)
+
+// HLCKeyFunc is used to convert a simple HLC for use in maps.
+func HLCKeyFunc(r HLCRevision) HLCRevision {
+ return r
+}
+
+// HLCKeyLessThanFunc is used to compare keys created by the HLCKeyFunc.
+func HLCKeyLessThanFunc(lhs, rhs HLCRevision) bool {
+ return lhs.time < rhs.time || (lhs.time == rhs.time && lhs.logicalclock < rhs.logicalclock)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go
new file mode 100644
index 0000000..3a5a919
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go
@@ -0,0 +1,118 @@
+package revisions
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "sync"
+ "time"
+
+ "github.com/benbjohnson/clock"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/trace"
+ "golang.org/x/sync/singleflight"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+var tracer = otel.Tracer("spicedb/internal/datastore/common/revisions")
+
+// OptimizedRevisionFunction instructs the datastore to compute its own current
+// optimized revision given the specific quantization, and return for how long
+// it will remain valid.
+type OptimizedRevisionFunction func(context.Context) (rev datastore.Revision, validFor time.Duration, err error)
+
+// NewCachedOptimizedRevisions returns a CachedOptimizedRevisions for the given configuration
+func NewCachedOptimizedRevisions(maxRevisionStaleness time.Duration) *CachedOptimizedRevisions {
+ return &CachedOptimizedRevisions{
+ maxRevisionStaleness: maxRevisionStaleness,
+ clockFn: clock.New(),
+ }
+}
+
+// SetOptimizedRevisionFunc must be called after construction, and is the method
+// by which one specializes this helper for a specific datastore.
+func (cor *CachedOptimizedRevisions) SetOptimizedRevisionFunc(revisionFunc OptimizedRevisionFunction) {
+ cor.optimizedFunc = revisionFunc
+}
+
+func (cor *CachedOptimizedRevisions) OptimizedRevision(ctx context.Context) (datastore.Revision, error) {
+ span := trace.SpanFromContext(ctx)
+ localNow := cor.clockFn.Now()
+
+ // Subtract a random amount of time from now, to let barely expired candidates get selected
+ adjustedNow := localNow
+ if cor.maxRevisionStaleness > 0 {
+ // nolint:gosec
+ // G404 use of non cryptographically secure random number generator is not a security concern here,
+ // as we are using it to introduce randomness to the accepted staleness of a revision and reduce the odds of
+ // a thundering herd to the datastore
+ adjustedNow = localNow.Add(-1 * time.Duration(rand.Int63n(cor.maxRevisionStaleness.Nanoseconds())) * time.Nanosecond)
+ }
+
+ cor.RLock()
+ for _, candidate := range cor.candidates {
+ if candidate.validThrough.After(adjustedNow) {
+ cor.RUnlock()
+ log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", candidate.validThrough).Msg("returning cached revision")
+ span.AddEvent("returning cached revision")
+ return candidate.revision, nil
+ }
+ }
+ cor.RUnlock()
+
+ newQuantizedRevision, err, _ := cor.updateGroup.Do("", func() (interface{}, error) {
+ log.Ctx(ctx).Debug().Time("now", localNow).Msg("computing new revision")
+ span.AddEvent("computing new revision")
+
+ optimized, validFor, err := cor.optimizedFunc(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("unable to compute optimized revision: %w", err)
+ }
+
+ rvt := localNow.Add(validFor)
+
+ // Prune the candidates that have definitely expired
+ cor.Lock()
+ var numToDrop uint
+ for _, candidate := range cor.candidates {
+ if candidate.validThrough.Add(cor.maxRevisionStaleness).Before(localNow) {
+ numToDrop++
+ } else {
+ break
+ }
+ }
+
+ cor.candidates = cor.candidates[numToDrop:]
+ cor.candidates = append(cor.candidates, validRevision{optimized, rvt})
+ cor.Unlock()
+
+ log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", rvt).Stringer("validFor", validFor).Msg("setting valid through")
+ return optimized, nil
+ })
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ return newQuantizedRevision.(datastore.Revision), err
+}
+
+// CachedOptimizedRevisions does caching and deduplication for requests for optimized revisions.
+type CachedOptimizedRevisions struct {
+ sync.RWMutex
+
+ maxRevisionStaleness time.Duration
+ optimizedFunc OptimizedRevisionFunction
+ clockFn clock.Clock
+
+ // these values are read and set by multiple consumers
+ candidates []validRevision // GUARDED_BY(RWMutex)
+
+ // the updategroup consolidates concurrent requests to the database into 1
+ updateGroup singleflight.Group
+}
+
+type validRevision struct {
+ revision datastore.Revision
+ validThrough time.Time
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go
new file mode 100644
index 0000000..ef793c8
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go
@@ -0,0 +1,125 @@
+package revisions
+
+import (
+ "context"
+ "time"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// RemoteNowFunction queries the datastore to get a current revision.
+type RemoteNowFunction func(context.Context) (datastore.Revision, error)
+
+// RemoteClockRevisions handles revision calculation for datastores that provide
+// their own clocks.
+type RemoteClockRevisions struct {
+ *CachedOptimizedRevisions
+
+ gcWindowNanos int64
+ nowFunc RemoteNowFunction
+ followerReadDelayNanos int64
+ quantizationNanos int64
+}
+
+// NewRemoteClockRevisions returns a RemoteClockRevisions for the given configuration
+func NewRemoteClockRevisions(gcWindow, maxRevisionStaleness, followerReadDelay, quantization time.Duration) *RemoteClockRevisions {
+ // Ensure the max revision staleness never exceeds the GC window.
+ if maxRevisionStaleness > gcWindow {
+ log.Warn().
+ Dur("maxRevisionStaleness", maxRevisionStaleness).
+ Dur("gcWindow", gcWindow).
+ Msg("the configured maximum revision staleness exceeds the configured gc window, so capping to gcWindow")
+ maxRevisionStaleness = gcWindow - 1
+ }
+
+ revisions := &RemoteClockRevisions{
+ CachedOptimizedRevisions: NewCachedOptimizedRevisions(
+ maxRevisionStaleness,
+ ),
+ gcWindowNanos: gcWindow.Nanoseconds(),
+ followerReadDelayNanos: followerReadDelay.Nanoseconds(),
+ quantizationNanos: quantization.Nanoseconds(),
+ }
+
+ revisions.SetOptimizedRevisionFunc(revisions.optimizedRevisionFunc)
+
+ return revisions
+}
+
+func (rcr *RemoteClockRevisions) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, error) {
+ nowRev, err := rcr.nowFunc(ctx)
+ if err != nil {
+ return datastore.NoRevision, 0, err
+ }
+
+ if nowRev == datastore.NoRevision {
+ return datastore.NoRevision, 0, datastore.NewInvalidRevisionErr(nowRev, datastore.CouldNotDetermineRevision)
+ }
+
+ nowTS, ok := nowRev.(WithTimestampRevision)
+ if !ok {
+ return datastore.NoRevision, 0, spiceerrors.MustBugf("expected with-timestamp revision, got %T", nowRev)
+ }
+
+ delayedNow := nowTS.TimestampNanoSec() - rcr.followerReadDelayNanos
+ quantized := delayedNow
+ validForNanos := int64(0)
+ if rcr.quantizationNanos > 0 {
+ afterLastQuantization := delayedNow % rcr.quantizationNanos
+ quantized -= afterLastQuantization
+ validForNanos = rcr.quantizationNanos - afterLastQuantization
+ }
+ log.Ctx(ctx).Debug().
+ Time("quantized", time.Unix(0, quantized)).
+ Int64("readSkew", rcr.followerReadDelayNanos).
+ Int64("totalSkew", nowTS.TimestampNanoSec()-quantized).
+ Msg("revision skews")
+
+ return nowTS.ConstructForTimestamp(quantized), time.Duration(validForNanos) * time.Nanosecond, nil
+}
+
+// SetNowFunc sets the function used to determine the head revision
+func (rcr *RemoteClockRevisions) SetNowFunc(nowFunc RemoteNowFunction) {
+ rcr.nowFunc = nowFunc
+}
+
+func (rcr *RemoteClockRevisions) CheckRevision(ctx context.Context, dsRevision datastore.Revision) error {
+ if dsRevision == datastore.NoRevision {
+ return datastore.NewInvalidRevisionErr(dsRevision, datastore.CouldNotDetermineRevision)
+ }
+
+ revision := dsRevision.(WithTimestampRevision)
+
+ ctx, span := tracer.Start(ctx, "CheckRevision")
+ defer span.End()
+
+ // Make sure the system time indicated is within the software GC window
+ now, err := rcr.nowFunc(ctx)
+ if err != nil {
+ return err
+ }
+
+ nowTS, ok := now.(WithTimestampRevision)
+ if !ok {
+ return spiceerrors.MustBugf("expected HLC revision, got %T", now)
+ }
+
+ nowNanos := nowTS.TimestampNanoSec()
+ revisionNanos := revision.TimestampNanoSec()
+
+ isStale := revisionNanos < (nowNanos - rcr.gcWindowNanos)
+ if isStale {
+ log.Ctx(ctx).Debug().Stringer("now", now).Stringer("revision", revision).Msg("stale revision")
+ return datastore.NewInvalidRevisionErr(revision, datastore.RevisionStale)
+ }
+
+ isUnknown := revisionNanos > nowNanos
+ if isUnknown {
+ log.Ctx(ctx).Debug().Stringer("now", now).Stringer("revision", revision).Msg("unknown revision")
+ return datastore.NewInvalidRevisionErr(revision, datastore.CouldNotDetermineRevision)
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go
new file mode 100644
index 0000000..fc2a250
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go
@@ -0,0 +1,97 @@
+package revisions
+
+import (
+ "fmt"
+ "strconv"
+ "time"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+// TimestampRevision is a revision that is a timestamp.
+type TimestampRevision int64
+
+var zeroTimestampRevision = TimestampRevision(0)
+
+// NewForTime creates a new revision for the given time.
+func NewForTime(time time.Time) TimestampRevision {
+ return TimestampRevision(time.UnixNano())
+}
+
+// NewForTimestamp creates a new revision for the given timestamp.
+func NewForTimestamp(timestampNanosec int64) TimestampRevision {
+ return TimestampRevision(timestampNanosec)
+}
+
+// parseTimestampRevisionString parses a string into a timestamp revision.
+func parseTimestampRevisionString(revisionStr string) (rev datastore.Revision, err error) {
+ parsed, err := strconv.ParseInt(revisionStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid integer revision: %w", err)
+ }
+
+ return TimestampRevision(parsed), nil
+}
+
+func (ir TimestampRevision) ByteSortable() bool {
+ return true
+}
+
+func (ir TimestampRevision) Equal(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTimestampRevision
+ }
+
+ return int64(ir) == int64(rhs.(TimestampRevision))
+}
+
+func (ir TimestampRevision) GreaterThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTimestampRevision
+ }
+
+ return int64(ir) > int64(rhs.(TimestampRevision))
+}
+
+func (ir TimestampRevision) LessThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTimestampRevision
+ }
+
+ return int64(ir) < int64(rhs.(TimestampRevision))
+}
+
+func (ir TimestampRevision) TimestampNanoSec() int64 {
+ return int64(ir)
+}
+
+func (ir TimestampRevision) String() string {
+ return strconv.FormatInt(int64(ir), 10)
+}
+
+func (ir TimestampRevision) Time() time.Time {
+ return time.Unix(0, int64(ir))
+}
+
+func (ir TimestampRevision) WithInexactFloat64() float64 {
+ return float64(ir)
+}
+
+func (ir TimestampRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevision {
+ return TimestampRevision(timestamp)
+}
+
+var (
+ _ datastore.Revision = TimestampRevision(0)
+ _ WithTimestampRevision = TimestampRevision(0)
+)
+
+// TimestampIDKeyFunc is used to create keys for timestamps.
+func TimestampIDKeyFunc(r TimestampRevision) int64 {
+ return int64(r)
+}
+
+// TimestampIDKeyLessThanFunc is used to create keys for timestamps.
+func TimestampIDKeyLessThanFunc(l, r int64) bool {
+ return l < r
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go
new file mode 100644
index 0000000..31d837f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go
@@ -0,0 +1,80 @@
+package revisions
+
+import (
+ "fmt"
+ "strconv"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+// TransactionIDRevision is a revision that is a transaction ID.
+type TransactionIDRevision uint64
+
+var zeroTransactionIDRevision = TransactionIDRevision(0)
+
+// NewForTransactionID creates a new revision for the given transaction ID.
+func NewForTransactionID(transactionID uint64) TransactionIDRevision {
+ return TransactionIDRevision(transactionID)
+}
+
+// parseTransactionIDRevisionString parses a string into a transaction ID revision.
+func parseTransactionIDRevisionString(revisionStr string) (rev datastore.Revision, err error) {
+ parsed, err := strconv.ParseUint(revisionStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid integer revision: %w", err)
+ }
+
+ return TransactionIDRevision(parsed), nil
+}
+
+func (ir TransactionIDRevision) ByteSortable() bool {
+ return true
+}
+
+func (ir TransactionIDRevision) Equal(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTransactionIDRevision
+ }
+
+ return uint64(ir) == uint64(rhs.(TransactionIDRevision))
+}
+
+func (ir TransactionIDRevision) GreaterThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTransactionIDRevision
+ }
+
+ return uint64(ir) > uint64(rhs.(TransactionIDRevision))
+}
+
+func (ir TransactionIDRevision) LessThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTransactionIDRevision
+ }
+
+ return uint64(ir) < uint64(rhs.(TransactionIDRevision))
+}
+
+func (ir TransactionIDRevision) TransactionID() uint64 {
+ return uint64(ir)
+}
+
+func (ir TransactionIDRevision) String() string {
+ return strconv.FormatUint(uint64(ir), 10)
+}
+
+func (ir TransactionIDRevision) WithInexactFloat64() float64 {
+ return float64(ir)
+}
+
+var _ datastore.Revision = TransactionIDRevision(0)
+
+// TransactionIDKeyFunc is used to create keys for transaction IDs.
+func TransactionIDKeyFunc(r TransactionIDRevision) uint64 {
+ return uint64(r)
+}
+
+// TransactionIDKeyLessThanFunc is used to create keys for transaction IDs.
+func TransactionIDKeyLessThanFunc(l, r uint64) bool {
+ return l < r
+}