diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/spicedb/internal/datastore | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal/datastore')
30 files changed, 5414 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go new file mode 100644 index 0000000..291abb5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go @@ -0,0 +1,352 @@ +package common + +import ( + "context" + "sort" + + "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/ccoveille/go-safecast" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +const ( + nsPrefix = "n$" + caveatPrefix = "c$" +) + +// Changes represents a set of datastore mutations that are kept self-consistent +// across one or more transaction revisions. +type Changes[R datastore.Revision, K comparable] struct { + records map[K]changeRecord[R] + keyFunc func(R) K + content datastore.WatchContent + maxByteSize uint64 + currentByteSize int64 +} + +type changeRecord[R datastore.Revision] struct { + rev R + relTouches map[string]tuple.Relationship + relDeletes map[string]tuple.Relationship + definitionsChanged map[string]datastore.SchemaDefinition + namespacesDeleted map[string]struct{} + caveatsDeleted map[string]struct{} + metadata map[string]any +} + +// NewChanges creates a new Changes object for change tracking and de-duplication. +func NewChanges[R datastore.Revision, K comparable](keyFunc func(R) K, content datastore.WatchContent, maxByteSize uint64) *Changes[R, K] { + return &Changes[R, K]{ + records: make(map[K]changeRecord[R], 0), + keyFunc: keyFunc, + content: content, + maxByteSize: maxByteSize, + currentByteSize: 0, + } +} + +// IsEmpty returns if the change set is empty. +func (ch *Changes[R, K]) IsEmpty() bool { + return len(ch.records) == 0 +} + +// AddRelationshipChange adds a specific change to the complete list of tracked changes +func (ch *Changes[R, K]) AddRelationshipChange( + ctx context.Context, + rev R, + rel tuple.Relationship, + op tuple.UpdateOperation, +) error { + if ch.content&datastore.WatchRelationships != datastore.WatchRelationships { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + key := tuple.StringWithoutCaveatOrExpiration(rel) + + switch op { + case tuple.UpdateOperationTouch: + // If there was a delete for the same tuple at the same revision, drop it + existing, ok := record.relDeletes[key] + if ok { + delete(record.relDeletes, key) + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + + record.relTouches[key] = rel + if err := ch.adjustByteSize(rel, 1); err != nil { + return err + } + + case tuple.UpdateOperationDelete: + _, alreadyTouched := record.relTouches[key] + if !alreadyTouched { + record.relDeletes[key] = rel + if err := ch.adjustByteSize(rel, 1); err != nil { + return err + } + } + + default: + return spiceerrors.MustBugf("unknown change operation") + } + + return nil +} + +type sized interface { + SizeVT() int +} + +func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error { + if ch.maxByteSize == 0 { + return nil + } + + size := item.SizeVT() * delta + ch.currentByteSize += int64(size) + if ch.currentByteSize < 0 { + return spiceerrors.MustBugf("byte size underflow") + } + + currentByteSize, err := safecast.ToUint64(ch.currentByteSize) + if err != nil { + return spiceerrors.MustBugf("could not cast currentByteSize to uint64: %v", err) + } + + if currentByteSize > ch.maxByteSize { + return datastore.NewMaximumChangesSizeExceededError(ch.maxByteSize) + } + + return nil +} + +// SetRevisionMetadata sets the metadata for the given revision. +func (ch *Changes[R, K]) SetRevisionMetadata(ctx context.Context, rev R, metadata map[string]any) error { + if len(metadata) == 0 { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + if len(record.metadata) > 0 { + return spiceerrors.MustBugf("metadata already set for revision") + } + + maps.Copy(record.metadata, metadata) + return nil +} + +func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { + k := ch.keyFunc(rev) + revisionChanges, ok := ch.records[k] + if !ok { + revisionChanges = changeRecord[R]{ + rev, + make(map[string]tuple.Relationship), + make(map[string]tuple.Relationship), + make(map[string]datastore.SchemaDefinition), + make(map[string]struct{}), + make(map[string]struct{}), + make(map[string]any), + } + ch.records[k] = revisionChanges + } + + return revisionChanges, nil +} + +// AddDeletedNamespace adds a change indicating that the namespace with the name was deleted. +func (ch *Changes[R, K]) AddDeletedNamespace( + _ context.Context, + rev R, + namespaceName string, +) error { + if ch.content&datastore.WatchSchema != datastore.WatchSchema { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + // if a delete happens in the same transaction as a change, we assume it was a change in the first place + // because that's how namespace changes are implemented in the MVCC + if _, ok := record.definitionsChanged[nsPrefix+namespaceName]; ok { + return nil + } + + delete(record.definitionsChanged, nsPrefix+namespaceName) + record.namespacesDeleted[namespaceName] = struct{}{} + return nil +} + +// AddDeletedCaveat adds a change indicating that the caveat with the name was deleted. +func (ch *Changes[R, K]) AddDeletedCaveat( + _ context.Context, + rev R, + caveatName string, +) error { + if ch.content&datastore.WatchSchema != datastore.WatchSchema { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + // if a delete happens in the same transaction as a change, we assume it was a change in the first place + // because that's how namespace changes are implemented in the MVCC + if _, ok := record.definitionsChanged[caveatPrefix+caveatName]; ok { + return nil + } + + delete(record.definitionsChanged, caveatPrefix+caveatName) + record.caveatsDeleted[caveatName] = struct{}{} + return nil +} + +// AddChangedDefinition adds a change indicating that the schema definition (namespace or caveat) +// was changed to the definition given. +func (ch *Changes[R, K]) AddChangedDefinition( + ctx context.Context, + rev R, + def datastore.SchemaDefinition, +) error { + if ch.content&datastore.WatchSchema != datastore.WatchSchema { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + switch t := def.(type) { + case *core.NamespaceDefinition: + delete(record.namespacesDeleted, t.Name) + + if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok { + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + + record.definitionsChanged[nsPrefix+t.Name] = t + + if err := ch.adjustByteSize(t, 1); err != nil { + return err + } + + case *core.CaveatDefinition: + delete(record.caveatsDeleted, t.Name) + + if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok { + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + + record.definitionsChanged[caveatPrefix+t.Name] = t + + if err := ch.adjustByteSize(t, 1); err != nil { + return err + } + + default: + log.Ctx(ctx).Fatal().Msg("unknown schema definition kind") + } + + return nil +} + +// AsRevisionChanges returns the list of changes processed so far as a datastore watch +// compatible, ordered, changelist. +func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) ([]datastore.RevisionChanges, error) { + return ch.revisionChanges(lessThanFunc, *new(R), false) +} + +// FilterAndRemoveRevisionChanges filters a list of changes processed up to the bound revision from the changes list, removing them +// and returning the filtered changes. +func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) ([]datastore.RevisionChanges, error) { + changes, err := ch.revisionChanges(lessThanFunc, boundRev, true) + if err != nil { + return nil, err + } + + ch.removeAllChangesBefore(boundRev) + return changes, nil +} + +func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) ([]datastore.RevisionChanges, error) { + if ch.IsEmpty() { + return nil, nil + } + + revisionsWithChanges := make([]K, 0, len(ch.records)) + for rk, cr := range ch.records { + if !withBound || boundRev.GreaterThan(cr.rev) { + revisionsWithChanges = append(revisionsWithChanges, rk) + } + } + + if len(revisionsWithChanges) == 0 { + return nil, nil + } + + sort.Slice(revisionsWithChanges, func(i int, j int) bool { + return lessThanFunc(revisionsWithChanges[i], revisionsWithChanges[j]) + }) + + changes := make([]datastore.RevisionChanges, len(revisionsWithChanges)) + for i, k := range revisionsWithChanges { + revisionChangeRecord := ch.records[k] + changes[i].Revision = revisionChangeRecord.rev + for _, rel := range revisionChangeRecord.relTouches { + changes[i].RelationshipChanges = append(changes[i].RelationshipChanges, tuple.Touch(rel)) + } + for _, rel := range revisionChangeRecord.relDeletes { + changes[i].RelationshipChanges = append(changes[i].RelationshipChanges, tuple.Delete(rel)) + } + changes[i].ChangedDefinitions = maps.Values(revisionChangeRecord.definitionsChanged) + changes[i].DeletedNamespaces = maps.Keys(revisionChangeRecord.namespacesDeleted) + changes[i].DeletedCaveats = maps.Keys(revisionChangeRecord.caveatsDeleted) + + if len(revisionChangeRecord.metadata) > 0 { + metadata, err := structpb.NewStruct(revisionChangeRecord.metadata) + if err != nil { + return nil, spiceerrors.MustBugf("failed to convert metadata to structpb: %v", err) + } + + changes[i].Metadata = metadata + } + } + + return changes, nil +} + +func (ch *Changes[R, K]) removeAllChangesBefore(boundRev R) { + for rk, cr := range ch.records { + if boundRev.GreaterThan(cr.rev) { + delete(ch.records, rk) + } + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go new file mode 100644 index 0000000..af0b229 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go @@ -0,0 +1,154 @@ +package common + +import ( + "context" + "errors" + "fmt" + "regexp" + "strings" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// SerializationError is returned when there's been a serialization +// error while performing a datastore operation +type SerializationError struct { + error +} + +func (err SerializationError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.Aborted, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_SERIALIZATION_FAILURE, + map[string]string{}, + ), + ) +} + +func (err SerializationError) Unwrap() error { + return err.error +} + +// NewSerializationError creates a new SerializationError +func NewSerializationError(err error) error { + return SerializationError{err} +} + +// ReadOnlyTransactionError is returned when an otherwise read-write +// transaction fails on writes with an error indicating that the datastore +// is currently in a read-only mode. +type ReadOnlyTransactionError struct { + error +} + +func (err ReadOnlyTransactionError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.Aborted, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_SERVICE_READ_ONLY, + map[string]string{}, + ), + ) +} + +// NewReadOnlyTransactionError creates a new ReadOnlyTransactionError. +func NewReadOnlyTransactionError(err error) error { + return ReadOnlyTransactionError{ + fmt.Errorf("could not perform write operation, as the datastore is currently in read-only mode: %w. This may indicate that the datastore has been put into maintenance mode", err), + } +} + +// CreateRelationshipExistsError is an error returned when attempting to CREATE an already-existing +// relationship. +type CreateRelationshipExistsError struct { + error + + // Relationship is the relationship that caused the error. May be nil, depending on the datastore. + Relationship *tuple.Relationship +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err CreateRelationshipExistsError) GRPCStatus() *status.Status { + if err.Relationship == nil { + return spiceerrors.WithCodeAndDetails( + err, + codes.AlreadyExists, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_ATTEMPT_TO_RECREATE_RELATIONSHIP, + map[string]string{}, + ), + ) + } + + relationship := tuple.ToV1Relationship(*err.Relationship) + return spiceerrors.WithCodeAndDetails( + err, + codes.AlreadyExists, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_ATTEMPT_TO_RECREATE_RELATIONSHIP, + map[string]string{ + "relationship": tuple.V1StringRelationshipWithoutCaveatOrExpiration(relationship), + "resource_type": relationship.Resource.ObjectType, + "resource_object_id": relationship.Resource.ObjectId, + "resource_relation": relationship.Relation, + "subject_type": relationship.Subject.Object.ObjectType, + "subject_object_id": relationship.Subject.Object.ObjectId, + "subject_relation": relationship.Subject.OptionalRelation, + }, + ), + ) +} + +// NewCreateRelationshipExistsError creates a new CreateRelationshipExistsError. +func NewCreateRelationshipExistsError(relationship *tuple.Relationship) error { + msg := "could not CREATE one or more relationships, as they already existed. If this is persistent, please switch to TOUCH operations or specify a precondition" + if relationship != nil { + msg = fmt.Sprintf("could not CREATE relationship `%s`, as it already existed. If this is persistent, please switch to TOUCH operations or specify a precondition", tuple.StringWithoutCaveatOrExpiration(*relationship)) + } + + return CreateRelationshipExistsError{ + errors.New(msg), + relationship, + } +} + +var ( + portMatchRegex = regexp.MustCompile("invalid port \\\"(.+)\\\" after host") + parseMatchRegex = regexp.MustCompile("parse \\\"(.+)\\\":") +) + +// RedactAndLogSensitiveConnString elides the given error, logging it only at trace +// level (after being redacted). +func RedactAndLogSensitiveConnString(ctx context.Context, baseErr string, err error, pgURL string) error { + if err == nil { + return errors.New(baseErr) + } + + // See: https://github.com/jackc/pgx/issues/1271 + filtered := err.Error() + filtered = strings.ReplaceAll(filtered, pgURL, "(redacted)") + filtered = portMatchRegex.ReplaceAllString(filtered, "(redacted)") + filtered = parseMatchRegex.ReplaceAllString(filtered, "(redacted)") + log.Ctx(ctx).Trace().Msg(baseErr + ": " + filtered) + return fmt.Errorf("%s. To view details of this error (that may contain sensitive information), please run with --log-level=trace", baseErr) +} + +// RevisionUnavailableError is returned when a revision is not available on a replica. +type RevisionUnavailableError struct { + error +} + +// NewRevisionUnavailableError creates a new RevisionUnavailableError. +func NewRevisionUnavailableError(err error) error { + return RevisionUnavailableError{err} +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go new file mode 100644 index 0000000..5788134 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go @@ -0,0 +1,269 @@ +package common + +import ( + "context" + "fmt" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" +) + +var ( + gcDurationHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_duration_seconds", + Help: "The duration of datastore garbage collection.", + Buckets: []float64{0.01, 0.1, 0.5, 1, 5, 10, 25, 60, 120}, + }) + + gcRelationshipsCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_relationships_total", + Help: "The number of stale relationships deleted by the datastore garbage collection.", + }) + + gcExpiredRelationshipsCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_expired_relationships_total", + Help: "The number of expired relationships deleted by the datastore garbage collection.", + }) + + gcTransactionsCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_transactions_total", + Help: "The number of stale transactions deleted by the datastore garbage collection.", + }) + + gcNamespacesCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_namespaces_total", + Help: "The number of stale namespaces deleted by the datastore garbage collection.", + }) + + gcFailureCounterConfig = prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_failure_total", + Help: "The number of failed runs of the datastore garbage collection.", + } + gcFailureCounter = prometheus.NewCounter(gcFailureCounterConfig) +) + +// RegisterGCMetrics registers garbage collection metrics to the default +// registry. +func RegisterGCMetrics() error { + for _, metric := range []prometheus.Collector{ + gcDurationHistogram, + gcRelationshipsCounter, + gcTransactionsCounter, + gcNamespacesCounter, + gcFailureCounter, + } { + if err := prometheus.Register(metric); err != nil { + return err + } + } + + return nil +} + +// GarbageCollector represents any datastore that supports external garbage +// collection. +type GarbageCollector interface { + // HasGCRun returns true if a garbage collection run has been completed. + HasGCRun() bool + + // MarkGCCompleted marks that a garbage collection run has been completed. + MarkGCCompleted() + + // ResetGCCompleted resets the state of the garbage collection run. + ResetGCCompleted() + + // LockForGCRun attempts to acquire a lock for garbage collection. This lock + // is typically done at the datastore level, to ensure that no other nodes are + // running garbage collection at the same time. + LockForGCRun(ctx context.Context) (bool, error) + + // UnlockAfterGCRun releases the lock after a garbage collection run. + // NOTE: this method does not take a context, as the context used for the + // reset of the GC run can be canceled/timed out and the unlock will still need to happen. + UnlockAfterGCRun() error + + // ReadyState returns the current state of the datastore. + ReadyState(context.Context) (datastore.ReadyState, error) + + // Now returns the current time from the datastore. + Now(context.Context) (time.Time, error) + + // TxIDBefore returns the highest transaction ID before the provided time. + TxIDBefore(context.Context, time.Time) (datastore.Revision, error) + + // DeleteBeforeTx deletes all data before the provided transaction ID. + DeleteBeforeTx(ctx context.Context, txID datastore.Revision) (DeletionCounts, error) + + // DeleteExpiredRels deletes all relationships that have expired. + DeleteExpiredRels(ctx context.Context) (int64, error) +} + +// DeletionCounts tracks the amount of deletions that occurred when calling +// DeleteBeforeTx. +type DeletionCounts struct { + Relationships int64 + Transactions int64 + Namespaces int64 +} + +func (g DeletionCounts) MarshalZerologObject(e *zerolog.Event) { + e. + Int64("relationships", g.Relationships). + Int64("transactions", g.Transactions). + Int64("namespaces", g.Namespaces) +} + +var MaxGCInterval = 60 * time.Minute + +// StartGarbageCollector loops forever until the context is canceled and +// performs garbage collection on the provided interval. +func StartGarbageCollector(ctx context.Context, gc GarbageCollector, interval, window, timeout time.Duration) error { + return startGarbageCollectorWithMaxElapsedTime(ctx, gc, interval, window, 0, timeout, gcFailureCounter) +} + +func startGarbageCollectorWithMaxElapsedTime(ctx context.Context, gc GarbageCollector, interval, window, maxElapsedTime, timeout time.Duration, failureCounter prometheus.Counter) error { + backoffInterval := backoff.NewExponentialBackOff() + backoffInterval.InitialInterval = interval + backoffInterval.MaxInterval = max(MaxGCInterval, interval) + backoffInterval.MaxElapsedTime = maxElapsedTime + backoffInterval.Reset() + + nextInterval := interval + + log.Ctx(ctx).Info(). + Dur("interval", nextInterval). + Msg("datastore garbage collection worker started") + + for { + select { + case <-ctx.Done(): + log.Ctx(ctx).Info(). + Msg("shutting down datastore garbage collection worker") + return ctx.Err() + + case <-time.After(nextInterval): + log.Ctx(ctx).Info(). + Dur("interval", nextInterval). + Dur("window", window). + Dur("timeout", timeout). + Msg("running garbage collection worker") + + err := RunGarbageCollection(gc, window, timeout) + if err != nil { + failureCounter.Inc() + nextInterval = backoffInterval.NextBackOff() + log.Ctx(ctx).Warn().Err(err). + Dur("next-attempt-in", nextInterval). + Msg("error attempting to perform garbage collection") + continue + } + + backoffInterval.Reset() + nextInterval = interval + + log.Ctx(ctx).Debug(). + Dur("next-run-in", interval). + Msg("datastore garbage collection scheduled for next run") + } + } +} + +// RunGarbageCollection runs garbage collection for the datastore. +func RunGarbageCollection(gc GarbageCollector, window, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + ctx, span := tracer.Start(ctx, "RunGarbageCollection") + defer span.End() + + // Before attempting anything, check if the datastore is ready. + startTime := time.Now() + ready, err := gc.ReadyState(ctx) + if err != nil { + return err + } + if !ready.IsReady { + log.Ctx(ctx).Warn(). + Msgf("datastore wasn't ready when attempting garbage collection: %s", ready.Message) + return nil + } + + ok, err := gc.LockForGCRun(ctx) + if err != nil { + return fmt.Errorf("error locking for gc run: %w", err) + } + + if !ok { + log.Info(). + Msg("datastore garbage collection already in progress on another node") + return nil + } + + defer func() { + err := gc.UnlockAfterGCRun() + if err != nil { + log.Error(). + Err(err). + Msg("error unlocking after gc run") + } + }() + + now, err := gc.Now(ctx) + if err != nil { + return fmt.Errorf("error retrieving now: %w", err) + } + + watermark, err := gc.TxIDBefore(ctx, now.Add(-1*window)) + if err != nil { + return fmt.Errorf("error retrieving watermark: %w", err) + } + + collected, err := gc.DeleteBeforeTx(ctx, watermark) + + expiredRelationshipsCount, eerr := gc.DeleteExpiredRels(ctx) + + // even if an error happened, garbage would have been collected. This makes sure these are reflected even if the + // worker eventually fails or times out. + gcRelationshipsCounter.Add(float64(collected.Relationships)) + gcTransactionsCounter.Add(float64(collected.Transactions)) + gcNamespacesCounter.Add(float64(collected.Namespaces)) + gcExpiredRelationshipsCounter.Add(float64(expiredRelationshipsCount)) + collectionDuration := time.Since(startTime) + gcDurationHistogram.Observe(collectionDuration.Seconds()) + + if err != nil { + return fmt.Errorf("error deleting in gc: %w", err) + } + + if eerr != nil { + return fmt.Errorf("error deleting expired relationships in gc: %w", eerr) + } + + log.Ctx(ctx).Info(). + Stringer("highestTxID", watermark). + Dur("duration", collectionDuration). + Time("nowTime", now). + Interface("collected", collected). + Int64("expiredRelationships", expiredRelationshipsCount). + Msg("datastore garbage collection completed successfully") + + gc.MarkGCCompleted() + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go new file mode 100644 index 0000000..8f34134 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go @@ -0,0 +1,49 @@ +package common + +import ( + "context" + "fmt" + + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +// WriteRelationships is a convenience method to perform the same update operation on a set of relationships +func WriteRelationships(ctx context.Context, ds datastore.Datastore, op tuple.UpdateOperation, rels ...tuple.Relationship) (datastore.Revision, error) { + updates := make([]tuple.RelationshipUpdate, 0, len(rels)) + for _, rel := range rels { + ru := tuple.RelationshipUpdate{ + Operation: op, + Relationship: rel, + } + updates = append(updates, ru) + } + return UpdateRelationshipsInDatastore(ctx, ds, updates...) +} + +// UpdateRelationshipsInDatastore is a convenience method to perform multiple relation update operations on a Datastore +func UpdateRelationshipsInDatastore(ctx context.Context, ds datastore.Datastore, updates ...tuple.RelationshipUpdate) (datastore.Revision, error) { + return ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteRelationships(ctx, updates) + }) +} + +// ContextualizedCaveatFrom convenience method that handles creation of a contextualized caveat +// given the possibility of arguments with zero-values. +func ContextualizedCaveatFrom(name string, context map[string]any) (*core.ContextualizedCaveat, error) { + var caveat *core.ContextualizedCaveat + if name != "" { + strct, err := structpb.NewStruct(context) + if err != nil { + return nil, fmt.Errorf("malformed caveat context: %w", err) + } + caveat = &core.ContextualizedCaveat{ + CaveatName: name, + Context: strct, + } + } + return caveat, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go new file mode 100644 index 0000000..1eb64d1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go @@ -0,0 +1,28 @@ +package common + +import "github.com/authzed/spicedb/pkg/datastore/queryshape" + +// IndexDefinition is a definition of an index for a datastore. +type IndexDefinition struct { + // Name is the unique name for the index. + Name string + + // ColumnsSQL is the SQL fragment of the columns over which this index will apply. + ColumnsSQL string + + // Shapes are those query shapes for which this index should be used. + Shapes []queryshape.Shape + + // IsDeprecated is true if this index is deprecated and should not be used. + IsDeprecated bool +} + +// matchesShape returns true if the index matches the given shape. +func (id IndexDefinition) matchesShape(shape queryshape.Shape) bool { + for _, s := range id.Shapes { + if s == shape { + return true + } + } + return false +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go new file mode 100644 index 0000000..6e84549 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go @@ -0,0 +1,15 @@ +package common + +import ( + "context" + + log "github.com/authzed/spicedb/internal/logging" +) + +// LogOnError executes the function and logs the error. +// Useful to avoid silently ignoring errors in defer statements +func LogOnError(ctx context.Context, f func() error) { + if err := f(); err != nil { + log.Ctx(ctx).Err(err).Msg("datastore error") + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go new file mode 100644 index 0000000..304f62c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go @@ -0,0 +1,42 @@ +package common + +import ( + "fmt" + "slices" + "strings" + + "github.com/authzed/spicedb/pkg/datastore" +) + +type MigrationValidator struct { + additionalAllowedMigrations []string + headMigration string +} + +func NewMigrationValidator(headMigration string, additionalAllowedMigrations []string) *MigrationValidator { + return &MigrationValidator{ + additionalAllowedMigrations: additionalAllowedMigrations, + headMigration: headMigration, + } +} + +// MigrationReadyState returns the readiness of the datastore for the given version. +func (mv *MigrationValidator) MigrationReadyState(version string) datastore.ReadyState { + if version == mv.headMigration { + return datastore.ReadyState{IsReady: true} + } + if slices.Contains(mv.additionalAllowedMigrations, version) { + return datastore.ReadyState{IsReady: true} + } + var msgBuilder strings.Builder + msgBuilder.WriteString(fmt.Sprintf("datastore is not migrated: currently at revision %q, but requires %q", version, mv.headMigration)) + + if len(mv.additionalAllowedMigrations) > 0 { + msgBuilder.WriteString(fmt.Sprintf(" (additional allowed migrations: %v)", mv.additionalAllowedMigrations)) + } + msgBuilder.WriteString(". Please run \"spicedb datastore migrate\".") + return datastore.ReadyState{ + Message: msgBuilder.String(), + IsReady: false, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go new file mode 100644 index 0000000..dee0ad5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go @@ -0,0 +1,214 @@ +package common + +import ( + "context" + "database/sql" + "fmt" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/pkg/datastore" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const errUnableToQueryRels = "unable to query relationships: %w" + +// Querier is an interface for querying the database. +type Querier[R Rows] interface { + QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error +} + +// Rows is a common interface for database rows reading. +type Rows interface { + Scan(dest ...any) error + Next() bool + Err() error +} + +type closeRowsWithError interface { + Rows + Close() error +} + +type closeRows interface { + Rows + Close() +} + +func runExplainIfNecessary[R Rows](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R], explainable datastore.Explainable) error { + if builder.SQLExplainCallbackForTest == nil { + return nil + } + + // Determine the expected index names via the schema. + expectedIndexes := builder.Schema.expectedIndexesForShape(builder.queryShape) + + // Run any pre-explain statements. + for _, statement := range explainable.PreExplainStatements() { + if err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + rows.Next() + return nil + }, statement); err != nil { + return fmt.Errorf(errUnableToQueryRels, err) + } + } + + // Run the query with EXPLAIN ANALYZE. + sqlString, args, err := builder.SelectSQL() + if err != nil { + return fmt.Errorf(errUnableToQueryRels, err) + } + + explainSQL, explainArgs, err := explainable.BuildExplainQuery(sqlString, args) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, err) + } + + err = tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + explainString := "" + for rows.Next() { + var explain string + if err := rows.Scan(&explain); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) + } + explainString += explain + "\n" + } + if explainString == "" { + return fmt.Errorf("received empty explain") + } + + return builder.SQLExplainCallbackForTest(ctx, sqlString, args, builder.queryShape, explainString, expectedIndexes) + }, explainSQL, explainArgs...) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, err) + } + + return nil +} + +// QueryRelationships queries relationships for the given query and transaction. +func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R], explainable datastore.Explainable) (datastore.RelationshipIterator, error) { + span := trace.SpanFromContext(ctx) + sqlString, args, err := builder.SelectSQL() + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) + } + + if err := runExplainIfNecessary(ctx, builder, tx, explainable); err != nil { + return nil, err + } + + var resourceObjectType string + var resourceObjectID string + var resourceRelation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName sql.NullString + var caveatCtx C + var expiration *time.Time + + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + span.AddEvent("Selecting columns") + colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) + } + + span.AddEvent("Returning iterator", trace.WithAttributes(attribute.Int("column-count", len(colsToSelect)))) + return func(yield func(tuple.Relationship, error) bool) { + span.AddEvent("Issuing query to database") + err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + span.AddEvent("Query issued to database") + + var r Rows = rows + if crwe, ok := r.(closeRowsWithError); ok { + defer LogOnError(ctx, crwe.Close) + } else if cr, ok := r.(closeRows); ok { + defer cr.Close() + } + + relCount := 0 + for rows.Next() { + if relCount == 0 { + span.AddEvent("First row returned") + } + + if err := rows.Scan(colsToSelect...); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) + } + + if relCount == 0 { + span.AddEvent("First row scanned") + } + + var caveat *corev1.ContextualizedCaveat + if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + if caveatName.Valid { + var err error + caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err)) + } + } + } + + var integrity *corev1.RelationshipIntegrity + if integrityKeyID != "" { + integrity = &corev1.RelationshipIntegrity{ + KeyId: integrityKeyID, + Hash: integrityHash, + HashedAt: timestamppb.New(timestamp), + } + } + + if expiration != nil { + // Ensure the expiration is always read in UTC, since some datastores (like CRDB) + // will normalize to local time. + t := expiration.UTC() + expiration = &t + } + + relCount++ + if !yield(tuple.Relationship{ + RelationshipReference: tuple.RelationshipReference{ + Resource: tuple.ObjectAndRelation{ + ObjectType: resourceObjectType, + ObjectID: resourceObjectID, + Relation: resourceRelation, + }, + Subject: tuple.ObjectAndRelation{ + ObjectType: subjectObjectType, + ObjectID: subjectObjectID, + Relation: subjectRelation, + }, + }, + OptionalCaveat: caveat, + OptionalExpiration: expiration, + OptionalIntegrity: integrity, + }, nil) { + return nil + } + } + + span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) + if err := rows.Err(); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err)) + } + + return nil + }, sqlString, args...) + if err != nil { + if !yield(tuple.Relationship{}, err) { + return + } + } + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go new file mode 100644 index 0000000..6e44d0b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go @@ -0,0 +1,188 @@ +package common + +import ( + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +const ( + relationshipStandardColumnCount = 6 // ColNamespace, ColObjectID, ColRelation, ColUsersetNamespace, ColUsersetObjectID, ColUsersetRelation + relationshipCaveatColumnCount = 2 // ColCaveatName, ColCaveatContext + relationshipExpirationColumnCount = 1 // ColExpiration + relationshipIntegrityColumnCount = 3 // ColIntegrityKeyID, ColIntegrityHash, ColIntegrityTimestamp +) + +// SchemaInformation holds the schema information from the SQL datastore implementation. +// +//go:generate go run github.com/ecordell/optgen -output zz_generated.schema_options.go . SchemaInformation +type SchemaInformation struct { + RelationshipTableName string `debugmap:"visible"` + + ColNamespace string `debugmap:"visible"` + ColObjectID string `debugmap:"visible"` + ColRelation string `debugmap:"visible"` + ColUsersetNamespace string `debugmap:"visible"` + ColUsersetObjectID string `debugmap:"visible"` + ColUsersetRelation string `debugmap:"visible"` + + ColCaveatName string `debugmap:"visible"` + ColCaveatContext string `debugmap:"visible"` + + ColExpiration string `debugmap:"visible"` + + ColIntegrityKeyID string `debugmap:"visible"` + ColIntegrityHash string `debugmap:"visible"` + ColIntegrityTimestamp string `debugmap:"visible"` + + // Indexes are the indexes to use for this schema. + Indexes []IndexDefinition `debugmap:"visible"` + + // PaginationFilterType is the type of pagination filter to use for this schema. + PaginationFilterType PaginationFilterType `debugmap:"visible"` + + // PlaceholderFormat is the format of placeholders to use for this schema. + PlaceholderFormat sq.PlaceholderFormat `debugmap:"visible"` + + // NowFunction is the function to use to get the current time in the datastore. + NowFunction string `debugmap:"visible"` + + // ColumnOptimization is the optimization to use for columns in the schema, if any. + ColumnOptimization ColumnOptimizationOption `debugmap:"visible"` + + // IntegrityEnabled is a flag to indicate if the schema has integrity columns. + IntegrityEnabled bool `debugmap:"visible"` + + // ExpirationDisabled is a flag to indicate whether expiration support is disabled. + ExpirationDisabled bool `debugmap:"visible"` + + // SortByResourceColumnOrder is the order of the resource columns in the schema to use + // when sorting by resource. If unspecified, the default will be used. + SortByResourceColumnOrder []string `debugmap:"visible"` + + // SortBySubjectColumnOrder is the order of the subject columns in the schema to use + // when sorting by subject. If unspecified, the default will be used. + SortBySubjectColumnOrder []string `debugmap:"visible"` +} + +// expectedIndexesForShape returns the expected index names for a given query shape. +func (si SchemaInformation) expectedIndexesForShape(shape queryshape.Shape) options.SQLIndexInformation { + expectedIndexes := options.SQLIndexInformation{} + for _, index := range si.Indexes { + if index.matchesShape(shape) { + expectedIndexes.ExpectedIndexNames = append(expectedIndexes.ExpectedIndexNames, index.Name) + } + } + return expectedIndexes +} + +func (si SchemaInformation) debugValidate() { + spiceerrors.DebugAssert(func() bool { + si.mustValidate() + return true + }, "SchemaInformation failed to validate") +} + +func (si SchemaInformation) sortByResourceColumnOrderColumns() []string { + if len(si.SortByResourceColumnOrder) > 0 { + return si.SortByResourceColumnOrder + } + + return []string{ + si.ColNamespace, + si.ColObjectID, + si.ColRelation, + si.ColUsersetNamespace, + si.ColUsersetObjectID, + si.ColUsersetRelation, + } +} + +func (si SchemaInformation) sortBySubjectColumnOrderColumns() []string { + if len(si.SortBySubjectColumnOrder) > 0 { + return si.SortBySubjectColumnOrder + } + + return []string{ + si.ColUsersetNamespace, + si.ColUsersetObjectID, + si.ColUsersetRelation, + si.ColNamespace, + si.ColObjectID, + si.ColRelation, + } +} + +func (si SchemaInformation) mustValidate() { + if si.RelationshipTableName == "" { + panic("RelationshipTableName is required") + } + + if si.ColNamespace == "" { + panic("ColNamespace is required") + } + + if si.ColObjectID == "" { + panic("ColObjectID is required") + } + + if si.ColRelation == "" { + panic("ColRelation is required") + } + + if si.ColUsersetNamespace == "" { + panic("ColUsersetNamespace is required") + } + + if si.ColUsersetObjectID == "" { + panic("ColUsersetObjectID is required") + } + + if si.ColUsersetRelation == "" { + panic("ColUsersetRelation is required") + } + + if si.ColCaveatName == "" { + panic("ColCaveatName is required") + } + + if si.ColCaveatContext == "" { + panic("ColCaveatContext is required") + } + + if si.ColExpiration == "" { + panic("ColExpiration is required") + } + + if si.IntegrityEnabled { + if si.ColIntegrityKeyID == "" { + panic("ColIntegrityKeyID is required") + } + + if si.ColIntegrityHash == "" { + panic("ColIntegrityHash is required") + } + + if si.ColIntegrityTimestamp == "" { + panic("ColIntegrityTimestamp is required") + } + } + + if si.NowFunction == "" { + panic("NowFunction is required") + } + + if si.ColumnOptimization == ColumnOptimizationOptionUnknown { + panic("ColumnOptimization is required") + } + + if si.PaginationFilterType == PaginationFilterTypeUnknown { + panic("PaginationFilterType is required") + } + + if si.PlaceholderFormat == nil { + panic("PlaceholderFormat is required") + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go new file mode 100644 index 0000000..4972700 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go @@ -0,0 +1,17 @@ +package common + +import ( + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/tuple" +) + +// NewSliceRelationshipIterator creates a datastore.RelationshipIterator instance from a materialized slice of tuples. +func NewSliceRelationshipIterator(rels []tuple.Relationship) datastore.RelationshipIterator { + return func(yield func(tuple.Relationship, error) bool) { + for _, rel := range rels { + if !yield(rel, nil) { + break + } + } + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go new file mode 100644 index 0000000..ba9c4f6 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go @@ -0,0 +1,961 @@ +package common + +import ( + "context" + "fmt" + "maps" + "math" + "strings" + "time" + + sq "github.com/Masterminds/squirrel" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/jzelinskie/stringz" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var ( + // CaveatNameKey is a tracing attribute representing a caveat name + CaveatNameKey = attribute.Key("authzed.com/spicedb/sql/caveatName") + + // ObjNamespaceNameKey is a tracing attribute representing the resource + // object type. + ObjNamespaceNameKey = attribute.Key("authzed.com/spicedb/sql/objNamespaceName") + + // ObjRelationNameKey is a tracing attribute representing the resource + // relation. + ObjRelationNameKey = attribute.Key("authzed.com/spicedb/sql/objRelationName") + + // ObjIDKey is a tracing attribute representing the resource object ID. + ObjIDKey = attribute.Key("authzed.com/spicedb/sql/objId") + + // SubNamespaceNameKey is a tracing attribute representing the subject object + // type. + SubNamespaceNameKey = attribute.Key("authzed.com/spicedb/sql/subNamespaceName") + + // SubRelationNameKey is a tracing attribute representing the subject + // relation. + SubRelationNameKey = attribute.Key("authzed.com/spicedb/sql/subRelationName") + + // SubObjectIDKey is a tracing attribute representing the the subject object + // ID. + SubObjectIDKey = attribute.Key("authzed.com/spicedb/sql/subObjectId") + + tracer = otel.Tracer("spicedb/internal/datastore/common") +) + +// PaginationFilterType is an enumerator for pagination filter types. +type PaginationFilterType uint8 + +const ( + PaginationFilterTypeUnknown PaginationFilterType = iota + + // TupleComparison uses a comparison with a compound key, + // e.g. (namespace, object_id, relation) > ('ns', '123', 'viewer') + // which is not compatible with all datastores. + TupleComparison = 1 + + // ExpandedLogicComparison comparison uses a nested tree of ANDs and ORs to properly + // filter out already received relationships. Useful for databases that do not support + // tuple comparison, or do not execute it efficiently + ExpandedLogicComparison = 2 +) + +// ColumnOptimizationOption is an enumerator for column optimization options. +type ColumnOptimizationOption int + +const ( + ColumnOptimizationOptionUnknown ColumnOptimizationOption = iota + + // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns. + ColumnOptimizationOptionNone + + // ColumnOptimizationOptionStaticValues is an option that optimizes columns for static values. + ColumnOptimizationOptionStaticValues +) + +type columnTracker struct { + SingleValue *string +} + +type columnTrackerMap map[string]columnTracker + +func (ctm columnTrackerMap) hasStaticValue(columnName string) bool { + if r, ok := ctm[columnName]; ok && r.SingleValue != nil { + return true + } + return false +} + +// SchemaQueryFilterer wraps a SchemaInformation and SelectBuilder to give an opinionated +// way to build query objects. +type SchemaQueryFilterer struct { + schema SchemaInformation + queryBuilder sq.SelectBuilder + filteringColumnTracker columnTrackerMap + filterMaximumIDCount uint16 + isCustomQuery bool + extraFields []string + fromSuffix string + fromTable string + indexingHint IndexingHint +} + +// IndexingHint is an interface that can be implemented to provide a hint for the SQL query. +type IndexingHint interface { + // SQLPrefix returns the SQL prefix to be used for the indexing hint, if any. + SQLPrefix() (string, error) + + // FromTable returns the table name to be used for the indexing hint, if any. + FromTable(existingTableName string) (string, error) + + // FromSQLSuffix returns the suffix to be used for the indexing hint, if any. + FromSQLSuffix() (string, error) +} + +// NewSchemaQueryFiltererForRelationshipsSelect creates a new SchemaQueryFilterer object for selecting +// relationships. This method will automatically filter the columns retrieved from the database, only +// selecting the columns that are not already specified with a single static value in the query. +func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer { + schema.debugValidate() + + if filterMaximumIDCount == 0 { + filterMaximumIDCount = 100 + log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") + } + + queryBuilder := sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFormat).Select() + return SchemaQueryFilterer{ + schema: schema, + queryBuilder: queryBuilder, + filteringColumnTracker: map[string]columnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: false, + extraFields: extraFields, + fromTable: "", + } +} + +// NewSchemaQueryFiltererWithStartingQuery creates a new SchemaQueryFilterer object for selecting +// relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect, +// this method will not auto-filter the columns retrieved from the database. +func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { + schema.debugValidate() + + if filterMaximumIDCount == 0 { + filterMaximumIDCount = 100 + log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") + } + + return SchemaQueryFilterer{ + schema: schema, + queryBuilder: startingQuery, + filteringColumnTracker: map[string]columnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: true, + extraFields: nil, + fromTable: "", + } +} + +// WithAdditionalFilter returns the SchemaQueryFilterer with an additional filter applied to the query. +func (sqf SchemaQueryFilterer) WithAdditionalFilter(filter func(original sq.SelectBuilder) sq.SelectBuilder) SchemaQueryFilterer { + sqf.queryBuilder = filter(sqf.queryBuilder) + return sqf +} + +// WithFromTable returns the SchemaQueryFilterer with a custom FROM table. +func (sqf SchemaQueryFilterer) WithFromTable(fromTable string) SchemaQueryFilterer { + sqf.fromTable = fromTable + return sqf +} + +// WithFromSuffix returns the SchemaQueryFilterer with a suffix added to the FROM clause. +func (sqf SchemaQueryFilterer) WithFromSuffix(fromSuffix string) SchemaQueryFilterer { + sqf.fromSuffix = fromSuffix + return sqf +} + +// WithIndexingHint returns the SchemaQueryFilterer with an indexing hint applied to the query. +func (sqf SchemaQueryFilterer) WithIndexingHint(indexingHint IndexingHint) SchemaQueryFilterer { + sqf.indexingHint = indexingHint + return sqf +} + +func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { + spiceerrors.DebugAssert(func() bool { + return sqf.isCustomQuery + }, "UnderlyingQueryBuilder should only be called on custom queries") + return sqf.queryBuilderWithMaybeExpirationFilter(false) +} + +// queryBuilderWithMaybeExpirationFilter returns the query builder with the expiration filter applied, when necessary. +// Note that this adds the clause to the existing builder. +func (sqf SchemaQueryFilterer) queryBuilderWithMaybeExpirationFilter(skipExpiration bool) sq.SelectBuilder { + if sqf.schema.ExpirationDisabled || skipExpiration { + return sqf.queryBuilder + } + + // Filter out any expired relationships. + return sqf.queryBuilder.Where(sq.Or{ + sq.Eq{sqf.schema.ColExpiration: nil}, + sq.Expr(sqf.schema.ColExpiration + " > " + sqf.schema.NowFunction + "()"), + }) +} + +func (sqf SchemaQueryFilterer) TupleOrder(order options.SortOrder) SchemaQueryFilterer { + switch order { + case options.ByResource: + sqf.queryBuilder = sqf.queryBuilder.OrderBy(sqf.schema.sortByResourceColumnOrderColumns()...) + + case options.BySubject: + sqf.queryBuilder = sqf.queryBuilder.OrderBy(sqf.schema.sortBySubjectColumnOrderColumns()...) + } + + return sqf +} + +type nameAndValue struct { + name string + value string +} + +func columnsAndValuesForSort( + order options.SortOrder, + schema SchemaInformation, + cursor options.Cursor, +) ([]nameAndValue, error) { + var columnNames []string + + switch order { + case options.ByResource: + columnNames = schema.sortByResourceColumnOrderColumns() + + case options.BySubject: + columnNames = schema.sortBySubjectColumnOrderColumns() + + default: + return nil, spiceerrors.MustBugf("invalid sort order %q", order) + } + + nameAndValues := make([]nameAndValue, 0, len(columnNames)) + for _, columnName := range columnNames { + switch columnName { + case schema.ColNamespace: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Resource.ObjectType, + }) + + case schema.ColObjectID: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Resource.ObjectID, + }) + + case schema.ColRelation: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Resource.Relation, + }) + + case schema.ColUsersetNamespace: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Subject.ObjectType, + }) + + case schema.ColUsersetObjectID: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Subject.ObjectID, + }) + + case schema.ColUsersetRelation: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Subject.Relation, + }) + + default: + return nil, spiceerrors.MustBugf("invalid column name %q", columnName) + } + } + + return nameAndValues, nil +} + +func (sqf SchemaQueryFilterer) MustAfter(cursor options.Cursor, order options.SortOrder) SchemaQueryFilterer { + updated, err := sqf.After(cursor, order) + if err != nil { + panic(err) + } + return updated +} + +func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOrder) (SchemaQueryFilterer, error) { + spiceerrors.DebugAssertNotNil(cursor, "cursor cannot be nil") + + // NOTE: The ordering of these columns can affect query performance, be aware when changing. + columnsAndValues, err := columnsAndValuesForSort(order, sqf.schema, cursor) + if err != nil { + return sqf, err + } + + switch sqf.schema.PaginationFilterType { + case TupleComparison: + // For performance reasons, remove any column names that have static values in the query. + columnNames := make([]string, 0, len(columnsAndValues)) + valueSlots := make([]any, 0, len(columnsAndValues)) + comparisonSlotCount := 0 + + for _, cav := range columnsAndValues { + if !sqf.filteringColumnTracker.hasStaticValue(cav.name) { + columnNames = append(columnNames, cav.name) + valueSlots = append(valueSlots, cav.value) + comparisonSlotCount++ + } + } + + if comparisonSlotCount > 0 { + comparisonTuple := "(" + strings.Join(columnNames, ",") + ") > (" + strings.Repeat(",?", comparisonSlotCount)[1:] + ")" + sqf.queryBuilder = sqf.queryBuilder.Where( + comparisonTuple, + valueSlots..., + ) + } + + case ExpandedLogicComparison: + // For performance reasons, remove any column names that have static values in the query. + orClause := sq.Or{} + + for index, cav := range columnsAndValues { + if !sqf.filteringColumnTracker.hasStaticValue(cav.name) { + andClause := sq.And{} + for _, previous := range columnsAndValues[0:index] { + if !sqf.filteringColumnTracker.hasStaticValue(previous.name) { + andClause = append(andClause, sq.Eq{previous.name: previous.value}) + } + } + + andClause = append(andClause, sq.Gt{cav.name: cav.value}) + orClause = append(orClause, andClause) + } + } + + if len(orClause) > 0 { + sqf.queryBuilder = sqf.queryBuilder.Where(orClause) + } + } + + return sqf, nil +} + +// FilterToResourceType returns a new SchemaQueryFilterer that is limited to resources of the +// specified type. +func (sqf SchemaQueryFilterer) FilterToResourceType(resourceType string) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColNamespace: resourceType}) + sqf.recordColumnValue(sqf.schema.ColNamespace, resourceType) + return sqf +} + +func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string) { + existing, ok := sqf.filteringColumnTracker[colName] + if ok { + if existing.SingleValue != nil && *existing.SingleValue != colValue { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} + } + } else { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: &colValue} + } +} + +func (sqf SchemaQueryFilterer) recordVaryingColumnValue(colName string) { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} +} + +// FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the +// specified ID. +func (sqf SchemaQueryFilterer) FilterToResourceID(objectID string) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColObjectID: objectID}) + sqf.recordColumnValue(sqf.schema.ColObjectID, objectID) + return sqf +} + +func (sqf SchemaQueryFilterer) MustFilterToResourceIDs(resourceIds []string) SchemaQueryFilterer { + updated, err := sqf.FilterToResourceIDs(resourceIds) + if err != nil { + panic(err) + } + return updated +} + +// FilterWithResourceIDPrefix returns new SchemaQueryFilterer that is limited to resources whose ID +// starts with the specified prefix. +func (sqf SchemaQueryFilterer) FilterWithResourceIDPrefix(prefix string) (SchemaQueryFilterer, error) { + if strings.Contains(prefix, "%") { + return sqf, spiceerrors.MustBugf("prefix cannot contain the percent sign") + } + if prefix == "" { + return sqf, spiceerrors.MustBugf("prefix cannot be empty") + } + + prefix = strings.ReplaceAll(prefix, `\`, `\\`) + prefix = strings.ReplaceAll(prefix, "_", `\_`) + + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.ColObjectID: prefix + "%"}) + + // NOTE: we do *not* record the use of the resource ID column here, because it is not used + // statically and thus is necessary for sorting operations. + return sqf, nil +} + +func (sqf SchemaQueryFilterer) MustFilterWithResourceIDPrefix(prefix string) SchemaQueryFilterer { + updated, err := sqf.FilterWithResourceIDPrefix(prefix) + if err != nil { + panic(err) + } + return updated +} + +// FilterToResourceIDs returns a new SchemaQueryFilterer that is limited to resources with any of the +// specified IDs. +func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (SchemaQueryFilterer, error) { + spiceerrors.DebugAssert(func() bool { + return len(resourceIds) <= int(sqf.filterMaximumIDCount) + }, "cannot have more than %d resource IDs in a single filter", sqf.filterMaximumIDCount) + + var builder strings.Builder + builder.WriteString(sqf.schema.ColObjectID) + builder.WriteString(" IN (") + args := make([]any, 0, len(resourceIds)) + + for _, resourceID := range resourceIds { + if len(resourceID) == 0 { + return sqf, spiceerrors.MustBugf("got empty resource ID") + } + + args = append(args, resourceID) + sqf.recordColumnValue(sqf.schema.ColObjectID, resourceID) + } + + builder.WriteString("?") + if len(resourceIds) > 1 { + builder.WriteString(strings.Repeat(",?", len(resourceIds)-1)) + } + builder.WriteString(")") + + sqf.queryBuilder = sqf.queryBuilder.Where(builder.String(), args...) + return sqf, nil +} + +// FilterToRelation returns a new SchemaQueryFilterer that is limited to resources with the +// specified relation. +func (sqf SchemaQueryFilterer) FilterToRelation(relation string) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColRelation: relation}) + sqf.recordColumnValue(sqf.schema.ColRelation, relation) + return sqf +} + +// MustFilterWithRelationshipsFilter returns a new SchemaQueryFilterer that is limited to resources with +// resources that match the specified filter. +func (sqf SchemaQueryFilterer) MustFilterWithRelationshipsFilter(filter datastore.RelationshipsFilter) SchemaQueryFilterer { + updated, err := sqf.FilterWithRelationshipsFilter(filter) + if err != nil { + panic(err) + } + return updated +} + +func (sqf SchemaQueryFilterer) FilterWithRelationshipsFilter(filter datastore.RelationshipsFilter) (SchemaQueryFilterer, error) { + csqf := sqf + + if filter.OptionalResourceType != "" { + csqf = csqf.FilterToResourceType(filter.OptionalResourceType) + } + + if filter.OptionalResourceRelation != "" { + csqf = csqf.FilterToRelation(filter.OptionalResourceRelation) + } + + if len(filter.OptionalResourceIds) > 0 && filter.OptionalResourceIDPrefix != "" { + return csqf, spiceerrors.MustBugf("cannot filter by both resource IDs and ID prefix") + } + + if len(filter.OptionalResourceIds) > 0 { + usqf, err := csqf.FilterToResourceIDs(filter.OptionalResourceIds) + if err != nil { + return csqf, err + } + csqf = usqf + } + + if len(filter.OptionalResourceIDPrefix) > 0 { + usqf, err := csqf.FilterWithResourceIDPrefix(filter.OptionalResourceIDPrefix) + if err != nil { + return csqf, err + } + csqf = usqf + } + + if len(filter.OptionalSubjectsSelectors) > 0 { + usqf, err := csqf.FilterWithSubjectsSelectors(filter.OptionalSubjectsSelectors...) + if err != nil { + return csqf, err + } + csqf = usqf + } + + switch filter.OptionalCaveatNameFilter.Option { + case datastore.CaveatFilterOptionHasMatchingCaveat: + spiceerrors.DebugAssert(func() bool { + return filter.OptionalCaveatNameFilter.CaveatName != "" + }, "caveat name must be set when using HasMatchingCaveat") + csqf = csqf.FilterWithCaveatName(filter.OptionalCaveatNameFilter.CaveatName) + + case datastore.CaveatFilterOptionNoCaveat: + csqf = csqf.FilterWithNoCaveat() + + case datastore.CaveatFilterOptionNone: + // No action needed, as this is the default behavior. + + default: + return csqf, spiceerrors.MustBugf("unknown caveat filter option: %v", filter.OptionalCaveatNameFilter.Option) + } + + if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionHasExpiration { + csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.ColExpiration: nil}) + spiceerrors.DebugAssert(func() bool { return !sqf.schema.ExpirationDisabled }, "expiration filter requested but schema does not support expiration") + } else if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionNoExpiration { + csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.ColExpiration: nil}) + } + + return csqf, nil +} + +// MustFilterWithSubjectsSelectors returns a new SchemaQueryFilterer that is limited to resources with +// subjects that match the specified selector(s). +func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) SchemaQueryFilterer { + usqf, err := sqf.FilterWithSubjectsSelectors(selectors...) + if err != nil { + panic(err) + } + return usqf +} + +// FilterWithSubjectsSelectors returns a new SchemaQueryFilterer that is limited to resources with +// subjects that match the specified selector(s). +func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) { + selectorsOrClause := sq.Or{} + + // If there is more than a single filter, record all the subjects as varying, as the subjects returned + // can differ for each branch. + // TODO(jschorr): Optimize this further where applicable. + if len(selectors) > 1 { + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetNamespace) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetObjectID) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation) + } + + for _, selector := range selectors { + selectorClause := sq.And{} + + if len(selector.OptionalSubjectType) > 0 { + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetNamespace: selector.OptionalSubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, selector.OptionalSubjectType) + } + + if len(selector.OptionalSubjectIds) > 0 { + spiceerrors.DebugAssert(func() bool { + return len(selector.OptionalSubjectIds) <= int(sqf.filterMaximumIDCount) + }, "cannot have more than %d subject IDs in a single filter", sqf.filterMaximumIDCount) + + var builder strings.Builder + builder.WriteString(sqf.schema.ColUsersetObjectID) + builder.WriteString(" IN (") + args := make([]any, 0, len(selector.OptionalSubjectIds)) + + for _, subjectID := range selector.OptionalSubjectIds { + if len(subjectID) == 0 { + return sqf, spiceerrors.MustBugf("got empty subject ID") + } + + args = append(args, subjectID) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, subjectID) + } + + builder.WriteString("?") + if len(selector.OptionalSubjectIds) > 1 { + builder.WriteString(strings.Repeat(",?", len(selector.OptionalSubjectIds)-1)) + } + + builder.WriteString(")") + selectorClause = append(selectorClause, sq.Expr(builder.String(), args...)) + } + + if !selector.RelationFilter.IsEmpty() { + if selector.RelationFilter.OnlyNonEllipsisRelations { + selectorClause = append(selectorClause, sq.NotEq{sqf.schema.ColUsersetRelation: datastore.Ellipsis}) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation) + } else { + relations := make([]string, 0, 2) + if selector.RelationFilter.IncludeEllipsisRelation { + relations = append(relations, datastore.Ellipsis) + } + + if selector.RelationFilter.NonEllipsisRelation != "" { + relations = append(relations, selector.RelationFilter.NonEllipsisRelation) + } + + if len(relations) == 1 { + relName := relations[0] + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetRelation: relName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, relName) + } else { + orClause := sq.Or{} + for _, relationName := range relations { + dsRelationName := stringz.DefaultEmpty(relationName, datastore.Ellipsis) + orClause = append(orClause, sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, dsRelationName) + } + + selectorClause = append(selectorClause, orClause) + } + } + } + + selectorsOrClause = append(selectorsOrClause, selectorClause) + } + + sqf.queryBuilder = sqf.queryBuilder.Where(selectorsOrClause) + return sqf, nil +} + +// FilterToSubjectFilter returns a new SchemaQueryFilterer that is limited to resources with +// subjects that match the specified filter. +func (sqf SchemaQueryFilterer) FilterToSubjectFilter(filter *v1.SubjectFilter) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetNamespace: filter.SubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, filter.SubjectType) + + if filter.OptionalSubjectId != "" { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetObjectID: filter.OptionalSubjectId}) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, filter.OptionalSubjectId) + } + + if filter.OptionalRelation != nil { + dsRelationName := stringz.DefaultEmpty(filter.OptionalRelation.Relation, datastore.Ellipsis) + + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, datastore.Ellipsis) + } + + return sqf +} + +// FilterWithCaveatName returns a new SchemaQueryFilterer that is limited to resources with the +// specified caveat name. +func (sqf SchemaQueryFilterer) FilterWithCaveatName(caveatName string) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColCaveatName: caveatName}) + sqf.recordColumnValue(sqf.schema.ColCaveatName, caveatName) + return sqf +} + +// FilterWithNoCaveat returns a new SchemaQueryFilterer that is limited to resources with no caveat. +func (sqf SchemaQueryFilterer) FilterWithNoCaveat() SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where( + sq.Or{ + sq.Eq{sqf.schema.ColCaveatName: nil}, + sq.Eq{sqf.schema.ColCaveatName: ""}, + }) + sqf.recordVaryingColumnValue(sqf.schema.ColCaveatName) + return sqf +} + +// Limit returns a new SchemaQueryFilterer which is limited to the specified number of results. +func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Limit(limit) + return sqf +} + +// QueryRelationshipsExecutor is a relationships query runner shared by SQL implementations of the datastore. +type QueryRelationshipsExecutor struct { + Executor ExecuteReadRelsQueryFunc +} + +// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. +type ExecuteReadRelsQueryFunc func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) + +// ExecuteQuery executes the query. +func (exc QueryRelationshipsExecutor) ExecuteQuery( + ctx context.Context, + query SchemaQueryFilterer, + opts ...options.QueryOptionsOption, +) (datastore.RelationshipIterator, error) { + if query.isCustomQuery { + return nil, spiceerrors.MustBugf("ExecuteQuery should not be called on custom queries") + } + + queryOpts := options.NewQueryOptionsWithOptions(opts...) + + // Add sort order. + query = query.TupleOrder(queryOpts.Sort) + + // Add cursor. + if queryOpts.After != nil { + if queryOpts.Sort == options.Unsorted { + return nil, datastore.ErrCursorsWithoutSorting + } + + q, err := query.After(queryOpts.After, queryOpts.Sort) + if err != nil { + return nil, err + } + query = q + } + + // Add limit. + var limit uint64 + // NOTE: we use a uint here because it lines up with the + // assignments in this function, but we set it to MaxInt64 + // because that's the biggest value that postgres and friends + // treat as valid. + limit = math.MaxInt64 + if queryOpts.Limit != nil { + limit = *queryOpts.Limit + } + + if limit < math.MaxInt64 { + query = query.limit(limit) + } + + // Add FROM clause. + from := query.schema.RelationshipTableName + if query.fromTable != "" { + from = query.fromTable + } + + // Add index hints, if any. + if query.indexingHint != nil { + // Check for a SQL prefix (pg_hint_plan). + sqlPrefix, err := query.indexingHint.SQLPrefix() + if err != nil { + return nil, fmt.Errorf("error getting SQL prefix for indexing hint: %w", err) + } + + if sqlPrefix != "" { + query.queryBuilder = query.queryBuilder.Prefix(sqlPrefix) + } + + // Check for an adjusting FROM table name (CRDB). + fromTableName, err := query.indexingHint.FromTable(from) + if err != nil { + return nil, fmt.Errorf("error getting FROM table name for indexing hint: %w", err) + } + from = fromTableName + + // Check for a SQL suffix (MySQL, Spanner). + fromSuffix, err := query.indexingHint.FromSQLSuffix() + if err != nil { + return nil, fmt.Errorf("error getting SQL suffix for indexing hint: %w", err) + } + + if fromSuffix != "" { + from += " " + fromSuffix + } + } + + if query.fromSuffix != "" { + from += " " + query.fromSuffix + } + + query.queryBuilder = query.queryBuilder.From(from) + + builder := RelationshipsQueryBuilder{ + Schema: query.schema, + SkipCaveats: queryOpts.SkipCaveats, + SkipExpiration: queryOpts.SkipExpiration, + SQLCheckAssertionForTest: queryOpts.SQLCheckAssertionForTest, + SQLExplainCallbackForTest: queryOpts.SQLExplainCallbackForTest, + filteringValues: query.filteringColumnTracker, + queryShape: queryOpts.QueryShape, + baseQueryBuilder: query, + } + + return exc.Executor(ctx, builder) +} + +// RelationshipsQueryBuilder is a builder for producing the SQL and arguments necessary for reading +// relationships. +type RelationshipsQueryBuilder struct { + Schema SchemaInformation + SkipCaveats bool + SkipExpiration bool + + filteringValues columnTrackerMap + baseQueryBuilder SchemaQueryFilterer + SQLCheckAssertionForTest options.SQLCheckAssertionForTest + SQLExplainCallbackForTest options.SQLExplainCallbackForTest + queryShape queryshape.Shape +} + +// withCaveats returns true if caveats should be included in the query. +func (b RelationshipsQueryBuilder) withCaveats() bool { + return !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone +} + +// withExpiration returns true if expiration should be included in the query. +func (b RelationshipsQueryBuilder) withExpiration() bool { + return !b.SkipExpiration && !b.Schema.ExpirationDisabled +} + +// integrityEnabled returns true if integrity columns should be included in the query. +func (b RelationshipsQueryBuilder) integrityEnabled() bool { + return b.Schema.IntegrityEnabled +} + +// columnCount returns the number of columns that will be selected in the query. +func (b RelationshipsQueryBuilder) columnCount() int { + columnCount := relationshipStandardColumnCount + if b.withCaveats() { + columnCount += relationshipCaveatColumnCount + } + if b.withExpiration() { + columnCount += relationshipExpirationColumnCount + } + if b.integrityEnabled() { + columnCount += relationshipIntegrityColumnCount + } + return columnCount +} + +// SelectSQL returns the SQL and arguments necessary for reading relationships. +func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { + // Set the column names to select. + columnNamesToSelect := make([]string, 0, b.columnCount()) + + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColRelation) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetRelation) + + if b.withCaveats() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext) + } + + if b.withExpiration() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) + } + + if b.integrityEnabled() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp) + } + + if len(columnNamesToSelect) == 0 { + columnNamesToSelect = append(columnNamesToSelect, "1") + } + + sqlBuilder := b.baseQueryBuilder.queryBuilderWithMaybeExpirationFilter(b.SkipExpiration) + sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) + + sql, args, err := sqlBuilder.ToSql() + if err != nil { + return "", nil, err + } + + if b.SQLCheckAssertionForTest != nil { + b.SQLCheckAssertionForTest(sql) + } + + return sql, args, nil +} + +// FilteringValuesForTesting returns the filtering values. For test use only. +func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]columnTracker { + return maps.Clone(b.filteringValues) +} + +func (b RelationshipsQueryBuilder) checkColumn(columns []string, colName string) []string { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + return append(columns, colName) + } + + if !b.filteringValues.hasStaticValue(colName) { + return append(columns, colName) + } + + return columns +} + +func (b RelationshipsQueryBuilder) staticValueOrAddColumnForSelect(colsToSelect []any, colName string, field *string) []any { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + // If column optimization is disabled, always add the column to the list of columns to select. + colsToSelect = append(colsToSelect, field) + return colsToSelect + } + + // If the value is static, set the field to it and return. + if found, ok := b.filteringValues[colName]; ok && found.SingleValue != nil { + *field = *found.SingleValue + return colsToSelect + } + + // Otherwise, add the column to the list of columns to select, as the value is not static. + colsToSelect = append(colsToSelect, field) + return colsToSelect +} + +// ColumnsToSelect returns the columns to select for a given query. The columns provided are +// the references to the slots in which the values for each relationship will be placed. +func ColumnsToSelect[CN any, CC any, EC any]( + b RelationshipsQueryBuilder, + resourceObjectType *string, + resourceObjectID *string, + resourceRelation *string, + subjectObjectType *string, + subjectObjectID *string, + subjectRelation *string, + caveatName *CN, + caveatCtx *CC, + expiration EC, + + integrityKeyID *string, + integrityHash *[]byte, + timestamp *time.Time, +) ([]any, error) { + colsToSelect := make([]any, 0, b.columnCount()) + + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColNamespace, resourceObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColObjectID, resourceObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColRelation, resourceRelation) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetNamespace, subjectObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetObjectID, subjectObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetRelation, subjectRelation) + + if b.withCaveats() { + colsToSelect = append(colsToSelect, caveatName, caveatCtx) + } + + if b.withExpiration() { + colsToSelect = append(colsToSelect, expiration) + } + + if b.Schema.IntegrityEnabled { + colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp) + } + + if len(colsToSelect) == 0 { + var unused int64 + colsToSelect = append(colsToSelect, &unused) + } + + return colsToSelect, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go new file mode 100644 index 0000000..fa23efc --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go @@ -0,0 +1,31 @@ +package common + +import ( + "context" + "errors" + "strings" +) + +// IsCancellationError determines if an error returned by pgx has been caused by context cancellation. +func IsCancellationError(err error) bool { + if errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) || + err.Error() == "conn closed" { // conns are sometimes closed async upon cancellation + return true + } + return false +} + +// IsResettableError returns whether the given error is a resettable error. +func IsResettableError(err error) bool { + // detect when an error is likely due to a node taken out of service + if strings.Contains(err.Error(), "broken pipe") || + strings.Contains(err.Error(), "unexpected EOF") || + strings.Contains(err.Error(), "conn closed") || + strings.Contains(err.Error(), "connection refused") || + strings.Contains(err.Error(), "connection reset by peer") { + return true + } + + return false +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go new file mode 100644 index 0000000..be665ed --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go @@ -0,0 +1,19 @@ +package common + +import ( + "errors" + "net/url" +) + +// MetricsIDFromURL extracts the metrics ID from a given datastore URL. +func MetricsIDFromURL(dsURL string) (string, error) { + if dsURL == "" { + return "", errors.New("datastore URL is empty") + } + + u, err := url.Parse(dsURL) + if err != nil { + return "", errors.New("could not parse datastore URL") + } + return u.Host + u.Path, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go new file mode 100644 index 0000000..2caa57a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go @@ -0,0 +1,276 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package common + +import ( + squirrel "github.com/Masterminds/squirrel" + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" +) + +type SchemaInformationOption func(s *SchemaInformation) + +// NewSchemaInformationWithOptions creates a new SchemaInformation with the passed in options set +func NewSchemaInformationWithOptions(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + for _, o := range opts { + o(s) + } + return s +} + +// NewSchemaInformationWithOptionsAndDefaults creates a new SchemaInformation with the passed in options set starting from the defaults +func NewSchemaInformationWithOptionsAndDefaults(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + defaults.MustSet(s) + for _, o := range opts { + o(s) + } + return s +} + +// ToOption returns a new SchemaInformationOption that sets the values from the passed in SchemaInformation +func (s *SchemaInformation) ToOption() SchemaInformationOption { + return func(to *SchemaInformation) { + to.RelationshipTableName = s.RelationshipTableName + to.ColNamespace = s.ColNamespace + to.ColObjectID = s.ColObjectID + to.ColRelation = s.ColRelation + to.ColUsersetNamespace = s.ColUsersetNamespace + to.ColUsersetObjectID = s.ColUsersetObjectID + to.ColUsersetRelation = s.ColUsersetRelation + to.ColCaveatName = s.ColCaveatName + to.ColCaveatContext = s.ColCaveatContext + to.ColExpiration = s.ColExpiration + to.ColIntegrityKeyID = s.ColIntegrityKeyID + to.ColIntegrityHash = s.ColIntegrityHash + to.ColIntegrityTimestamp = s.ColIntegrityTimestamp + to.Indexes = s.Indexes + to.PaginationFilterType = s.PaginationFilterType + to.PlaceholderFormat = s.PlaceholderFormat + to.NowFunction = s.NowFunction + to.ColumnOptimization = s.ColumnOptimization + to.IntegrityEnabled = s.IntegrityEnabled + to.ExpirationDisabled = s.ExpirationDisabled + to.SortByResourceColumnOrder = s.SortByResourceColumnOrder + to.SortBySubjectColumnOrder = s.SortBySubjectColumnOrder + } +} + +// DebugMap returns a map form of SchemaInformation for debugging +func (s SchemaInformation) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["RelationshipTableName"] = helpers.DebugValue(s.RelationshipTableName, false) + debugMap["ColNamespace"] = helpers.DebugValue(s.ColNamespace, false) + debugMap["ColObjectID"] = helpers.DebugValue(s.ColObjectID, false) + debugMap["ColRelation"] = helpers.DebugValue(s.ColRelation, false) + debugMap["ColUsersetNamespace"] = helpers.DebugValue(s.ColUsersetNamespace, false) + debugMap["ColUsersetObjectID"] = helpers.DebugValue(s.ColUsersetObjectID, false) + debugMap["ColUsersetRelation"] = helpers.DebugValue(s.ColUsersetRelation, false) + debugMap["ColCaveatName"] = helpers.DebugValue(s.ColCaveatName, false) + debugMap["ColCaveatContext"] = helpers.DebugValue(s.ColCaveatContext, false) + debugMap["ColExpiration"] = helpers.DebugValue(s.ColExpiration, false) + debugMap["ColIntegrityKeyID"] = helpers.DebugValue(s.ColIntegrityKeyID, false) + debugMap["ColIntegrityHash"] = helpers.DebugValue(s.ColIntegrityHash, false) + debugMap["ColIntegrityTimestamp"] = helpers.DebugValue(s.ColIntegrityTimestamp, false) + debugMap["Indexes"] = helpers.DebugValue(s.Indexes, false) + debugMap["PaginationFilterType"] = helpers.DebugValue(s.PaginationFilterType, false) + debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false) + debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false) + debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false) + debugMap["IntegrityEnabled"] = helpers.DebugValue(s.IntegrityEnabled, false) + debugMap["ExpirationDisabled"] = helpers.DebugValue(s.ExpirationDisabled, false) + debugMap["SortByResourceColumnOrder"] = helpers.DebugValue(s.SortByResourceColumnOrder, false) + debugMap["SortBySubjectColumnOrder"] = helpers.DebugValue(s.SortBySubjectColumnOrder, false) + return debugMap +} + +// SchemaInformationWithOptions configures an existing SchemaInformation with the passed in options set +func SchemaInformationWithOptions(s *SchemaInformation, opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithOptions configures the receiver SchemaInformation with the passed in options set +func (s *SchemaInformation) WithOptions(opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithRelationshipTableName returns an option that can set RelationshipTableName on a SchemaInformation +func WithRelationshipTableName(relationshipTableName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.RelationshipTableName = relationshipTableName + } +} + +// WithColNamespace returns an option that can set ColNamespace on a SchemaInformation +func WithColNamespace(colNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColNamespace = colNamespace + } +} + +// WithColObjectID returns an option that can set ColObjectID on a SchemaInformation +func WithColObjectID(colObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColObjectID = colObjectID + } +} + +// WithColRelation returns an option that can set ColRelation on a SchemaInformation +func WithColRelation(colRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColRelation = colRelation + } +} + +// WithColUsersetNamespace returns an option that can set ColUsersetNamespace on a SchemaInformation +func WithColUsersetNamespace(colUsersetNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetNamespace = colUsersetNamespace + } +} + +// WithColUsersetObjectID returns an option that can set ColUsersetObjectID on a SchemaInformation +func WithColUsersetObjectID(colUsersetObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetObjectID = colUsersetObjectID + } +} + +// WithColUsersetRelation returns an option that can set ColUsersetRelation on a SchemaInformation +func WithColUsersetRelation(colUsersetRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetRelation = colUsersetRelation + } +} + +// WithColCaveatName returns an option that can set ColCaveatName on a SchemaInformation +func WithColCaveatName(colCaveatName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatName = colCaveatName + } +} + +// WithColCaveatContext returns an option that can set ColCaveatContext on a SchemaInformation +func WithColCaveatContext(colCaveatContext string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatContext = colCaveatContext + } +} + +// WithColExpiration returns an option that can set ColExpiration on a SchemaInformation +func WithColExpiration(colExpiration string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColExpiration = colExpiration + } +} + +// WithColIntegrityKeyID returns an option that can set ColIntegrityKeyID on a SchemaInformation +func WithColIntegrityKeyID(colIntegrityKeyID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityKeyID = colIntegrityKeyID + } +} + +// WithColIntegrityHash returns an option that can set ColIntegrityHash on a SchemaInformation +func WithColIntegrityHash(colIntegrityHash string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityHash = colIntegrityHash + } +} + +// WithColIntegrityTimestamp returns an option that can set ColIntegrityTimestamp on a SchemaInformation +func WithColIntegrityTimestamp(colIntegrityTimestamp string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityTimestamp = colIntegrityTimestamp + } +} + +// WithIndexes returns an option that can append Indexess to SchemaInformation.Indexes +func WithIndexes(indexes IndexDefinition) SchemaInformationOption { + return func(s *SchemaInformation) { + s.Indexes = append(s.Indexes, indexes) + } +} + +// SetIndexes returns an option that can set Indexes on a SchemaInformation +func SetIndexes(indexes []IndexDefinition) SchemaInformationOption { + return func(s *SchemaInformation) { + s.Indexes = indexes + } +} + +// WithPaginationFilterType returns an option that can set PaginationFilterType on a SchemaInformation +func WithPaginationFilterType(paginationFilterType PaginationFilterType) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PaginationFilterType = paginationFilterType + } +} + +// WithPlaceholderFormat returns an option that can set PlaceholderFormat on a SchemaInformation +func WithPlaceholderFormat(placeholderFormat squirrel.PlaceholderFormat) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PlaceholderFormat = placeholderFormat + } +} + +// WithNowFunction returns an option that can set NowFunction on a SchemaInformation +func WithNowFunction(nowFunction string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.NowFunction = nowFunction + } +} + +// WithColumnOptimization returns an option that can set ColumnOptimization on a SchemaInformation +func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColumnOptimization = columnOptimization + } +} + +// WithIntegrityEnabled returns an option that can set IntegrityEnabled on a SchemaInformation +func WithIntegrityEnabled(integrityEnabled bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.IntegrityEnabled = integrityEnabled + } +} + +// WithExpirationDisabled returns an option that can set ExpirationDisabled on a SchemaInformation +func WithExpirationDisabled(expirationDisabled bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ExpirationDisabled = expirationDisabled + } +} + +// WithSortByResourceColumnOrder returns an option that can append SortByResourceColumnOrders to SchemaInformation.SortByResourceColumnOrder +func WithSortByResourceColumnOrder(sortByResourceColumnOrder string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.SortByResourceColumnOrder = append(s.SortByResourceColumnOrder, sortByResourceColumnOrder) + } +} + +// SetSortByResourceColumnOrder returns an option that can set SortByResourceColumnOrder on a SchemaInformation +func SetSortByResourceColumnOrder(sortByResourceColumnOrder []string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.SortByResourceColumnOrder = sortByResourceColumnOrder + } +} + +// WithSortBySubjectColumnOrder returns an option that can append SortBySubjectColumnOrders to SchemaInformation.SortBySubjectColumnOrder +func WithSortBySubjectColumnOrder(sortBySubjectColumnOrder string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.SortBySubjectColumnOrder = append(s.SortBySubjectColumnOrder, sortBySubjectColumnOrder) + } +} + +// SetSortBySubjectColumnOrder returns an option that can set SortBySubjectColumnOrder on a SchemaInformation +func SetSortBySubjectColumnOrder(sortBySubjectColumnOrder []string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.SortBySubjectColumnOrder = sortBySubjectColumnOrder + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md new file mode 100644 index 0000000..de32e34 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md @@ -0,0 +1,23 @@ +# MemDB Datastore Implementation + +The `memdb` datastore implementation is based on Hashicorp's [go-memdb library](https://github.com/hashicorp/go-memdb). +Its implementation most closely mimics that of `spanner`, or `crdb`, where there is a single immutable datastore that supports querying at any point in time. +The `memdb` datastore is used for validating and rapidly iterating on concepts from consumers of other datastores. +It is 100% compliant with the datastore acceptance test suite and it should be possible to use it in place of any other datastore for development purposes. +Differences between the `memdb` datastore and other implementations that manifest themselves as differences visible to the caller should be reported as bugs. + +**The memdb datastore can NOT be used in a production setting!** + +## Implementation Caveats + +### No Garbage Collection + +This implementation of the datastore has no garbage collection, meaning that memory usage will grow monotonically with mutations. + +### No Durable Storage + +The `memdb` datastore, as its name implies, stores information entirely in memory, and therefore will lose all data when the host process terminates. + +### Cannot be used for multi-node dispatch + +If you attempt to run SpiceDB with multi-node dispatch enabled using the memory datastore, each independent node will get a separate copy of the datastore, and you will end up very confused. diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go new file mode 100644 index 0000000..2b4baca --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go @@ -0,0 +1,156 @@ +package memdb + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-memdb" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +const tableCaveats = "caveats" + +type caveat struct { + name string + definition []byte + revision datastore.Revision +} + +func (c *caveat) Unwrap() (*core.CaveatDefinition, error) { + definition := core.CaveatDefinition{} + err := definition.UnmarshalVT(c.definition) + return &definition, err +} + +func (r *memdbReader) ReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, datastore.NoRevision, err + } + return r.readUnwrappedCaveatByName(tx, name) +} + +func (r *memdbReader) readCaveatByName(tx *memdb.Txn, name string) (*caveat, datastore.Revision, error) { + found, err := tx.First(tableCaveats, indexID, name) + if err != nil { + return nil, datastore.NoRevision, err + } + if found == nil { + return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) + } + cvt := found.(*caveat) + return cvt, cvt.revision, nil +} + +func (r *memdbReader) readUnwrappedCaveatByName(tx *memdb.Txn, name string) (*core.CaveatDefinition, datastore.Revision, error) { + c, rev, err := r.readCaveatByName(tx, name) + if err != nil { + return nil, datastore.NoRevision, err + } + unwrapped, err := c.Unwrap() + if err != nil { + return nil, datastore.NoRevision, err + } + return unwrapped, rev, nil +} + +func (r *memdbReader) ListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) { + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + var caveats []datastore.RevisionedCaveat + it, err := tx.LowerBound(tableCaveats, indexID) + if err != nil { + return nil, err + } + + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + rawCaveat := foundRaw.(*caveat) + definition, err := rawCaveat.Unwrap() + if err != nil { + return nil, err + } + caveats = append(caveats, datastore.RevisionedCaveat{ + Definition: definition, + LastWrittenRevision: rawCaveat.revision, + }) + } + + return caveats, nil +} + +func (r *memdbReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { + allCaveats, err := r.ListAllCaveats(ctx) + if err != nil { + return nil, err + } + + allowedCaveatNames := mapz.NewSet[string]() + allowedCaveatNames.Extend(caveatNames) + + toReturn := make([]datastore.RevisionedCaveat, 0, len(caveatNames)) + for _, caveat := range allCaveats { + if allowedCaveatNames.Has(caveat.Definition.Name) { + toReturn = append(toReturn, caveat) + } + } + return toReturn, nil +} + +func (rwt *memdbReadWriteTx) WriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error { + rwt.mustLock() + defer rwt.Unlock() + tx, err := rwt.txSource() + if err != nil { + return err + } + return rwt.writeCaveat(tx, caveats) +} + +func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.CaveatDefinition) error { + caveatNames := mapz.NewSet[string]() + for _, coreCaveat := range caveats { + if !caveatNames.Add(coreCaveat.Name) { + return fmt.Errorf("duplicate caveat %s", coreCaveat.Name) + } + marshalled, err := coreCaveat.MarshalVT() + if err != nil { + return err + } + c := caveat{ + name: coreCaveat.Name, + definition: marshalled, + revision: rwt.newRevision, + } + if err := tx.Insert(tableCaveats, &c); err != nil { + return err + } + } + return nil +} + +func (rwt *memdbReadWriteTx) DeleteCaveats(_ context.Context, names []string) error { + rwt.mustLock() + defer rwt.Unlock() + tx, err := rwt.txSource() + if err != nil { + return err + } + for _, name := range names { + if err := tx.Delete(tableCaveats, caveat{name: name}); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go new file mode 100644 index 0000000..0ef4b8b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go @@ -0,0 +1,37 @@ +package memdb + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// SerializationMaxRetriesReachedError occurs when a write request has reached its maximum number +// of retries due to serialization errors. +type SerializationMaxRetriesReachedError struct { + error +} + +// NewSerializationMaxRetriesReachedErr constructs a new max retries reached error. +func NewSerializationMaxRetriesReachedErr(baseErr error) error { + return SerializationMaxRetriesReachedError{ + error: baseErr, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err SerializationMaxRetriesReachedError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.DeadlineExceeded, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{ + "details": "too many updates were made to the in-memory datastore at once; this datastore has limited write throughput capability", + }, + ), + ) +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go new file mode 100644 index 0000000..61eba84 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go @@ -0,0 +1,386 @@ +package memdb + +import ( + "context" + "errors" + "fmt" + "math" + "sort" + "sync" + "time" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + + "github.com/google/uuid" + "github.com/hashicorp/go-memdb" + + "github.com/authzed/spicedb/internal/datastore/revisions" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +const ( + Engine = "memory" + defaultWatchBufferLength = 128 + numAttempts = 10 +) + +var ( + ErrMemDBIsClosed = errors.New("datastore is closed") + ErrSerialization = errors.New("serialization error") +) + +// DisableGC is a convenient constant for setting the garbage collection +// interval high enough that it will never run. +const DisableGC = time.Duration(math.MaxInt64) + +// NewMemdbDatastore creates a new Datastore compliant datastore backed by memdb. +// +// If the watchBufferLength value of 0 is set then a default value of 128 will be used. +func NewMemdbDatastore( + watchBufferLength uint16, + revisionQuantization, + gcWindow time.Duration, +) (datastore.Datastore, error) { + if revisionQuantization > gcWindow { + return nil, errors.New("gc window must be larger than quantization interval") + } + + if revisionQuantization <= 1 { + revisionQuantization = 1 + } + + db, err := memdb.NewMemDB(schema) + if err != nil { + return nil, err + } + + if watchBufferLength == 0 { + watchBufferLength = defaultWatchBufferLength + } + + uniqueID := uuid.NewString() + return &memdbDatastore{ + CommonDecoder: revisions.CommonDecoder{ + Kind: revisions.Timestamp, + }, + db: db, + revisions: []snapshot{ + { + revision: nowRevision(), + db: db, + }, + }, + + negativeGCWindow: gcWindow.Nanoseconds() * -1, + quantizationPeriod: revisionQuantization.Nanoseconds(), + watchBufferLength: watchBufferLength, + watchBufferWriteTimeout: 100 * time.Millisecond, + uniqueID: uniqueID, + }, nil +} + +type memdbDatastore struct { + sync.RWMutex + revisions.CommonDecoder + + // NOTE: call checkNotClosed before using + db *memdb.MemDB // GUARDED_BY(RWMutex) + revisions []snapshot // GUARDED_BY(RWMutex) + activeWriteTxn *memdb.Txn // GUARDED_BY(RWMutex) + + negativeGCWindow int64 + quantizationPeriod int64 + watchBufferLength uint16 + watchBufferWriteTimeout time.Duration + uniqueID string +} + +type snapshot struct { + revision revisions.TimestampRevision + db *memdb.MemDB +} + +func (mdb *memdbDatastore) MetricsID() (string, error) { + return "memdb", nil +} + +func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision) datastore.Reader { + mdb.RLock() + defer mdb.RUnlock() + + if err := mdb.checkNotClosed(); err != nil { + return &memdbReader{nil, nil, err, time.Now()} + } + + if len(mdb.revisions) == 0 { + return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is not ready"), time.Now()} + } + + if err := mdb.checkRevisionLocalCallerMustLock(dr); err != nil { + return &memdbReader{nil, nil, err, time.Now()} + } + + revIndex := sort.Search(len(mdb.revisions), func(i int) bool { + return mdb.revisions[i].revision.GreaterThan(dr) || mdb.revisions[i].revision.Equal(dr) + }) + + // handle the case when there is no revision snapshot newer than the requested revision + if revIndex == len(mdb.revisions) { + revIndex = len(mdb.revisions) - 1 + } + + rev := mdb.revisions[revIndex] + if rev.db == nil { + return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is already closed"), time.Now()} + } + + roTxn := rev.db.Txn(false) + txSrc := func() (*memdb.Txn, error) { + return roTxn, nil + } + + return &memdbReader{noopTryLocker{}, txSrc, nil, time.Now()} +} + +func (mdb *memdbDatastore) SupportsIntegrity() bool { + return true +} + +func (mdb *memdbDatastore) ReadWriteTx( + ctx context.Context, + f datastore.TxUserFunc, + opts ...options.RWTOptionsOption, +) (datastore.Revision, error) { + config := options.NewRWTOptionsWithOptions(opts...) + txNumAttempts := numAttempts + if config.DisableRetries { + txNumAttempts = 1 + } + + for i := 0; i < txNumAttempts; i++ { + var tx *memdb.Txn + createTxOnce := sync.Once{} + txSrc := func() (*memdb.Txn, error) { + var err error + createTxOnce.Do(func() { + mdb.Lock() + defer mdb.Unlock() + + if mdb.activeWriteTxn != nil { + err = ErrSerialization + return + } + + if err = mdb.checkNotClosed(); err != nil { + return + } + + tx = mdb.db.Txn(true) + tx.TrackChanges() + mdb.activeWriteTxn = tx + }) + + return tx, err + } + + newRevision := mdb.newRevisionID() + rwt := &memdbReadWriteTx{memdbReader{&sync.Mutex{}, txSrc, nil, time.Now()}, newRevision} + if err := f(ctx, rwt); err != nil { + mdb.Lock() + if tx != nil { + tx.Abort() + mdb.activeWriteTxn = nil + } + + // If the error was a serialization error, retry the transaction + if errors.Is(err, ErrSerialization) { + mdb.Unlock() + + // If we don't sleep here, we run out of retries instantaneously + time.Sleep(1 * time.Millisecond) + continue + } + defer mdb.Unlock() + + // We *must* return the inner error unmodified in case it's not an error type + // that supports unwrapping (e.g. gRPC errors) + return datastore.NoRevision, err + } + + mdb.Lock() + defer mdb.Unlock() + + tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) + if tx != nil { + if config.Metadata != nil && len(config.Metadata.GetFields()) > 0 { + if err := tracked.SetRevisionMetadata(ctx, newRevision, config.Metadata.AsMap()); err != nil { + return datastore.NoRevision, err + } + } + + for _, change := range tx.Changes() { + switch change.Table { + case tableRelationship: + if change.After != nil { + rt, err := change.After.(*relationship).Relationship() + if err != nil { + return datastore.NoRevision, err + } + + if err := tracked.AddRelationshipChange(ctx, newRevision, rt, tuple.UpdateOperationTouch); err != nil { + return datastore.NoRevision, err + } + } else if change.After == nil && change.Before != nil { + rt, err := change.Before.(*relationship).Relationship() + if err != nil { + return datastore.NoRevision, err + } + + if err := tracked.AddRelationshipChange(ctx, newRevision, rt, tuple.UpdateOperationDelete); err != nil { + return datastore.NoRevision, err + } + } else { + return datastore.NoRevision, spiceerrors.MustBugf("unexpected relationship change") + } + case tableNamespace: + if change.After != nil { + loaded := &corev1.NamespaceDefinition{} + if err := loaded.UnmarshalVT(change.After.(*namespace).configBytes); err != nil { + return datastore.NoRevision, err + } + + err := tracked.AddChangedDefinition(ctx, newRevision, loaded) + if err != nil { + return datastore.NoRevision, err + } + } else if change.After == nil && change.Before != nil { + err := tracked.AddDeletedNamespace(ctx, newRevision, change.Before.(*namespace).name) + if err != nil { + return datastore.NoRevision, err + } + } else { + return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change") + } + case tableCaveats: + if change.After != nil { + loaded := &corev1.CaveatDefinition{} + if err := loaded.UnmarshalVT(change.After.(*caveat).definition); err != nil { + return datastore.NoRevision, err + } + + err := tracked.AddChangedDefinition(ctx, newRevision, loaded) + if err != nil { + return datastore.NoRevision, err + } + } else if change.After == nil && change.Before != nil { + err := tracked.AddDeletedCaveat(ctx, newRevision, change.Before.(*caveat).name) + if err != nil { + return datastore.NoRevision, err + } + } else { + return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change") + } + } + } + + var rc datastore.RevisionChanges + changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + if err != nil { + return datastore.NoRevision, err + } + + if len(changes) > 1 { + return datastore.NoRevision, spiceerrors.MustBugf("unexpected MemDB transaction with multiple revision changes") + } else if len(changes) == 1 { + rc = changes[0] + } + + change := &changelog{ + revisionNanos: newRevision.TimestampNanoSec(), + changes: rc, + } + if err := tx.Insert(tableChangelog, change); err != nil { + return datastore.NoRevision, fmt.Errorf("error writing changelog: %w", err) + } + + tx.Commit() + } + mdb.activeWriteTxn = nil + + if err := mdb.checkNotClosed(); err != nil { + return datastore.NoRevision, err + } + + // Create a snapshot and add it to the revisions slice + snap := mdb.db.Snapshot() + mdb.revisions = append(mdb.revisions, snapshot{newRevision, snap}) + return newRevision, nil + } + + return datastore.NoRevision, NewSerializationMaxRetriesReachedErr(errors.New("serialization max retries exceeded; please reduce your parallel writes")) +} + +func (mdb *memdbDatastore) ReadyState(_ context.Context) (datastore.ReadyState, error) { + mdb.RLock() + defer mdb.RUnlock() + + return datastore.ReadyState{ + Message: "missing expected initial revision", + IsReady: len(mdb.revisions) > 0, + }, nil +} + +func (mdb *memdbDatastore) OfflineFeatures() (*datastore.Features, error) { + return &datastore.Features{ + Watch: datastore.Feature{ + Status: datastore.FeatureSupported, + }, + IntegrityData: datastore.Feature{ + Status: datastore.FeatureSupported, + }, + ContinuousCheckpointing: datastore.Feature{ + Status: datastore.FeatureUnsupported, + }, + WatchEmitsImmediately: datastore.Feature{ + Status: datastore.FeatureUnsupported, + }, + }, nil +} + +func (mdb *memdbDatastore) Features(_ context.Context) (*datastore.Features, error) { + return mdb.OfflineFeatures() +} + +func (mdb *memdbDatastore) Close() error { + mdb.Lock() + defer mdb.Unlock() + + if db := mdb.db; db != nil { + mdb.revisions = []snapshot{ + { + revision: nowRevision(), + db: db, + }, + } + } else { + mdb.revisions = []snapshot{} + } + + mdb.db = nil + + return nil +} + +// This code assumes that the RWMutex has been acquired. +func (mdb *memdbDatastore) checkNotClosed() error { + if mdb.db == nil { + return ErrMemDBIsClosed + } + return nil +} + +var _ datastore.Datastore = &memdbDatastore{} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go new file mode 100644 index 0000000..fdd224a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go @@ -0,0 +1,597 @@ +package memdb + +import ( + "context" + "fmt" + "slices" + "sort" + "strings" + "time" + + "github.com/hashicorp/go-memdb" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +type txFactory func() (*memdb.Txn, error) + +type memdbReader struct { + TryLocker + txSource txFactory + initErr error + now time.Time +} + +func (r *memdbReader) CountRelationships(ctx context.Context, name string) (int, error) { + counters, err := r.LookupCounters(ctx) + if err != nil { + return 0, err + } + + var found *core.RelationshipFilter + for _, counter := range counters { + if counter.Name == name { + found = counter.Filter + break + } + } + + if found == nil { + return 0, datastore.NewCounterNotRegisteredErr(name) + } + + coreFilter, err := datastore.RelationshipsFilterFromCoreFilter(found) + if err != nil { + return 0, err + } + + iter, err := r.QueryRelationships(ctx, coreFilter) + if err != nil { + return 0, err + } + + count := 0 + for _, err := range iter { + if err != nil { + return 0, err + } + + count++ + } + return count, nil +} + +func (r *memdbReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { + if r.initErr != nil { + return nil, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + var counters []datastore.RelationshipCounter + + it, err := tx.LowerBound(tableCounters, indexID) + if err != nil { + return nil, err + } + + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + found := foundRaw.(*counter) + + loaded := &core.RelationshipFilter{} + if err := loaded.UnmarshalVT(found.filterBytes); err != nil { + return nil, err + } + + counters = append(counters, datastore.RelationshipCounter{ + Name: found.name, + Filter: loaded, + Count: found.count, + ComputedAtRevision: found.updated, + }) + } + + return counters, nil +} + +// QueryRelationships reads relationships starting from the resource side. +func (r *memdbReader) QueryRelationships( + _ context.Context, + filter datastore.RelationshipsFilter, + opts ...options.QueryOptionsOption, +) (datastore.RelationshipIterator, error) { + if r.initErr != nil { + return nil, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + queryOpts := options.NewQueryOptionsWithOptions(opts...) + + bestIterator, err := iteratorForFilter(tx, filter) + if err != nil { + return nil, err + } + + if queryOpts.After != nil && queryOpts.Sort == options.Unsorted { + return nil, datastore.ErrCursorsWithoutSorting + } + + matchingRelationshipsFilterFunc := filterFuncForFilters( + filter.OptionalResourceType, + filter.OptionalResourceIds, + filter.OptionalResourceIDPrefix, + filter.OptionalResourceRelation, + filter.OptionalSubjectsSelectors, + filter.OptionalCaveatNameFilter, + filter.OptionalExpirationOption, + makeCursorFilterFn(queryOpts.After, queryOpts.Sort), + ) + filteredIterator := memdb.NewFilterIterator(bestIterator, matchingRelationshipsFilterFunc) + + switch queryOpts.Sort { + case options.Unsorted: + fallthrough + + case options.ByResource: + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration) + return iter, nil + + case options.BySubject: + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration) + + default: + return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort) + } +} + +// ReverseQueryRelationships reads relationships starting from the subject. +func (r *memdbReader) ReverseQueryRelationships( + _ context.Context, + subjectsFilter datastore.SubjectsFilter, + opts ...options.ReverseQueryOptionsOption, +) (datastore.RelationshipIterator, error) { + if r.initErr != nil { + return nil, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + queryOpts := options.NewReverseQueryOptionsWithOptions(opts...) + + iterator, err := tx.Get( + tableRelationship, + indexSubjectNamespace, + subjectsFilter.SubjectType, + ) + if err != nil { + return nil, err + } + + filterObjectType, filterRelation := "", "" + if queryOpts.ResRelation != nil { + filterObjectType = queryOpts.ResRelation.Namespace + filterRelation = queryOpts.ResRelation.Relation + } + + matchingRelationshipsFilterFunc := filterFuncForFilters( + filterObjectType, + nil, + "", + filterRelation, + []datastore.SubjectsSelector{subjectsFilter.AsSelector()}, + datastore.CaveatNameFilter{}, + datastore.ExpirationFilterOptionNone, + makeCursorFilterFn(queryOpts.AfterForReverse, queryOpts.SortForReverse), + ) + filteredIterator := memdb.NewFilterIterator(iterator, matchingRelationshipsFilterFunc) + + switch queryOpts.SortForReverse { + case options.Unsorted: + fallthrough + + case options.ByResource: + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false) + return iter, nil + + case options.BySubject: + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false) + + default: + return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.SortForReverse) + } +} + +// ReadNamespace reads a namespace definition and version and returns it, and the revision at +// which it was created or last written, if found. +func (r *memdbReader) ReadNamespaceByName(_ context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) { + if r.initErr != nil { + return nil, datastore.NoRevision, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, datastore.NoRevision, err + } + + foundRaw, err := tx.First(tableNamespace, indexID, nsName) + if err != nil { + return nil, datastore.NoRevision, err + } + + if foundRaw == nil { + return nil, datastore.NoRevision, datastore.NewNamespaceNotFoundErr(nsName) + } + + found := foundRaw.(*namespace) + + loaded := &core.NamespaceDefinition{} + if err := loaded.UnmarshalVT(found.configBytes); err != nil { + return nil, datastore.NoRevision, err + } + + return loaded, found.updated, nil +} + +// ListNamespaces lists all namespaces defined. +func (r *memdbReader) ListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) { + if r.initErr != nil { + return nil, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + var nsDefs []datastore.RevisionedNamespace + + it, err := tx.LowerBound(tableNamespace, indexID) + if err != nil { + return nil, err + } + + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + found := foundRaw.(*namespace) + + loaded := &core.NamespaceDefinition{} + if err := loaded.UnmarshalVT(found.configBytes); err != nil { + return nil, err + } + + nsDefs = append(nsDefs, datastore.RevisionedNamespace{ + Definition: loaded, + LastWrittenRevision: found.updated, + }) + } + + return nsDefs, nil +} + +func (r *memdbReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { + if r.initErr != nil { + return nil, r.initErr + } + + if len(nsNames) == 0 { + return nil, nil + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + it, err := tx.LowerBound(tableNamespace, indexID) + if err != nil { + return nil, err + } + + nsNameMap := make(map[string]struct{}, len(nsNames)) + for _, nsName := range nsNames { + nsNameMap[nsName] = struct{}{} + } + + nsDefs := make([]datastore.RevisionedNamespace, 0, len(nsNames)) + + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + found := foundRaw.(*namespace) + + loaded := &core.NamespaceDefinition{} + if err := loaded.UnmarshalVT(found.configBytes); err != nil { + return nil, err + } + + if _, ok := nsNameMap[loaded.Name]; ok { + nsDefs = append(nsDefs, datastore.RevisionedNamespace{ + Definition: loaded, + LastWrittenRevision: found.updated, + }) + } + } + + return nsDefs, nil +} + +func (r *memdbReader) mustLock() { + if !r.TryLock() { + panic("detected concurrent use of ReadWriteTransaction") + } +} + +func iteratorForFilter(txn *memdb.Txn, filter datastore.RelationshipsFilter) (memdb.ResultIterator, error) { + // "_prefix" is a specialized index suffix used by github.com/hashicorp/go-memdb to match on + // a prefix of a string. + // See: https://github.com/hashicorp/go-memdb/blob/9940d4a14258e3b887bfb4bc6ebc28f65461a01c/txn.go#L531 + index := indexNamespace + "_prefix" + + var args []any + if filter.OptionalResourceType != "" { + args = append(args, filter.OptionalResourceType) + index = indexNamespace + } else { + args = append(args, "") + } + + if filter.OptionalResourceType != "" && filter.OptionalResourceRelation != "" { + args = append(args, filter.OptionalResourceRelation) + index = indexNamespaceAndRelation + } + + if len(args) == 0 { + return nil, spiceerrors.MustBugf("cannot specify an empty filter") + } + + iter, err := txn.Get(tableRelationship, index, args...) + if err != nil { + return nil, fmt.Errorf("unable to get iterator for filter: %w", err) + } + + return iter, err +} + +func filterFuncForFilters( + optionalResourceType string, + optionalResourceIds []string, + optionalResourceIDPrefix string, + optionalRelation string, + optionalSubjectsSelectors []datastore.SubjectsSelector, + optionalCaveatFilter datastore.CaveatNameFilter, + optionalExpirationFilter datastore.ExpirationFilterOption, + cursorFilter func(*relationship) bool, +) memdb.FilterFunc { + return func(tupleRaw interface{}) bool { + tuple := tupleRaw.(*relationship) + + switch { + case optionalResourceType != "" && optionalResourceType != tuple.namespace: + return true + case len(optionalResourceIds) > 0 && !slices.Contains(optionalResourceIds, tuple.resourceID): + return true + case optionalResourceIDPrefix != "" && !strings.HasPrefix(tuple.resourceID, optionalResourceIDPrefix): + return true + case optionalRelation != "" && optionalRelation != tuple.relation: + return true + case optionalCaveatFilter.Option == datastore.CaveatFilterOptionHasMatchingCaveat && (tuple.caveat == nil || tuple.caveat.caveatName != optionalCaveatFilter.CaveatName): + return true + case optionalCaveatFilter.Option == datastore.CaveatFilterOptionNoCaveat && (tuple.caveat != nil && tuple.caveat.caveatName != ""): + return true + case optionalExpirationFilter == datastore.ExpirationFilterOptionHasExpiration && tuple.expiration == nil: + return true + case optionalExpirationFilter == datastore.ExpirationFilterOptionNoExpiration && tuple.expiration != nil: + return true + } + + applySubjectSelector := func(selector datastore.SubjectsSelector) bool { + switch { + case len(selector.OptionalSubjectType) > 0 && selector.OptionalSubjectType != tuple.subjectNamespace: + return false + case len(selector.OptionalSubjectIds) > 0 && !slices.Contains(selector.OptionalSubjectIds, tuple.subjectObjectID): + return false + } + + if selector.RelationFilter.OnlyNonEllipsisRelations { + return tuple.subjectRelation != datastore.Ellipsis + } + + relations := make([]string, 0, 2) + if selector.RelationFilter.IncludeEllipsisRelation { + relations = append(relations, datastore.Ellipsis) + } + + if selector.RelationFilter.NonEllipsisRelation != "" { + relations = append(relations, selector.RelationFilter.NonEllipsisRelation) + } + + return len(relations) == 0 || slices.Contains(relations, tuple.subjectRelation) + } + + if len(optionalSubjectsSelectors) > 0 { + hasMatchingSelector := false + for _, selector := range optionalSubjectsSelectors { + if applySubjectSelector(selector) { + hasMatchingSelector = true + break + } + } + + if !hasMatchingSelector { + return true + } + } + + return cursorFilter(tuple) + } +} + +func makeCursorFilterFn(after options.Cursor, order options.SortOrder) func(tpl *relationship) bool { + if after != nil { + switch order { + case options.ByResource: + return func(tpl *relationship) bool { + return less(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) || + (eq(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) && + (less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) || + eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject))) + } + case options.BySubject: + return func(tpl *relationship) bool { + return less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) || + (eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) && + (less(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) || + eq(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource))) + } + } + } + return noopCursorFilter +} + +func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) (datastore.RelationshipIterator, error) { + results := make([]tuple.Relationship, 0) + + // Coalesce all of the results into memory + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + rt, err := foundRaw.(*relationship).Relationship() + if err != nil { + return nil, err + } + + if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) { + continue + } + + if skipCaveats && rt.OptionalCaveat != nil { + return nil, spiceerrors.MustBugf("unexpected caveat in result for relationship: %v", rt) + } + + if skipExpiration && rt.OptionalExpiration != nil { + return nil, spiceerrors.MustBugf("unexpected expiration in result for relationship: %v", rt) + } + + results = append(results, rt) + } + + // Sort them by subject + sort.Slice(results, func(i, j int) bool { + lhsRes := results[i].Resource + lhsSub := results[i].Subject + rhsRes := results[j].Resource + rhsSub := results[j].Subject + return less(lhsSub.ObjectType, lhsSub.ObjectID, lhsSub.Relation, rhsSub) || + (eq(lhsSub.ObjectType, lhsSub.ObjectID, lhsSub.Relation, rhsSub) && + (less(lhsRes.ObjectType, lhsRes.ObjectID, lhsRes.Relation, rhsRes))) + }) + + // Limit them if requested + if limit != nil && uint64(len(results)) > *limit { + results = results[0:*limit] + } + + return common.NewSliceRelationshipIterator(results), nil +} + +func noopCursorFilter(_ *relationship) bool { + return false +} + +func less(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelation) bool { + return lhsNamespace < rhs.ObjectType || + (lhsNamespace == rhs.ObjectType && lhsObjectID < rhs.ObjectID) || + (lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation < rhs.Relation) +} + +func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelation) bool { + return lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation == rhs.Relation +} + +func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) datastore.RelationshipIterator { + var count uint64 + return func(yield func(tuple.Relationship, error) bool) { + for { + foundRaw := it.Next() + if foundRaw == nil { + return + } + + if limit != nil && count >= *limit { + return + } + + rt, err := foundRaw.(*relationship).Relationship() + if err != nil { + if !yield(tuple.Relationship{}, err) { + return + } + continue + } + + if skipCaveats && rt.OptionalCaveat != nil { + yield(rt, fmt.Errorf("unexpected caveat in result for relationship: %v", rt)) + return + } + + if skipExpiration && rt.OptionalExpiration != nil { + yield(rt, fmt.Errorf("unexpected expiration in result for relationship: %v", rt)) + return + } + + if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) { + continue + } + + if !yield(rt, err) { + return + } + count++ + } + } +} + +var _ datastore.Reader = &memdbReader{} + +type TryLocker interface { + TryLock() bool + Unlock() +} + +type noopTryLocker struct{} + +func (ntl noopTryLocker) TryLock() bool { + return true +} + +func (ntl noopTryLocker) Unlock() {} + +var _ TryLocker = noopTryLocker{} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go new file mode 100644 index 0000000..8929e84 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go @@ -0,0 +1,386 @@ +package memdb + +import ( + "context" + "fmt" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/hashicorp/go-memdb" + "github.com/jzelinskie/stringz" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +type memdbReadWriteTx struct { + memdbReader + newRevision datastore.Revision +} + +func (rwt *memdbReadWriteTx) WriteRelationships(_ context.Context, mutations []tuple.RelationshipUpdate) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + return rwt.write(tx, mutations...) +} + +func (rwt *memdbReadWriteTx) toIntegrity(mutation tuple.RelationshipUpdate) *relationshipIntegrity { + var ri *relationshipIntegrity + if mutation.Relationship.OptionalIntegrity != nil { + ri = &relationshipIntegrity{ + keyID: mutation.Relationship.OptionalIntegrity.KeyId, + hash: mutation.Relationship.OptionalIntegrity.Hash, + timestamp: mutation.Relationship.OptionalIntegrity.HashedAt.AsTime(), + } + } + return ri +} + +// Caller must already hold the concurrent access lock! +func (rwt *memdbReadWriteTx) write(tx *memdb.Txn, mutations ...tuple.RelationshipUpdate) error { + // Apply the mutations + for _, mutation := range mutations { + rel := &relationship{ + mutation.Relationship.Resource.ObjectType, + mutation.Relationship.Resource.ObjectID, + mutation.Relationship.Resource.Relation, + mutation.Relationship.Subject.ObjectType, + mutation.Relationship.Subject.ObjectID, + mutation.Relationship.Subject.Relation, + rwt.toCaveatReference(mutation), + rwt.toIntegrity(mutation), + mutation.Relationship.OptionalExpiration, + } + + found, err := tx.First( + tableRelationship, + indexID, + rel.namespace, + rel.resourceID, + rel.relation, + rel.subjectNamespace, + rel.subjectObjectID, + rel.subjectRelation, + ) + if err != nil { + return fmt.Errorf("error loading existing relationship: %w", err) + } + + var existing *relationship + if found != nil { + existing = found.(*relationship) + } + + switch mutation.Operation { + case tuple.UpdateOperationCreate: + if existing != nil { + rt, err := existing.Relationship() + if err != nil { + return err + } + return common.NewCreateRelationshipExistsError(&rt) + } + if err := tx.Insert(tableRelationship, rel); err != nil { + return fmt.Errorf("error inserting relationship: %w", err) + } + + case tuple.UpdateOperationTouch: + if existing != nil { + rt, err := existing.Relationship() + if err != nil { + return err + } + if tuple.MustString(rt) == tuple.MustString(mutation.Relationship) { + continue + } + } + + if err := tx.Insert(tableRelationship, rel); err != nil { + return fmt.Errorf("error inserting relationship: %w", err) + } + + case tuple.UpdateOperationDelete: + if existing != nil { + if err := tx.Delete(tableRelationship, existing); err != nil { + return fmt.Errorf("error deleting relationship: %w", err) + } + } + default: + return spiceerrors.MustBugf("unknown tuple mutation operation type: %v", mutation.Operation) + } + } + + return nil +} + +func (rwt *memdbReadWriteTx) toCaveatReference(mutation tuple.RelationshipUpdate) *contextualizedCaveat { + var cr *contextualizedCaveat + if mutation.Relationship.OptionalCaveat != nil { + cr = &contextualizedCaveat{ + caveatName: mutation.Relationship.OptionalCaveat.CaveatName, + context: mutation.Relationship.OptionalCaveat.Context.AsMap(), + } + } + return cr +} + +func (rwt *memdbReadWriteTx) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return 0, false, err + } + + delOpts := options.NewDeleteOptionsWithOptionsAndDefaults(opts...) + var delLimit uint64 + if delOpts.DeleteLimit != nil && *delOpts.DeleteLimit > 0 { + delLimit = *delOpts.DeleteLimit + } + + return rwt.deleteWithLock(tx, filter, delLimit) +} + +// caller must already hold the concurrent access lock +func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.RelationshipFilter, limit uint64) (uint64, bool, error) { + // Create an iterator to find the relevant tuples + dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter) + if err != nil { + return 0, false, err + } + + bestIter, err := iteratorForFilter(tx, dsFilter) + if err != nil { + return 0, false, err + } + filteredIter := memdb.NewFilterIterator(bestIter, relationshipFilterFilterFunc(filter)) + + // Collect the tuples into a slice of mutations for the changelog + var mutations []tuple.RelationshipUpdate + var counter uint64 + + metLimit := false + for row := filteredIter.Next(); row != nil; row = filteredIter.Next() { + rt, err := row.(*relationship).Relationship() + if err != nil { + return 0, false, err + } + mutations = append(mutations, tuple.Delete(rt)) + counter++ + + if limit > 0 && counter == limit { + metLimit = true + break + } + } + + return counter, metLimit, rwt.write(tx, mutations...) +} + +func (rwt *memdbReadWriteTx) RegisterCounter(ctx context.Context, name string, filter *core.RelationshipFilter) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + foundRaw, err := tx.First(tableCounters, indexID, name) + if err != nil { + return err + } + + if foundRaw != nil { + return datastore.NewCounterAlreadyRegisteredErr(name, filter) + } + + filterBytes, err := filter.MarshalVT() + if err != nil { + return err + } + + // Insert the counter + counter := &counter{ + name, + filterBytes, + 0, + datastore.NoRevision, + } + + return tx.Insert(tableCounters, counter) +} + +func (rwt *memdbReadWriteTx) UnregisterCounter(ctx context.Context, name string) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + // Check if the counter exists + foundRaw, err := tx.First(tableCounters, indexID, name) + if err != nil { + return err + } + + if foundRaw == nil { + return datastore.NewCounterNotRegisteredErr(name) + } + + return tx.Delete(tableCounters, foundRaw) +} + +func (rwt *memdbReadWriteTx) StoreCounterValue(ctx context.Context, name string, value int, computedAtRevision datastore.Revision) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + // Check if the counter exists + foundRaw, err := tx.First(tableCounters, indexID, name) + if err != nil { + return err + } + + if foundRaw == nil { + return datastore.NewCounterNotRegisteredErr(name) + } + + counter := foundRaw.(*counter) + counter.count = value + counter.updated = computedAtRevision + + return tx.Insert(tableCounters, counter) +} + +func (rwt *memdbReadWriteTx) WriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + for _, newConfig := range newConfigs { + serialized, err := newConfig.MarshalVT() + if err != nil { + return err + } + + newConfigEntry := &namespace{newConfig.Name, serialized, rwt.newRevision} + + err = tx.Insert(tableNamespace, newConfigEntry) + if err != nil { + return err + } + } + + return nil +} + +func (rwt *memdbReadWriteTx) DeleteNamespaces(_ context.Context, nsNames ...string) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + for _, nsName := range nsNames { + foundRaw, err := tx.First(tableNamespace, indexID, nsName) + if err != nil { + return err + } + + if foundRaw == nil { + return fmt.Errorf("namespace not found") + } + + if err := tx.Delete(tableNamespace, foundRaw); err != nil { + return err + } + + // Delete the relationships from the namespace + if _, _, err := rwt.deleteWithLock(tx, &v1.RelationshipFilter{ + ResourceType: nsName, + }, 0); err != nil { + return fmt.Errorf("unable to delete relationships from deleted namespace: %w", err) + } + } + + return nil +} + +func (rwt *memdbReadWriteTx) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { + var numCopied uint64 + var next *tuple.Relationship + var err error + + updates := []tuple.RelationshipUpdate{{ + Operation: tuple.UpdateOperationCreate, + }} + + for next, err = iter.Next(ctx); next != nil && err == nil; next, err = iter.Next(ctx) { + updates[0].Relationship = *next + if err := rwt.WriteRelationships(ctx, updates); err != nil { + return 0, err + } + numCopied++ + } + + return numCopied, err +} + +func relationshipFilterFilterFunc(filter *v1.RelationshipFilter) func(interface{}) bool { + return func(tupleRaw interface{}) bool { + tuple := tupleRaw.(*relationship) + + // If it doesn't match one of the resource filters, filter it. + switch { + case filter.ResourceType != "" && filter.ResourceType != tuple.namespace: + return true + case filter.OptionalResourceId != "" && filter.OptionalResourceId != tuple.resourceID: + return true + case filter.OptionalResourceIdPrefix != "" && !strings.HasPrefix(tuple.resourceID, filter.OptionalResourceIdPrefix): + return true + case filter.OptionalRelation != "" && filter.OptionalRelation != tuple.relation: + return true + } + + // If it doesn't match one of the subject filters, filter it. + if subjectFilter := filter.OptionalSubjectFilter; subjectFilter != nil { + switch { + case subjectFilter.SubjectType != tuple.subjectNamespace: + return true + case subjectFilter.OptionalSubjectId != "" && subjectFilter.OptionalSubjectId != tuple.subjectObjectID: + return true + case subjectFilter.OptionalRelation != nil && + stringz.DefaultEmpty(subjectFilter.OptionalRelation.Relation, datastore.Ellipsis) != tuple.subjectRelation: + return true + } + } + + return false + } +} + +var _ datastore.ReadWriteTransaction = &memdbReadWriteTx{} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go new file mode 100644 index 0000000..be79771 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go @@ -0,0 +1,118 @@ +package memdb + +import ( + "context" + "time" + + "github.com/authzed/spicedb/internal/datastore/revisions" + "github.com/authzed/spicedb/pkg/datastore" +) + +var ParseRevisionString = revisions.RevisionParser(revisions.Timestamp) + +func nowRevision() revisions.TimestampRevision { + return revisions.NewForTime(time.Now().UTC()) +} + +func (mdb *memdbDatastore) newRevisionID() revisions.TimestampRevision { + mdb.Lock() + defer mdb.Unlock() + + existing := mdb.revisions[len(mdb.revisions)-1].revision + created := nowRevision() + + // NOTE: The time.Now().UTC() only appears to have *microsecond* level + // precision on macOS Monterey in Go 1.19.1. This means that HeadRevision + // and the result of a ReadWriteTx could return the *same* transaction ID + // if both are executed in sequence without any other forms of delay on + // macOS. We therefore check if the created transaction ID matches that + // previously created and, if not, add to it. + // + // See: https://github.com/golang/go/issues/22037 which appeared to fix + // this in Go 1.9.2, but there appears to have been a reversion with either + // the new version of macOS or Go. + if created.Equal(existing) { + return revisions.NewForTimestamp(created.TimestampNanoSec() + 1) + } + + return created +} + +func (mdb *memdbDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) { + mdb.RLock() + defer mdb.RUnlock() + if err := mdb.checkNotClosed(); err != nil { + return nil, err + } + + return mdb.headRevisionNoLock(), nil +} + +func (mdb *memdbDatastore) SquashRevisionsForTesting() { + mdb.revisions = []snapshot{ + { + revision: nowRevision(), + db: mdb.db, + }, + } +} + +func (mdb *memdbDatastore) headRevisionNoLock() revisions.TimestampRevision { + return mdb.revisions[len(mdb.revisions)-1].revision +} + +func (mdb *memdbDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) { + mdb.RLock() + defer mdb.RUnlock() + if err := mdb.checkNotClosed(); err != nil { + return nil, err + } + + now := nowRevision() + return revisions.NewForTimestamp(now.TimestampNanoSec() - now.TimestampNanoSec()%mdb.quantizationPeriod), nil +} + +func (mdb *memdbDatastore) CheckRevision(_ context.Context, dr datastore.Revision) error { + mdb.RLock() + defer mdb.RUnlock() + if err := mdb.checkNotClosed(); err != nil { + return err + } + + return mdb.checkRevisionLocalCallerMustLock(dr) +} + +func (mdb *memdbDatastore) checkRevisionLocalCallerMustLock(dr datastore.Revision) error { + now := nowRevision() + + // Ensure the revision has not fallen outside of the GC window. If it has, it is considered + // invalid. + if mdb.revisionOutsideGCWindow(now, dr) { + return datastore.NewInvalidRevisionErr(dr, datastore.RevisionStale) + } + + // If the revision <= now and later than the GC window, it is assumed to be valid, even if + // HEAD revision is behind it. + if dr.GreaterThan(now) { + // If the revision is in the "future", then check to ensure that it is <= of HEAD to handle + // the microsecond granularity on macos (see comment above in newRevisionID) + headRevision := mdb.headRevisionNoLock() + if dr.LessThan(headRevision) || dr.Equal(headRevision) { + return nil + } + + return datastore.NewInvalidRevisionErr(dr, datastore.CouldNotDetermineRevision) + } + + return nil +} + +func (mdb *memdbDatastore) revisionOutsideGCWindow(now revisions.TimestampRevision, revisionRaw datastore.Revision) bool { + // make an exception for head revision - it will be acceptable even if outside GC Window + if revisionRaw.Equal(mdb.headRevisionNoLock()) { + return false + } + + oldest := revisions.NewForTimestamp(now.TimestampNanoSec() + mdb.negativeGCWindow) + return revisionRaw.LessThan(oldest) +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go new file mode 100644 index 0000000..7905d48 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go @@ -0,0 +1,232 @@ +package memdb + +import ( + "time" + + "github.com/hashicorp/go-memdb" + "github.com/rs/zerolog" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const ( + tableNamespace = "namespace" + + tableRelationship = "relationship" + indexID = "id" + indexNamespace = "namespace" + indexNamespaceAndRelation = "namespaceAndRelation" + indexSubjectNamespace = "subjectNamespace" + + tableCounters = "counters" + + tableChangelog = "changelog" + indexRevision = "id" +) + +type namespace struct { + name string + configBytes []byte + updated datastore.Revision +} + +func (ns namespace) MarshalZerologObject(e *zerolog.Event) { + e.Stringer("rev", ns.updated).Str("name", ns.name) +} + +type counter struct { + name string + filterBytes []byte + count int + updated datastore.Revision +} + +type relationship struct { + namespace string + resourceID string + relation string + subjectNamespace string + subjectObjectID string + subjectRelation string + caveat *contextualizedCaveat + integrity *relationshipIntegrity + expiration *time.Time +} + +type relationshipIntegrity struct { + keyID string + hash []byte + timestamp time.Time +} + +func (ri relationshipIntegrity) MarshalZerologObject(e *zerolog.Event) { + e.Str("keyID", ri.keyID).Bytes("hash", ri.hash).Time("timestamp", ri.timestamp) +} + +func (ri relationshipIntegrity) RelationshipIntegrity() *core.RelationshipIntegrity { + return &core.RelationshipIntegrity{ + KeyId: ri.keyID, + Hash: ri.hash, + HashedAt: timestamppb.New(ri.timestamp), + } +} + +type contextualizedCaveat struct { + caveatName string + context map[string]any +} + +func (cr *contextualizedCaveat) ContextualizedCaveat() (*core.ContextualizedCaveat, error) { + if cr == nil { + return nil, nil + } + v, err := structpb.NewStruct(cr.context) + if err != nil { + return nil, err + } + return &core.ContextualizedCaveat{ + CaveatName: cr.caveatName, + Context: v, + }, nil +} + +func (r relationship) String() string { + caveat := "" + if r.caveat != nil { + caveat = "[" + r.caveat.caveatName + "]" + } + + expiration := "" + if r.expiration != nil { + expiration = "[expiration:" + r.expiration.Format(time.RFC3339Nano) + "]" + } + + return r.namespace + ":" + r.resourceID + "#" + r.relation + "@" + r.subjectNamespace + ":" + r.subjectObjectID + "#" + r.subjectRelation + caveat + expiration +} + +func (r relationship) MarshalZerologObject(e *zerolog.Event) { + e.Str("rel", r.String()) +} + +func (r relationship) Relationship() (tuple.Relationship, error) { + cr, err := r.caveat.ContextualizedCaveat() + if err != nil { + return tuple.Relationship{}, err + } + + var ig *core.RelationshipIntegrity + if r.integrity != nil { + ig = r.integrity.RelationshipIntegrity() + } + + return tuple.Relationship{ + RelationshipReference: tuple.RelationshipReference{ + Resource: tuple.ObjectAndRelation{ + ObjectType: r.namespace, + ObjectID: r.resourceID, + Relation: r.relation, + }, + Subject: tuple.ObjectAndRelation{ + ObjectType: r.subjectNamespace, + ObjectID: r.subjectObjectID, + Relation: r.subjectRelation, + }, + }, + OptionalCaveat: cr, + OptionalIntegrity: ig, + OptionalExpiration: r.expiration, + }, nil +} + +type changelog struct { + revisionNanos int64 + changes datastore.RevisionChanges +} + +var schema = &memdb.DBSchema{ + Tables: map[string]*memdb.TableSchema{ + tableNamespace: { + Name: tableNamespace, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, + tableChangelog: { + Name: tableChangelog, + Indexes: map[string]*memdb.IndexSchema{ + indexRevision: { + Name: indexRevision, + Unique: true, + Indexer: &memdb.IntFieldIndex{Field: "revisionNanos"}, + }, + }, + }, + tableRelationship: { + Name: tableRelationship, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{Field: "namespace"}, + &memdb.StringFieldIndex{Field: "resourceID"}, + &memdb.StringFieldIndex{Field: "relation"}, + &memdb.StringFieldIndex{Field: "subjectNamespace"}, + &memdb.StringFieldIndex{Field: "subjectObjectID"}, + &memdb.StringFieldIndex{Field: "subjectRelation"}, + }, + }, + }, + indexNamespace: { + Name: indexNamespace, + Unique: false, + Indexer: &memdb.StringFieldIndex{Field: "namespace"}, + }, + indexNamespaceAndRelation: { + Name: indexNamespaceAndRelation, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{Field: "namespace"}, + &memdb.StringFieldIndex{Field: "relation"}, + }, + }, + }, + indexSubjectNamespace: { + Name: indexSubjectNamespace, + Unique: false, + Indexer: &memdb.StringFieldIndex{Field: "subjectNamespace"}, + }, + }, + }, + tableCaveats: { + Name: tableCaveats, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, + tableCounters: { + Name: tableCounters, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, + }, +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go new file mode 100644 index 0000000..33665a1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go @@ -0,0 +1,51 @@ +package memdb + +import ( + "context" + "fmt" + + "github.com/authzed/spicedb/pkg/datastore" +) + +func (mdb *memdbDatastore) Statistics(ctx context.Context) (datastore.Stats, error) { + head, err := mdb.HeadRevision(ctx) + if err != nil { + return datastore.Stats{}, fmt.Errorf("unable to compute head revision: %w", err) + } + + count, err := mdb.countRelationships(ctx) + if err != nil { + return datastore.Stats{}, fmt.Errorf("unable to count relationships: %w", err) + } + + objTypes, err := mdb.SnapshotReader(head).ListAllNamespaces(ctx) + if err != nil { + return datastore.Stats{}, fmt.Errorf("unable to list object types: %w", err) + } + + return datastore.Stats{ + UniqueID: mdb.uniqueID, + EstimatedRelationshipCount: count, + ObjectTypeStatistics: datastore.ComputeObjectTypeStats(objTypes), + }, nil +} + +func (mdb *memdbDatastore) countRelationships(_ context.Context) (uint64, error) { + mdb.RLock() + defer mdb.RUnlock() + + txn := mdb.db.Txn(false) + defer txn.Abort() + + it, err := txn.LowerBound(tableRelationship, indexID) + if err != nil { + return 0, err + } + + var count uint64 + for row := it.Next(); row != nil; row = it.Next() { + count++ + } + + return count, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go new file mode 100644 index 0000000..eaa4812 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go @@ -0,0 +1,148 @@ +package memdb + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/hashicorp/go-memdb" + + "github.com/authzed/spicedb/internal/datastore/revisions" + "github.com/authzed/spicedb/pkg/datastore" +) + +const errWatchError = "watch error: %w" + +func (mdb *memdbDatastore) Watch(ctx context.Context, ar datastore.Revision, options datastore.WatchOptions) (<-chan datastore.RevisionChanges, <-chan error) { + watchBufferLength := options.WatchBufferLength + if watchBufferLength == 0 { + watchBufferLength = mdb.watchBufferLength + } + + updates := make(chan datastore.RevisionChanges, watchBufferLength) + errs := make(chan error, 1) + + if options.EmissionStrategy == datastore.EmitImmediatelyStrategy { + close(updates) + errs <- errors.New("emit immediately strategy is unsupported in MemDB") + return updates, errs + } + + watchBufferWriteTimeout := options.WatchBufferWriteTimeout + if watchBufferWriteTimeout == 0 { + watchBufferWriteTimeout = mdb.watchBufferWriteTimeout + } + + sendChange := func(change datastore.RevisionChanges) bool { + select { + case updates <- change: + return true + + default: + // If we cannot immediately write, setup the timer and try again. + } + + timer := time.NewTimer(watchBufferWriteTimeout) + defer timer.Stop() + + select { + case updates <- change: + return true + + case <-timer.C: + errs <- datastore.NewWatchDisconnectedErr() + return false + } + } + + go func() { + defer close(updates) + defer close(errs) + + currentTxn := ar.(revisions.TimestampRevision).TimestampNanoSec() + + for { + var stagedUpdates []datastore.RevisionChanges + var watchChan <-chan struct{} + var err error + stagedUpdates, currentTxn, watchChan, err = mdb.loadChanges(ctx, currentTxn, options) + if err != nil { + errs <- err + return + } + + // Write the staged updates to the channel + for _, changeToWrite := range stagedUpdates { + if !sendChange(changeToWrite) { + return + } + } + + // Wait for new changes + ws := memdb.NewWatchSet() + ws.Add(watchChan) + + err = ws.WatchCtx(ctx) + if err != nil { + switch { + case errors.Is(err, context.Canceled): + errs <- datastore.NewWatchCanceledErr() + default: + errs <- fmt.Errorf(errWatchError, err) + } + return + } + } + }() + + return updates, errs +} + +func (mdb *memdbDatastore) loadChanges(_ context.Context, currentTxn int64, options datastore.WatchOptions) ([]datastore.RevisionChanges, int64, <-chan struct{}, error) { + mdb.RLock() + defer mdb.RUnlock() + + if err := mdb.checkNotClosed(); err != nil { + return nil, 0, nil, err + } + + loadNewTxn := mdb.db.Txn(false) + defer loadNewTxn.Abort() + + it, err := loadNewTxn.LowerBound(tableChangelog, indexRevision, currentTxn+1) + if err != nil { + return nil, 0, nil, fmt.Errorf(errWatchError, err) + } + + var changes []datastore.RevisionChanges + lastRevision := currentTxn + for changeRaw := it.Next(); changeRaw != nil; changeRaw = it.Next() { + change := changeRaw.(*changelog) + + if options.Content&datastore.WatchRelationships == datastore.WatchRelationships && len(change.changes.RelationshipChanges) > 0 { + changes = append(changes, change.changes) + } + + if options.Content&datastore.WatchSchema == datastore.WatchSchema && + len(change.changes.ChangedDefinitions) > 0 || len(change.changes.DeletedCaveats) > 0 || len(change.changes.DeletedNamespaces) > 0 { + changes = append(changes, change.changes) + } + + if options.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints && change.revisionNanos > lastRevision { + changes = append(changes, datastore.RevisionChanges{ + Revision: revisions.NewForTimestamp(change.revisionNanos), + IsCheckpoint: true, + }) + } + + lastRevision = change.revisionNanos + } + + watchChan, _, err := loadNewTxn.LastWatch(tableChangelog, indexRevision) + if err != nil { + return nil, 0, nil, fmt.Errorf(errWatchError, err) + } + + return changes, lastRevision, watchChan, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go new file mode 100644 index 0000000..7092728 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go @@ -0,0 +1,79 @@ +package revisions + +import ( + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// RevisionKind is an enum of the different kinds of revisions that can be used. +type RevisionKind string + +const ( + // Timestamp is a revision that is a timestamp. + Timestamp RevisionKind = "timestamp" + + // TransactionID is a revision that is a transaction ID. + TransactionID = "txid" + + // HybridLogicalClock is a revision that is a hybrid logical clock. + HybridLogicalClock = "hlc" +) + +// ParsingFunc is a function that can parse a string into a revision. +type ParsingFunc func(revisionStr string) (rev datastore.Revision, err error) + +// RevisionParser returns a ParsingFunc for the given RevisionKind. +func RevisionParser(kind RevisionKind) ParsingFunc { + switch kind { + case TransactionID: + return parseTransactionIDRevisionString + + case Timestamp: + return parseTimestampRevisionString + + case HybridLogicalClock: + return parseHLCRevisionString + + default: + return func(revisionStr string) (rev datastore.Revision, err error) { + return nil, spiceerrors.MustBugf("unknown revision kind: %v", kind) + } + } +} + +// CommonDecoder is a revision decoder that can decode revisions of a given kind. +type CommonDecoder struct { + Kind RevisionKind +} + +func (cd CommonDecoder) RevisionFromString(s string) (datastore.Revision, error) { + switch cd.Kind { + case TransactionID: + return parseTransactionIDRevisionString(s) + + case Timestamp: + return parseTimestampRevisionString(s) + + case HybridLogicalClock: + return parseHLCRevisionString(s) + + default: + return nil, spiceerrors.MustBugf("unknown revision kind in decoder: %v", cd.Kind) + } +} + +// WithInexactFloat64 is an interface that can be implemented by a revision to +// provide an inexact float64 representation of the revision. +type WithInexactFloat64 interface { + // InexactFloat64 returns a float64 that is an inexact representation of the + // revision. + InexactFloat64() float64 +} + +// WithTimestampRevision is an interface that can be implemented by a revision to +// provide a timestamp. +type WithTimestampRevision interface { + datastore.Revision + TimestampNanoSec() int64 + ConstructForTimestamp(timestampNanoSec int64) WithTimestampRevision +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go new file mode 100644 index 0000000..e4f7fc6 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go @@ -0,0 +1,166 @@ +package revisions + +import ( + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/ccoveille/go-safecast" + "github.com/shopspring/decimal" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var zeroHLC = HLCRevision{} + +// NOTE: This *must* match the length defined in CRDB or the implementation below will break. +const logicalClockLength = 10 + +var logicalClockOffset = uint32(math.Pow10(logicalClockLength + 1)) + +// HLCRevision is a revision that is a hybrid logical clock, stored as two integers. +// The first integer is the timestamp in nanoseconds, and the second integer is the +// logical clock defined as 11 digits, with the first digit being ignored to ensure +// precision of the given logical clock. +type HLCRevision struct { + time int64 + logicalclock uint32 +} + +// parseHLCRevisionString parses a string into a hybrid logical clock revision. +func parseHLCRevisionString(revisionStr string) (datastore.Revision, error) { + pieces := strings.Split(revisionStr, ".") + if len(pieces) == 1 { + // If there is no decimal point, assume the revision is a timestamp. + timestamp, err := strconv.ParseInt(pieces[0], 10, 64) + if err != nil { + return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) + } + return HLCRevision{timestamp, logicalClockOffset}, nil + } + + if len(pieces) != 2 { + return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) + } + + timestamp, err := strconv.ParseInt(pieces[0], 10, 64) + if err != nil { + return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) + } + + if len(pieces[1]) > logicalClockLength { + return datastore.NoRevision, spiceerrors.MustBugf("invalid revision string due to unexpected logical clock size (%d): %q", len(pieces[1]), revisionStr) + } + + paddedLogicalClockStr := pieces[1] + strings.Repeat("0", logicalClockLength-len(pieces[1])) + logicalclock, err := strconv.ParseUint(paddedLogicalClockStr, 10, 64) + if err != nil { + return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) + } + + if logicalclock > math.MaxUint32 { + return datastore.NoRevision, spiceerrors.MustBugf("received logical lock that exceeds MaxUint32 (%d > %d): revision %q", logicalclock, math.MaxUint32, revisionStr) + } + + uintLogicalClock, err := safecast.ToUint32(logicalclock) + if err != nil { + return datastore.NoRevision, spiceerrors.MustBugf("could not cast logicalclock to uint32: %v", err) + } + + return HLCRevision{timestamp, uintLogicalClock + logicalClockOffset}, nil +} + +// HLCRevisionFromString parses a string into a hybrid logical clock revision. +func HLCRevisionFromString(revisionStr string) (HLCRevision, error) { + rev, err := parseHLCRevisionString(revisionStr) + if err != nil { + return zeroHLC, err + } + + return rev.(HLCRevision), nil +} + +// NewForHLC creates a new revision for the given hybrid logical clock. +func NewForHLC(decimal decimal.Decimal) (HLCRevision, error) { + rev, err := HLCRevisionFromString(decimal.String()) + if err != nil { + return zeroHLC, fmt.Errorf("invalid HLC decimal: %v (%s) => %w", decimal, decimal.String(), err) + } + + return rev, nil +} + +// NewHLCForTime creates a new revision for the given time. +func NewHLCForTime(time time.Time) HLCRevision { + return HLCRevision{time.UnixNano(), logicalClockOffset} +} + +func (hlc HLCRevision) ByteSortable() bool { + return true +} + +func (hlc HLCRevision) Equal(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroHLC + } + + rhsHLC := rhs.(HLCRevision) + return hlc.time == rhsHLC.time && hlc.logicalclock == rhsHLC.logicalclock +} + +func (hlc HLCRevision) GreaterThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroHLC + } + + rhsHLC := rhs.(HLCRevision) + return hlc.time > rhsHLC.time || (hlc.time == rhsHLC.time && hlc.logicalclock > rhsHLC.logicalclock) +} + +func (hlc HLCRevision) LessThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroHLC + } + + rhsHLC := rhs.(HLCRevision) + return hlc.time < rhsHLC.time || (hlc.time == rhsHLC.time && hlc.logicalclock < rhsHLC.logicalclock) +} + +func (hlc HLCRevision) String() string { + logicalClockString := strconv.FormatInt(int64(hlc.logicalclock)-int64(logicalClockOffset), 10) + return strconv.FormatInt(hlc.time, 10) + "." + strings.Repeat("0", logicalClockLength-len(logicalClockString)) + logicalClockString +} + +func (hlc HLCRevision) TimestampNanoSec() int64 { + return hlc.time +} + +func (hlc HLCRevision) InexactFloat64() float64 { + return float64(hlc.time) + float64(hlc.logicalclock-logicalClockOffset)/math.Pow10(logicalClockLength) +} + +func (hlc HLCRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevision { + return HLCRevision{timestamp, logicalClockOffset} +} + +func (hlc HLCRevision) AsDecimal() (decimal.Decimal, error) { + return decimal.NewFromString(hlc.String()) +} + +var ( + _ datastore.Revision = HLCRevision{} + _ WithTimestampRevision = HLCRevision{} +) + +// HLCKeyFunc is used to convert a simple HLC for use in maps. +func HLCKeyFunc(r HLCRevision) HLCRevision { + return r +} + +// HLCKeyLessThanFunc is used to compare keys created by the HLCKeyFunc. +func HLCKeyLessThanFunc(lhs, rhs HLCRevision) bool { + return lhs.time < rhs.time || (lhs.time == rhs.time && lhs.logicalclock < rhs.logicalclock) +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go new file mode 100644 index 0000000..3a5a919 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go @@ -0,0 +1,118 @@ +package revisions + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/benbjohnson/clock" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/singleflight" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" +) + +var tracer = otel.Tracer("spicedb/internal/datastore/common/revisions") + +// OptimizedRevisionFunction instructs the datastore to compute its own current +// optimized revision given the specific quantization, and return for how long +// it will remain valid. +type OptimizedRevisionFunction func(context.Context) (rev datastore.Revision, validFor time.Duration, err error) + +// NewCachedOptimizedRevisions returns a CachedOptimizedRevisions for the given configuration +func NewCachedOptimizedRevisions(maxRevisionStaleness time.Duration) *CachedOptimizedRevisions { + return &CachedOptimizedRevisions{ + maxRevisionStaleness: maxRevisionStaleness, + clockFn: clock.New(), + } +} + +// SetOptimizedRevisionFunc must be called after construction, and is the method +// by which one specializes this helper for a specific datastore. +func (cor *CachedOptimizedRevisions) SetOptimizedRevisionFunc(revisionFunc OptimizedRevisionFunction) { + cor.optimizedFunc = revisionFunc +} + +func (cor *CachedOptimizedRevisions) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { + span := trace.SpanFromContext(ctx) + localNow := cor.clockFn.Now() + + // Subtract a random amount of time from now, to let barely expired candidates get selected + adjustedNow := localNow + if cor.maxRevisionStaleness > 0 { + // nolint:gosec + // G404 use of non cryptographically secure random number generator is not a security concern here, + // as we are using it to introduce randomness to the accepted staleness of a revision and reduce the odds of + // a thundering herd to the datastore + adjustedNow = localNow.Add(-1 * time.Duration(rand.Int63n(cor.maxRevisionStaleness.Nanoseconds())) * time.Nanosecond) + } + + cor.RLock() + for _, candidate := range cor.candidates { + if candidate.validThrough.After(adjustedNow) { + cor.RUnlock() + log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", candidate.validThrough).Msg("returning cached revision") + span.AddEvent("returning cached revision") + return candidate.revision, nil + } + } + cor.RUnlock() + + newQuantizedRevision, err, _ := cor.updateGroup.Do("", func() (interface{}, error) { + log.Ctx(ctx).Debug().Time("now", localNow).Msg("computing new revision") + span.AddEvent("computing new revision") + + optimized, validFor, err := cor.optimizedFunc(ctx) + if err != nil { + return nil, fmt.Errorf("unable to compute optimized revision: %w", err) + } + + rvt := localNow.Add(validFor) + + // Prune the candidates that have definitely expired + cor.Lock() + var numToDrop uint + for _, candidate := range cor.candidates { + if candidate.validThrough.Add(cor.maxRevisionStaleness).Before(localNow) { + numToDrop++ + } else { + break + } + } + + cor.candidates = cor.candidates[numToDrop:] + cor.candidates = append(cor.candidates, validRevision{optimized, rvt}) + cor.Unlock() + + log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", rvt).Stringer("validFor", validFor).Msg("setting valid through") + return optimized, nil + }) + if err != nil { + return datastore.NoRevision, err + } + return newQuantizedRevision.(datastore.Revision), err +} + +// CachedOptimizedRevisions does caching and deduplication for requests for optimized revisions. +type CachedOptimizedRevisions struct { + sync.RWMutex + + maxRevisionStaleness time.Duration + optimizedFunc OptimizedRevisionFunction + clockFn clock.Clock + + // these values are read and set by multiple consumers + candidates []validRevision // GUARDED_BY(RWMutex) + + // the updategroup consolidates concurrent requests to the database into 1 + updateGroup singleflight.Group +} + +type validRevision struct { + revision datastore.Revision + validThrough time.Time +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go new file mode 100644 index 0000000..ef793c8 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go @@ -0,0 +1,125 @@ +package revisions + +import ( + "context" + "time" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// RemoteNowFunction queries the datastore to get a current revision. +type RemoteNowFunction func(context.Context) (datastore.Revision, error) + +// RemoteClockRevisions handles revision calculation for datastores that provide +// their own clocks. +type RemoteClockRevisions struct { + *CachedOptimizedRevisions + + gcWindowNanos int64 + nowFunc RemoteNowFunction + followerReadDelayNanos int64 + quantizationNanos int64 +} + +// NewRemoteClockRevisions returns a RemoteClockRevisions for the given configuration +func NewRemoteClockRevisions(gcWindow, maxRevisionStaleness, followerReadDelay, quantization time.Duration) *RemoteClockRevisions { + // Ensure the max revision staleness never exceeds the GC window. + if maxRevisionStaleness > gcWindow { + log.Warn(). + Dur("maxRevisionStaleness", maxRevisionStaleness). + Dur("gcWindow", gcWindow). + Msg("the configured maximum revision staleness exceeds the configured gc window, so capping to gcWindow") + maxRevisionStaleness = gcWindow - 1 + } + + revisions := &RemoteClockRevisions{ + CachedOptimizedRevisions: NewCachedOptimizedRevisions( + maxRevisionStaleness, + ), + gcWindowNanos: gcWindow.Nanoseconds(), + followerReadDelayNanos: followerReadDelay.Nanoseconds(), + quantizationNanos: quantization.Nanoseconds(), + } + + revisions.SetOptimizedRevisionFunc(revisions.optimizedRevisionFunc) + + return revisions +} + +func (rcr *RemoteClockRevisions) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, error) { + nowRev, err := rcr.nowFunc(ctx) + if err != nil { + return datastore.NoRevision, 0, err + } + + if nowRev == datastore.NoRevision { + return datastore.NoRevision, 0, datastore.NewInvalidRevisionErr(nowRev, datastore.CouldNotDetermineRevision) + } + + nowTS, ok := nowRev.(WithTimestampRevision) + if !ok { + return datastore.NoRevision, 0, spiceerrors.MustBugf("expected with-timestamp revision, got %T", nowRev) + } + + delayedNow := nowTS.TimestampNanoSec() - rcr.followerReadDelayNanos + quantized := delayedNow + validForNanos := int64(0) + if rcr.quantizationNanos > 0 { + afterLastQuantization := delayedNow % rcr.quantizationNanos + quantized -= afterLastQuantization + validForNanos = rcr.quantizationNanos - afterLastQuantization + } + log.Ctx(ctx).Debug(). + Time("quantized", time.Unix(0, quantized)). + Int64("readSkew", rcr.followerReadDelayNanos). + Int64("totalSkew", nowTS.TimestampNanoSec()-quantized). + Msg("revision skews") + + return nowTS.ConstructForTimestamp(quantized), time.Duration(validForNanos) * time.Nanosecond, nil +} + +// SetNowFunc sets the function used to determine the head revision +func (rcr *RemoteClockRevisions) SetNowFunc(nowFunc RemoteNowFunction) { + rcr.nowFunc = nowFunc +} + +func (rcr *RemoteClockRevisions) CheckRevision(ctx context.Context, dsRevision datastore.Revision) error { + if dsRevision == datastore.NoRevision { + return datastore.NewInvalidRevisionErr(dsRevision, datastore.CouldNotDetermineRevision) + } + + revision := dsRevision.(WithTimestampRevision) + + ctx, span := tracer.Start(ctx, "CheckRevision") + defer span.End() + + // Make sure the system time indicated is within the software GC window + now, err := rcr.nowFunc(ctx) + if err != nil { + return err + } + + nowTS, ok := now.(WithTimestampRevision) + if !ok { + return spiceerrors.MustBugf("expected HLC revision, got %T", now) + } + + nowNanos := nowTS.TimestampNanoSec() + revisionNanos := revision.TimestampNanoSec() + + isStale := revisionNanos < (nowNanos - rcr.gcWindowNanos) + if isStale { + log.Ctx(ctx).Debug().Stringer("now", now).Stringer("revision", revision).Msg("stale revision") + return datastore.NewInvalidRevisionErr(revision, datastore.RevisionStale) + } + + isUnknown := revisionNanos > nowNanos + if isUnknown { + log.Ctx(ctx).Debug().Stringer("now", now).Stringer("revision", revision).Msg("unknown revision") + return datastore.NewInvalidRevisionErr(revision, datastore.CouldNotDetermineRevision) + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go new file mode 100644 index 0000000..fc2a250 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go @@ -0,0 +1,97 @@ +package revisions + +import ( + "fmt" + "strconv" + "time" + + "github.com/authzed/spicedb/pkg/datastore" +) + +// TimestampRevision is a revision that is a timestamp. +type TimestampRevision int64 + +var zeroTimestampRevision = TimestampRevision(0) + +// NewForTime creates a new revision for the given time. +func NewForTime(time time.Time) TimestampRevision { + return TimestampRevision(time.UnixNano()) +} + +// NewForTimestamp creates a new revision for the given timestamp. +func NewForTimestamp(timestampNanosec int64) TimestampRevision { + return TimestampRevision(timestampNanosec) +} + +// parseTimestampRevisionString parses a string into a timestamp revision. +func parseTimestampRevisionString(revisionStr string) (rev datastore.Revision, err error) { + parsed, err := strconv.ParseInt(revisionStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid integer revision: %w", err) + } + + return TimestampRevision(parsed), nil +} + +func (ir TimestampRevision) ByteSortable() bool { + return true +} + +func (ir TimestampRevision) Equal(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTimestampRevision + } + + return int64(ir) == int64(rhs.(TimestampRevision)) +} + +func (ir TimestampRevision) GreaterThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTimestampRevision + } + + return int64(ir) > int64(rhs.(TimestampRevision)) +} + +func (ir TimestampRevision) LessThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTimestampRevision + } + + return int64(ir) < int64(rhs.(TimestampRevision)) +} + +func (ir TimestampRevision) TimestampNanoSec() int64 { + return int64(ir) +} + +func (ir TimestampRevision) String() string { + return strconv.FormatInt(int64(ir), 10) +} + +func (ir TimestampRevision) Time() time.Time { + return time.Unix(0, int64(ir)) +} + +func (ir TimestampRevision) WithInexactFloat64() float64 { + return float64(ir) +} + +func (ir TimestampRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevision { + return TimestampRevision(timestamp) +} + +var ( + _ datastore.Revision = TimestampRevision(0) + _ WithTimestampRevision = TimestampRevision(0) +) + +// TimestampIDKeyFunc is used to create keys for timestamps. +func TimestampIDKeyFunc(r TimestampRevision) int64 { + return int64(r) +} + +// TimestampIDKeyLessThanFunc is used to create keys for timestamps. +func TimestampIDKeyLessThanFunc(l, r int64) bool { + return l < r +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go new file mode 100644 index 0000000..31d837f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go @@ -0,0 +1,80 @@ +package revisions + +import ( + "fmt" + "strconv" + + "github.com/authzed/spicedb/pkg/datastore" +) + +// TransactionIDRevision is a revision that is a transaction ID. +type TransactionIDRevision uint64 + +var zeroTransactionIDRevision = TransactionIDRevision(0) + +// NewForTransactionID creates a new revision for the given transaction ID. +func NewForTransactionID(transactionID uint64) TransactionIDRevision { + return TransactionIDRevision(transactionID) +} + +// parseTransactionIDRevisionString parses a string into a transaction ID revision. +func parseTransactionIDRevisionString(revisionStr string) (rev datastore.Revision, err error) { + parsed, err := strconv.ParseUint(revisionStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid integer revision: %w", err) + } + + return TransactionIDRevision(parsed), nil +} + +func (ir TransactionIDRevision) ByteSortable() bool { + return true +} + +func (ir TransactionIDRevision) Equal(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTransactionIDRevision + } + + return uint64(ir) == uint64(rhs.(TransactionIDRevision)) +} + +func (ir TransactionIDRevision) GreaterThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTransactionIDRevision + } + + return uint64(ir) > uint64(rhs.(TransactionIDRevision)) +} + +func (ir TransactionIDRevision) LessThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTransactionIDRevision + } + + return uint64(ir) < uint64(rhs.(TransactionIDRevision)) +} + +func (ir TransactionIDRevision) TransactionID() uint64 { + return uint64(ir) +} + +func (ir TransactionIDRevision) String() string { + return strconv.FormatUint(uint64(ir), 10) +} + +func (ir TransactionIDRevision) WithInexactFloat64() float64 { + return float64(ir) +} + +var _ datastore.Revision = TransactionIDRevision(0) + +// TransactionIDKeyFunc is used to create keys for transaction IDs. +func TransactionIDKeyFunc(r TransactionIDRevision) uint64 { + return uint64(r) +} + +// TransactionIDKeyLessThanFunc is used to create keys for transaction IDs. +func TransactionIDKeyLessThanFunc(l, r uint64) bool { + return l < r +} |
