diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-24 17:58:01 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-24 17:58:01 -0600 |
| commit | 72296119fc9755774719f8f625ad03e0e0ec457a (patch) | |
| tree | ed236ddee12a20fb55b7cfecf13f62d3a000dcb5 /vendor/github.com/authzed/spicedb/pkg/development | |
| parent | a920a8cfe415858bb2777371a77018599ffed23f (diff) | |
| parent | eaa1bd3b8e12934aed06413d75e7482ac58d805a (diff) | |
Merge branch 'the-spice-must-flow' into 'main'
Add SpiceDB Authorization
See merge request gitlab-org/software-supply-chain-security/authorization/sparkled!19
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/development')
10 files changed, 1998 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/development/assertions.go b/vendor/github.com/authzed/spicedb/pkg/development/assertions.go new file mode 100644 index 0000000..7b08e6b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/assertions.go @@ -0,0 +1,97 @@ +package development + +import ( + "fmt" + + "github.com/ccoveille/go-safecast" + + log "github.com/authzed/spicedb/internal/logging" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/validationfile/blocks" +) + +const maxDispatchDepth = 25 + +// RunAllAssertions runs all assertions found in the given assertions block against the +// developer context, returning whether any errors occurred. +func RunAllAssertions(devContext *DevContext, assertions *blocks.Assertions) ([]*devinterface.DeveloperError, error) { + trueFailures, err := runAssertions(devContext, assertions.AssertTrue, v1.ResourceCheckResult_MEMBER, "Expected relation or permission %s to exist") + if err != nil { + return nil, err + } + + caveatedFailures, err := runAssertions(devContext, assertions.AssertCaveated, v1.ResourceCheckResult_CAVEATED_MEMBER, "Expected relation or permission %s to be caveated") + if err != nil { + return nil, err + } + + falseFailures, err := runAssertions(devContext, assertions.AssertFalse, v1.ResourceCheckResult_NOT_MEMBER, "Expected relation or permission %s to not exist") + if err != nil { + return nil, err + } + + failures := append(trueFailures, caveatedFailures...) + failures = append(failures, falseFailures...) + return failures, nil +} + +func runAssertions(devContext *DevContext, assertions []blocks.Assertion, expected v1.ResourceCheckResult_Membership, fmtString string) ([]*devinterface.DeveloperError, error) { + var failures []*devinterface.DeveloperError + + for _, assertion := range assertions { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToUint32(assertion.SourcePosition.LineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToUint32(assertion.SourcePosition.ColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + + rel := assertion.Relationship + if rel.OptionalCaveat != nil { + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf("cannot specify a caveat on an assertion: `%s`", assertion.RelationshipWithContextString), + Source: devinterface.DeveloperError_ASSERTION, + Kind: devinterface.DeveloperError_UNKNOWN_RELATION, + Context: assertion.RelationshipWithContextString, + Line: lineNumber, + Column: columnPosition, + }) + continue + } + + cr, err := RunCheck(devContext, rel.Resource, rel.Subject, assertion.CaveatContext) + if err != nil { + devErr, wireErr := DistinguishGraphError( + devContext, + err, + devinterface.DeveloperError_ASSERTION, + lineNumber, + columnPosition, + assertion.RelationshipWithContextString, + ) + if wireErr != nil { + return nil, wireErr + } + if devErr != nil { + failures = append(failures, devErr) + } + } else if cr.Permissionship != expected { + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf(fmtString, assertion.RelationshipWithContextString), + Source: devinterface.DeveloperError_ASSERTION, + Kind: devinterface.DeveloperError_ASSERTION_FAILED, + Context: assertion.RelationshipWithContextString, + Line: lineNumber, + Column: columnPosition, + CheckDebugInformation: cr.DispatchDebugInfo, + CheckResolvedDebugInformation: cr.V1DebugInfo, + }) + } + } + + return failures, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/development/check.go b/vendor/github.com/authzed/spicedb/pkg/development/check.go new file mode 100644 index 0000000..292a1ea --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/check.go @@ -0,0 +1,53 @@ +package development + +import ( + v1api "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/graph/computed" + v1 "github.com/authzed/spicedb/internal/services/v1" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + v1dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const defaultWasmDispatchChunkSize = 100 + +// CheckResult is the result of a RunCheck operation. +type CheckResult struct { + Permissionship v1dispatch.ResourceCheckResult_Membership + MissingCaveatFields []string + DispatchDebugInfo *v1dispatch.DebugInformation + V1DebugInfo *v1api.DebugInformation +} + +// RunCheck performs a check against the data in the development context. +// +// Note that it is up to the caller to call DistinguishGraphError on the error +// if they want to distinguish between user errors and internal errors. +func RunCheck(devContext *DevContext, resource tuple.ObjectAndRelation, subject tuple.ObjectAndRelation, caveatContext map[string]any) (CheckResult, error) { + ctx := devContext.Ctx + cr, meta, err := computed.ComputeCheck(ctx, devContext.Dispatcher, + caveattypes.Default.TypeSet, + computed.CheckParameters{ + ResourceType: resource.RelationReference(), + Subject: subject, + CaveatContext: caveatContext, + AtRevision: devContext.Revision, + MaximumDepth: maxDispatchDepth, + DebugOption: computed.TraceDebuggingEnabled, + }, + resource.ObjectID, + defaultWasmDispatchChunkSize, + ) + if err != nil { + return CheckResult{v1dispatch.ResourceCheckResult_NOT_MEMBER, nil, nil, nil}, err + } + + reader := devContext.Datastore.SnapshotReader(devContext.Revision) + converted, err := v1.ConvertCheckDispatchDebugInformation(ctx, caveattypes.Default.TypeSet, caveatContext, meta.DebugInfo, reader) + if err != nil { + return CheckResult{v1dispatch.ResourceCheckResult_NOT_MEMBER, nil, nil, nil}, err + } + + return CheckResult{cr.Membership, cr.MissingExprFields, meta.DebugInfo, converted}, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/development/devcontext.go b/vendor/github.com/authzed/spicedb/pkg/development/devcontext.go new file mode 100644 index 0000000..f0f2e38 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/devcontext.go @@ -0,0 +1,470 @@ +package development + +import ( + "context" + "errors" + "net" + "time" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/ccoveille/go-safecast" + humanize "github.com/dustin/go-humanize" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" + + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/dispatch/graph" + maingraph "github.com/authzed/spicedb/internal/graph" + "github.com/authzed/spicedb/internal/grpchelpers" + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/relationships" + v1svc "github.com/authzed/spicedb/internal/services/v1" + "github.com/authzed/spicedb/internal/sharederrors" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +const defaultConnBufferSize = humanize.MiByte + +// DevContext holds the various helper types for running the developer calls. +type DevContext struct { + Ctx context.Context + Datastore datastore.Datastore + Revision datastore.Revision + CompiledSchema *compiler.CompiledSchema + Dispatcher dispatch.Dispatcher +} + +// NewDevContext creates a new DevContext from the specified request context, parsing and populating +// the datastore as needed. +func NewDevContext(ctx context.Context, requestContext *devinterface.RequestContext) (*DevContext, *devinterface.DeveloperErrors, error) { + ds, err := memdb.NewMemdbDatastore(0, 0*time.Second, memdb.DisableGC) + if err != nil { + return nil, nil, err + } + ctx = datastoremw.ContextWithDatastore(ctx, ds) + + dctx, devErrs, nerr := newDevContextWithDatastore(ctx, requestContext, ds) + if nerr != nil || devErrs != nil { + // If any form of error occurred, immediately close the datastore + derr := ds.Close() + if derr != nil { + return nil, nil, derr + } + + return dctx, devErrs, nerr + } + + return dctx, nil, nil +} + +func newDevContextWithDatastore(ctx context.Context, requestContext *devinterface.RequestContext, ds datastore.Datastore) (*DevContext, *devinterface.DeveloperErrors, error) { + // Compile the schema and load its caveats and namespaces into the datastore. + compiled, devError, err := CompileSchema(requestContext.Schema) + if err != nil { + return nil, nil, err + } + + if devError != nil { + return nil, &devinterface.DeveloperErrors{InputErrors: []*devinterface.DeveloperError{devError}}, nil + } + + var inputErrors []*devinterface.DeveloperError + currentRevision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + inputErrors, err = loadCompiled(ctx, compiled, rwt) + if err != nil || len(inputErrors) > 0 { + return err + } + + // Load the test relationships into the datastore. + relationships := make([]tuple.Relationship, 0, len(requestContext.Relationships)) + for _, rel := range requestContext.Relationships { + if err := rel.Validate(); err != nil { + inputErrors = append(inputErrors, &devinterface.DeveloperError{ + Message: err.Error(), + Source: devinterface.DeveloperError_RELATIONSHIP, + Kind: devinterface.DeveloperError_PARSE_ERROR, + Context: tuple.CoreRelationToStringWithoutCaveatOrExpiration(rel), + }) + } + + convertedRel := tuple.FromCoreRelationTuple(rel) + if err := convertedRel.Validate(); err != nil { + tplString, serr := tuple.String(convertedRel) + if serr != nil { + return serr + } + + inputErrors = append(inputErrors, &devinterface.DeveloperError{ + Message: err.Error(), + Source: devinterface.DeveloperError_RELATIONSHIP, + Kind: devinterface.DeveloperError_PARSE_ERROR, + Context: tplString, + }) + } + + relationships = append(relationships, convertedRel) + } + + ie, lerr := loadsRels(ctx, relationships, rwt) + if len(ie) > 0 { + inputErrors = append(inputErrors, ie...) + } + + return lerr + }) + + if err != nil || len(inputErrors) > 0 { + return nil, &devinterface.DeveloperErrors{InputErrors: inputErrors}, err + } + + // Sanity check: Make sure the request context for the developer is fully valid. We do this after + // the loading to ensure that any user-created errors are reported as developer errors, + // rather than internal errors. + verr := requestContext.Validate() + if verr != nil { + return nil, nil, verr + } + + return &DevContext{ + Ctx: ctx, + Datastore: ds, + CompiledSchema: compiled, + Revision: currentRevision, + Dispatcher: graph.NewLocalOnlyDispatcher(caveattypes.Default.TypeSet, 10, 100), + }, nil, nil +} + +// RunV1InMemoryService runs a V1 server in-memory on a buffconn over the given +// development context and returns a client connection and a function to shutdown +// the server. It is the responsibility of the caller to call the function to close +// the server. +func (dc *DevContext) RunV1InMemoryService() (*grpc.ClientConn, func(), error) { + listener := bufconn.Listen(defaultConnBufferSize) + + s := grpc.NewServer( + grpc.ChainUnaryInterceptor( + datastoremw.UnaryServerInterceptor(dc.Datastore), + consistency.UnaryServerInterceptor("development"), + ), + grpc.ChainStreamInterceptor( + datastoremw.StreamServerInterceptor(dc.Datastore), + consistency.StreamServerInterceptor("development"), + ), + ) + ps := v1svc.NewPermissionsServer(dc.Dispatcher, v1svc.PermissionsServerConfig{ + MaxUpdatesPerWrite: 50, + MaxPreconditionsCount: 50, + MaximumAPIDepth: 50, + MaxCaveatContextSize: 0, + ExpiringRelationshipsEnabled: true, + CaveatTypeSet: caveattypes.Default.TypeSet, + }) + ss := v1svc.NewSchemaServer(caveattypes.Default.TypeSet, false, true) + + v1.RegisterPermissionsServiceServer(s, ps) + v1.RegisterSchemaServiceServer(s, ss) + + go func() { + if err := s.Serve(listener); err != nil { + log.Err(err).Msg("error when serving in-memory server") + } + }() + + conn, err := grpchelpers.DialAndWait( + context.Background(), + "", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + return conn, func() { + conn.Close() + listener.Close() + s.Stop() + }, err +} + +// Dispose disposes of the DevContext and its underlying datastore. +func (dc *DevContext) Dispose() { + if dc.Dispatcher == nil { + return + } + if err := dc.Dispatcher.Close(); err != nil { + log.Ctx(dc.Ctx).Err(err).Msg("error when disposing of dispatcher in devcontext") + } + + if dc.Datastore == nil { + return + } + + if err := dc.Datastore.Close(); err != nil { + log.Ctx(dc.Ctx).Err(err).Msg("error when disposing of datastore in devcontext") + } +} + +func loadsRels(ctx context.Context, rels []tuple.Relationship, rwt datastore.ReadWriteTransaction) ([]*devinterface.DeveloperError, error) { + devErrors := make([]*devinterface.DeveloperError, 0, len(rels)) + updates := make([]tuple.RelationshipUpdate, 0, len(rels)) + for _, rel := range rels { + if err := relationships.ValidateRelationshipsForCreateOrTouch(ctx, rwt, caveattypes.Default.TypeSet, rel); err != nil { + relString, serr := tuple.String(rel) + if serr != nil { + return nil, serr + } + + devErr, wireErr := distinguishGraphError(ctx, err, devinterface.DeveloperError_RELATIONSHIP, 0, 0, relString) + if wireErr != nil { + return devErrors, wireErr + } + + if devErr != nil { + devErrors = append(devErrors, devErr) + } + } + + updates = append(updates, tuple.Touch(rel)) + } + + err := rwt.WriteRelationships(ctx, updates) + return devErrors, err +} + +func loadCompiled( + ctx context.Context, + compiled *compiler.CompiledSchema, + rwt datastore.ReadWriteTransaction, +) ([]*devinterface.DeveloperError, error) { + errors := make([]*devinterface.DeveloperError, 0, len(compiled.OrderedDefinitions)) + ts := schema.NewTypeSystem(schema.ResolverForCompiledSchema(*compiled)) + + for _, caveatDef := range compiled.CaveatDefinitions { + cverr := namespace.ValidateCaveatDefinition(caveattypes.Default.TypeSet, caveatDef) + if cverr == nil { + if err := rwt.WriteCaveats(ctx, []*core.CaveatDefinition{caveatDef}); err != nil { + return errors, err + } + continue + } + + errWithSource, ok := spiceerrors.AsWithSourceError(cverr) + if ok { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToUint32(errWithSource.LineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToUint32(errWithSource.ColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + errors = append(errors, &devinterface.DeveloperError{ + Message: cverr.Error(), + Kind: devinterface.DeveloperError_SCHEMA_ISSUE, + Source: devinterface.DeveloperError_SCHEMA, + Context: errWithSource.SourceCodeString, + Line: lineNumber, + Column: columnPosition, + }) + } else { + errors = append(errors, &devinterface.DeveloperError{ + Message: cverr.Error(), + Kind: devinterface.DeveloperError_SCHEMA_ISSUE, + Source: devinterface.DeveloperError_SCHEMA, + Context: caveatDef.Name, + }) + } + } + + for _, nsDef := range compiled.ObjectDefinitions { + def, terr := schema.NewDefinition(ts, nsDef) + if terr != nil { + errWithSource, ok := spiceerrors.AsWithSourceError(terr) + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToUint32(errWithSource.LineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToUint32(errWithSource.ColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + if ok { + errors = append(errors, &devinterface.DeveloperError{ + Message: terr.Error(), + Kind: devinterface.DeveloperError_SCHEMA_ISSUE, + Source: devinterface.DeveloperError_SCHEMA, + Context: errWithSource.SourceCodeString, + Line: lineNumber, + Column: columnPosition, + }) + continue + } + + errors = append(errors, &devinterface.DeveloperError{ + Message: terr.Error(), + Kind: devinterface.DeveloperError_SCHEMA_ISSUE, + Source: devinterface.DeveloperError_SCHEMA, + Context: nsDef.Name, + }) + continue + } + + _, tverr := def.Validate(ctx) + if tverr == nil { + if err := rwt.WriteNamespaces(ctx, nsDef); err != nil { + return errors, err + } + continue + } + + errWithSource, ok := spiceerrors.AsWithSourceError(tverr) + if ok { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToUint32(errWithSource.LineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToUint32(errWithSource.ColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + errors = append(errors, &devinterface.DeveloperError{ + Message: tverr.Error(), + Kind: devinterface.DeveloperError_SCHEMA_ISSUE, + Source: devinterface.DeveloperError_SCHEMA, + Context: errWithSource.SourceCodeString, + Line: lineNumber, + Column: columnPosition, + }) + } else { + errors = append(errors, &devinterface.DeveloperError{ + Message: tverr.Error(), + Kind: devinterface.DeveloperError_SCHEMA_ISSUE, + Source: devinterface.DeveloperError_SCHEMA, + Context: nsDef.Name, + }) + } + } + + return errors, nil +} + +// DistinguishGraphError turns an error from a dispatch call into either a user-facing +// DeveloperError or an internal error, based on the error raised by the dispatcher. +func DistinguishGraphError(devContext *DevContext, dispatchError error, source devinterface.DeveloperError_Source, line uint32, column uint32, context string) (*devinterface.DeveloperError, error) { + return distinguishGraphError(devContext.Ctx, dispatchError, source, line, column, context) +} + +func distinguishGraphError(ctx context.Context, dispatchError error, source devinterface.DeveloperError_Source, line uint32, column uint32, context string) (*devinterface.DeveloperError, error) { + var nsNotFoundError sharederrors.UnknownNamespaceError + var relNotFoundError sharederrors.UnknownRelationError + var invalidRelError relationships.InvalidSubjectTypeError + var maxDepthErr dispatch.MaxDepthExceededError + + if errors.As(dispatchError, &maxDepthErr) { + return &devinterface.DeveloperError{ + Message: dispatchError.Error(), + Source: source, + Kind: devinterface.DeveloperError_MAXIMUM_RECURSION, + Line: line, + Column: column, + Context: context, + }, nil + } + + details, ok := spiceerrors.GetDetails[*errdetails.ErrorInfo](dispatchError) + if ok && details.Reason == "ERROR_REASON_MAXIMUM_DEPTH_EXCEEDED" { + status, _ := status.FromError(dispatchError) + return &devinterface.DeveloperError{ + Message: status.Message(), + Source: source, + Kind: devinterface.DeveloperError_MAXIMUM_RECURSION, + Line: line, + Column: column, + Context: context, + }, nil + } + + if errors.As(dispatchError, &invalidRelError) { + return &devinterface.DeveloperError{ + Message: dispatchError.Error(), + Source: source, + Kind: devinterface.DeveloperError_INVALID_SUBJECT_TYPE, + Line: line, + Column: column, + Context: context, + }, nil + } + + if errors.As(dispatchError, &nsNotFoundError) { + return &devinterface.DeveloperError{ + Message: dispatchError.Error(), + Source: source, + Kind: devinterface.DeveloperError_UNKNOWN_OBJECT_TYPE, + Line: line, + Column: column, + Context: context, + }, nil + } + + if errors.As(dispatchError, &relNotFoundError) { + return &devinterface.DeveloperError{ + Message: dispatchError.Error(), + Source: source, + Kind: devinterface.DeveloperError_UNKNOWN_RELATION, + Line: line, + Column: column, + Context: context, + }, nil + } + + return nil, rewriteACLError(ctx, dispatchError) +} + +func rewriteACLError(ctx context.Context, err error) error { + var nsNotFoundError sharederrors.UnknownNamespaceError + var relNotFoundError sharederrors.UnknownRelationError + + switch { + case errors.As(err, &nsNotFoundError): + fallthrough + case errors.As(err, &relNotFoundError): + fallthrough + + case errors.As(err, &datastore.InvalidRevisionError{}): + return status.Errorf(codes.OutOfRange, "invalid zookie: %s", err) + + case errors.As(err, &maingraph.RelationMissingTypeInfoError{}): + return status.Errorf(codes.FailedPrecondition, "failed precondition: %s", err) + + case errors.As(err, &maingraph.AlwaysFailError{}): + log.Ctx(ctx).Err(err).Msg("internal graph error in devcontext") + return status.Errorf(codes.Internal, "internal error: %s", err) + + case errors.Is(err, context.DeadlineExceeded): + return status.Errorf(codes.DeadlineExceeded, "%s", err) + + case errors.Is(err, context.Canceled): + return status.Errorf(codes.Canceled, "%s", err) + + default: + log.Ctx(ctx).Err(err).Msg("unexpected graph error in devcontext") + return err + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/development/doc.go b/vendor/github.com/authzed/spicedb/pkg/development/doc.go new file mode 100644 index 0000000..a4c106a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/doc.go @@ -0,0 +1,2 @@ +// Package development contains code that runs in the Playground. +package development diff --git a/vendor/github.com/authzed/spicedb/pkg/development/parsing.go b/vendor/github.com/authzed/spicedb/pkg/development/parsing.go new file mode 100644 index 0000000..75e27c8 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/parsing.go @@ -0,0 +1,69 @@ +package development + +import ( + "github.com/ccoveille/go-safecast" + + log "github.com/authzed/spicedb/internal/logging" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/validationfile" + "github.com/authzed/spicedb/pkg/validationfile/blocks" +) + +// ParseAssertionsYAML parses the YAML form of an assertions block. +func ParseAssertionsYAML(assertionsYaml string) (*blocks.Assertions, *devinterface.DeveloperError) { + assertions, err := validationfile.ParseAssertionsBlock([]byte(assertionsYaml)) + if err != nil { + serr, ok := spiceerrors.AsWithSourceError(err) + if ok { + return nil, convertSourceError(devinterface.DeveloperError_ASSERTION, serr) + } + } + + return assertions, convertError(devinterface.DeveloperError_ASSERTION, err) +} + +// ParseExpectedRelationsYAML parses the YAML form of an expected relations block. +func ParseExpectedRelationsYAML(expectedRelationsYaml string) (*blocks.ParsedExpectedRelations, *devinterface.DeveloperError) { + block, err := validationfile.ParseExpectedRelationsBlock([]byte(expectedRelationsYaml)) + if err != nil { + serr, ok := spiceerrors.AsWithSourceError(err) + if ok { + return nil, convertSourceError(devinterface.DeveloperError_VALIDATION_YAML, serr) + } + } + return block, convertError(devinterface.DeveloperError_VALIDATION_YAML, err) +} + +func convertError(source devinterface.DeveloperError_Source, err error) *devinterface.DeveloperError { + if err == nil { + return nil + } + + return &devinterface.DeveloperError{ + Message: err.Error(), + Kind: devinterface.DeveloperError_PARSE_ERROR, + Source: source, + Line: 0, + } +} + +func convertSourceError(source devinterface.DeveloperError_Source, err *spiceerrors.WithSourceError) *devinterface.DeveloperError { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, castErr := safecast.ToUint32(err.LineNumber) + if castErr != nil { + log.Err(castErr).Msg("could not cast lineNumber to uint32") + } + columnPosition, castErr := safecast.ToUint32(err.ColumnPosition) + if castErr != nil { + log.Err(castErr).Msg("could not cast columnPosition to uint32") + } + return &devinterface.DeveloperError{ + Message: err.Error(), + Kind: devinterface.DeveloperError_PARSE_ERROR, + Source: source, + Line: lineNumber, + Column: columnPosition, + Context: err.SourceCodeString, + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/development/resolver.go b/vendor/github.com/authzed/spicedb/pkg/development/resolver.go new file mode 100644 index 0000000..ddb8a67 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/resolver.go @@ -0,0 +1,426 @@ +package development + +import ( + "context" + "fmt" + "strings" + + "github.com/ccoveille/go-safecast" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +// ReferenceType is the type of reference. +type ReferenceType int + +const ( + ReferenceTypeUnknown ReferenceType = iota + ReferenceTypeDefinition + ReferenceTypeCaveat + ReferenceTypeRelation + ReferenceTypePermission + ReferenceTypeCaveatParameter +) + +// SchemaReference represents a reference to a schema node. +type SchemaReference struct { + // Source is the source of the reference. + Source input.Source + + // Position is the position of the reference in the source. + Position input.Position + + // Text is the text of the reference. + Text string + + // ReferenceType is the type of reference. + ReferenceType ReferenceType + + // ReferenceMarkdown is the markdown representation of the reference. + ReferenceMarkdown string + + // TargetSource is the source of the target node, if any. + TargetSource *input.Source + + // TargetPosition is the position of the target node, if any. + TargetPosition *input.Position + + // TargetSourceCode is the source code representation of the target, if any. + TargetSourceCode string + + // TargetNamePositionOffset is the offset from the target position from where the + // *name* of the target is found. + TargetNamePositionOffset int +} + +// Resolver resolves references to schema nodes from source positions. +type Resolver struct { + schema *compiler.CompiledSchema + typeSystem *schema.TypeSystem +} + +// NewResolver creates a new resolver for the given schema. +func NewResolver(compiledSchema *compiler.CompiledSchema) (*Resolver, error) { + resolver := schema.ResolverForCompiledSchema(*compiledSchema) + ts := schema.NewTypeSystem(resolver) + return &Resolver{schema: compiledSchema, typeSystem: ts}, nil +} + +// ReferenceAtPosition returns the reference to the schema node at the given position in the source, if any. +func (r *Resolver) ReferenceAtPosition(source input.Source, position input.Position) (*SchemaReference, error) { + nodeChain, err := compiler.PositionToAstNodeChain(r.schema, source, position) + if err != nil { + return nil, err + } + + if nodeChain == nil { + return nil, nil + } + + relationReference := func(relation *core.Relation, def *schema.Definition) (*SchemaReference, error) { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToInt(relation.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToInt(relation.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + relationPosition := input.Position{ + LineNumber: lineNumber, + ColumnPosition: columnPosition, + } + + targetSourceCode, err := generator.GenerateRelationSource(relation, caveattypes.Default.TypeSet) + if err != nil { + return nil, err + } + + if def.IsPermission(relation.Name) { + return &SchemaReference{ + Source: source, + Position: position, + Text: relation.Name, + + ReferenceType: ReferenceTypePermission, + ReferenceMarkdown: fmt.Sprintf("permission %s", relation.Name), + + TargetSource: &source, + TargetPosition: &relationPosition, + TargetSourceCode: targetSourceCode, + TargetNamePositionOffset: len("permission "), + }, nil + } + + return &SchemaReference{ + Source: source, + Position: position, + Text: relation.Name, + + ReferenceType: ReferenceTypeRelation, + ReferenceMarkdown: fmt.Sprintf("relation %s", relation.Name), + + TargetSource: &source, + TargetPosition: &relationPosition, + TargetSourceCode: targetSourceCode, + TargetNamePositionOffset: len("relation "), + }, nil + } + + // Type reference. + if ts, relation, ok := r.typeReferenceChain(nodeChain); ok { + if relation != nil { + return relationReference(relation, ts) + } + + def := ts.Namespace() + + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToInt(def.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToInt(def.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + + defPosition := input.Position{ + LineNumber: lineNumber, + ColumnPosition: columnPosition, + } + + docComment := "" + comments := namespace.GetComments(def.Metadata) + if len(comments) > 0 { + docComment = strings.Join(comments, "\n") + "\n" + } + + targetSourceCode := fmt.Sprintf("%sdefinition %s {\n\t// ...\n}", docComment, def.Name) + if len(def.Relation) == 0 { + targetSourceCode = fmt.Sprintf("%sdefinition %s {}", docComment, def.Name) + } + + return &SchemaReference{ + Source: source, + Position: position, + Text: def.Name, + + ReferenceType: ReferenceTypeDefinition, + ReferenceMarkdown: fmt.Sprintf("definition %s", def.Name), + + TargetSource: &source, + TargetPosition: &defPosition, + TargetSourceCode: targetSourceCode, + TargetNamePositionOffset: len("definition "), + }, nil + } + + // Caveat Type reference. + if caveatDef, ok := r.caveatTypeReferenceChain(nodeChain); ok { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToInt(caveatDef.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToInt(caveatDef.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + + defPosition := input.Position{ + LineNumber: lineNumber, + ColumnPosition: columnPosition, + } + + var caveatSourceCode strings.Builder + caveatSourceCode.WriteString(fmt.Sprintf("caveat %s(", caveatDef.Name)) + index := 0 + for paramName, paramType := range caveatDef.ParameterTypes { + if index > 0 { + caveatSourceCode.WriteString(", ") + } + + caveatSourceCode.WriteString(fmt.Sprintf("%s %s", paramName, caveats.ParameterTypeString(paramType))) + index++ + } + caveatSourceCode.WriteString(") {\n\t// ...\n}") + + return &SchemaReference{ + Source: source, + Position: position, + Text: caveatDef.Name, + + ReferenceType: ReferenceTypeCaveat, + ReferenceMarkdown: fmt.Sprintf("caveat %s", caveatDef.Name), + + TargetSource: &source, + TargetPosition: &defPosition, + TargetSourceCode: caveatSourceCode.String(), + TargetNamePositionOffset: len("caveat "), + }, nil + } + + // Relation reference. + if relation, ts, ok := r.relationReferenceChain(nodeChain); ok { + return relationReference(relation, ts) + } + + // Caveat parameter used in expression. + if caveatParamName, caveatDef, ok := r.caveatParamChain(nodeChain, source, position); ok { + targetSourceCode := fmt.Sprintf("%s %s", caveatParamName, caveats.ParameterTypeString(caveatDef.ParameterTypes[caveatParamName])) + + return &SchemaReference{ + Source: source, + Position: position, + Text: caveatParamName, + + ReferenceType: ReferenceTypeCaveatParameter, + ReferenceMarkdown: targetSourceCode, + + TargetSource: &source, + TargetSourceCode: targetSourceCode, + }, nil + } + + return nil, nil +} + +func (r *Resolver) lookupCaveat(caveatName string) (*core.CaveatDefinition, bool) { + for _, caveatDef := range r.schema.CaveatDefinitions { + if caveatDef.Name == caveatName { + return caveatDef, true + } + } + + return nil, false +} + +func (r *Resolver) lookupRelation(defName, relationName string) (*core.Relation, *schema.Definition, bool) { + ts, err := r.typeSystem.GetDefinition(context.Background(), defName) + if err != nil { + return nil, nil, false + } + + rel, ok := ts.GetRelation(relationName) + if !ok { + return nil, nil, false + } + + return rel, ts, true +} + +func (r *Resolver) caveatParamChain(nodeChain *compiler.NodeChain, source input.Source, position input.Position) (string, *core.CaveatDefinition, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypeCaveatExpression) { + return "", nil, false + } + + caveatDefNode := nodeChain.FindNodeOfType(dslshape.NodeTypeCaveatDefinition) + if caveatDefNode == nil { + return "", nil, false + } + + caveatName, err := caveatDefNode.GetString(dslshape.NodeCaveatDefinitionPredicateName) + if err != nil { + return "", nil, false + } + + caveatDef, ok := r.lookupCaveat(caveatName) + if !ok { + return "", nil, false + } + + runePosition, err := r.schema.SourcePositionToRunePosition(source, position) + if err != nil { + return "", nil, false + } + + exprRunePosition, err := nodeChain.Head().GetInt(dslshape.NodePredicateStartRune) + if err != nil { + return "", nil, false + } + + if exprRunePosition > runePosition { + return "", nil, false + } + + relationRunePosition := runePosition - exprRunePosition + + caveatExpr, err := nodeChain.Head().GetString(dslshape.NodeCaveatExpressionPredicateExpression) + if err != nil { + return "", nil, false + } + + // Split the expression into tokens and find the associated token. + tokens := strings.FieldsFunc(caveatExpr, splitCELToken) + currentIndex := 0 + for _, token := range tokens { + if currentIndex <= relationRunePosition && currentIndex+len(token) >= relationRunePosition { + if _, ok := caveatDef.ParameterTypes[token]; ok { + return token, caveatDef, true + } + } + } + + return "", caveatDef, true +} + +func splitCELToken(r rune) bool { + return r == ' ' || r == '(' || r == ')' || r == '.' || r == ',' || r == '[' || r == ']' || r == '{' || r == '}' || r == ':' || r == '=' +} + +func (r *Resolver) caveatTypeReferenceChain(nodeChain *compiler.NodeChain) (*core.CaveatDefinition, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypeCaveatReference) { + return nil, false + } + + caveatName, err := nodeChain.Head().GetString(dslshape.NodeCaveatPredicateCaveat) + if err != nil { + return nil, false + } + + return r.lookupCaveat(caveatName) +} + +func (r *Resolver) typeReferenceChain(nodeChain *compiler.NodeChain) (*schema.Definition, *core.Relation, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypeSpecificTypeReference) { + return nil, nil, false + } + + defName, err := nodeChain.Head().GetString(dslshape.NodeSpecificReferencePredicateType) + if err != nil { + return nil, nil, false + } + + def, err := r.typeSystem.GetDefinition(context.Background(), defName) + if err != nil { + return nil, nil, false + } + + relationName, err := nodeChain.Head().GetString(dslshape.NodeSpecificReferencePredicateRelation) + if err != nil { + return def, nil, true + } + + startingRune, err := nodeChain.Head().GetInt(dslshape.NodePredicateStartRune) + if err != nil { + return def, nil, true + } + + // If hover over the definition name, return the definition. + if nodeChain.ForRunePosition() < startingRune+len(defName) { + return def, nil, true + } + + relation, ok := def.GetRelation(relationName) + if !ok { + return nil, nil, false + } + + return def, relation, true +} + +func (r *Resolver) relationReferenceChain(nodeChain *compiler.NodeChain) (*core.Relation, *schema.Definition, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypeIdentifier) { + return nil, nil, false + } + + if arrowExpr := nodeChain.FindNodeOfType(dslshape.NodeTypeArrowExpression); arrowExpr != nil { + // Ensure this on the left side of the arrow. + rightExpr, err := arrowExpr.Lookup(dslshape.NodeExpressionPredicateRightExpr) + if err != nil { + return nil, nil, false + } + + if rightExpr == nodeChain.Head() { + return nil, nil, false + } + } + + relationName, err := nodeChain.Head().GetString(dslshape.NodeIdentiferPredicateValue) + if err != nil { + return nil, nil, false + } + + parentDefNode := nodeChain.FindNodeOfType(dslshape.NodeTypeDefinition) + if parentDefNode == nil { + return nil, nil, false + } + + defName, err := parentDefNode.GetString(dslshape.NodeDefinitionPredicateName) + if err != nil { + return nil, nil, false + } + + return r.lookupRelation(defName, relationName) +} diff --git a/vendor/github.com/authzed/spicedb/pkg/development/schema.go b/vendor/github.com/authzed/spicedb/pkg/development/schema.go new file mode 100644 index 0000000..0e50451 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/schema.go @@ -0,0 +1,54 @@ +package development + +import ( + "errors" + + "github.com/ccoveille/go-safecast" + + log "github.com/authzed/spicedb/internal/logging" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +// CompileSchema compiles a schema into its caveat and namespace definition(s), returning a developer +// error if the schema could not be compiled. The non-developer error is returned only if an +// internal errors occurred. +func CompileSchema(schema string) (*compiler.CompiledSchema, *devinterface.DeveloperError, error) { + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: schema, + }, compiler.AllowUnprefixedObjectType()) + + var contextError compiler.WithContextError + if errors.As(err, &contextError) { + line, col, lerr := contextError.SourceRange.Start().LineAndColumn() + if lerr != nil { + return nil, nil, lerr + } + + // NOTE: zeroes are fine here on failure. + uintLine, err := safecast.ToUint32(line) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + uintColumn, err := safecast.ToUint32(col) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + return nil, &devinterface.DeveloperError{ + Message: contextError.BaseCompilerError.BaseMessage, + Kind: devinterface.DeveloperError_SCHEMA_ISSUE, + Source: devinterface.DeveloperError_SCHEMA, + Line: uintLine + 1, // 0-indexed in parser. + Column: uintColumn + 1, // 0-indexed in parser. + Context: contextError.ErrorSourceCode, + }, nil + } + + if err != nil { + return nil, nil, err + } + + return compiled, nil, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/development/validation.go b/vendor/github.com/authzed/spicedb/pkg/development/validation.go new file mode 100644 index 0000000..b3da6cf --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/validation.go @@ -0,0 +1,286 @@ +package development + +import ( + "fmt" + "sort" + "strings" + + "github.com/ccoveille/go-safecast" + "github.com/google/go-cmp/cmp" + yaml "gopkg.in/yaml.v2" + + "github.com/authzed/spicedb/internal/developmentmembership" + log "github.com/authzed/spicedb/internal/logging" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/validationfile/blocks" +) + +// RunValidation runs the parsed validation block against the data in the dev context. +func RunValidation(devContext *DevContext, validation *blocks.ParsedExpectedRelations) (*developmentmembership.Set, []*devinterface.DeveloperError, error) { + var failures []*devinterface.DeveloperError + membershipSet := developmentmembership.NewMembershipSet() + ctx := devContext.Ctx + + for onrKey, expectedSubjects := range validation.ValidationMap { + // Run a full recursive expansion over the ONR. + er, derr := devContext.Dispatcher.DispatchExpand(ctx, &v1.DispatchExpandRequest{ + ResourceAndRelation: onrKey.ObjectAndRelation.ToCoreONR(), + Metadata: &v1.ResolverMeta{ + AtRevision: devContext.Revision.String(), + DepthRemaining: maxDispatchDepth, + TraversalBloom: v1.MustNewTraversalBloomFilter(uint(maxDispatchDepth)), + }, + ExpansionMode: v1.DispatchExpandRequest_RECURSIVE, + }) + if derr != nil { + devErr, wireErr := DistinguishGraphError(devContext, derr, devinterface.DeveloperError_VALIDATION_YAML, 0, 0, onrKey.ObjectRelationString) + if wireErr != nil { + return nil, nil, wireErr + } + + failures = append(failures, devErr) + continue + } + + // Add the ONR and its expansion to the membership set. + foundSubjects, _, aerr := membershipSet.AddExpansion(onrKey.ObjectAndRelation, er.TreeNode) + if aerr != nil { + devErr, wireErr := DistinguishGraphError(devContext, aerr, devinterface.DeveloperError_VALIDATION_YAML, 0, 0, onrKey.ObjectRelationString) + if wireErr != nil { + return nil, nil, wireErr + } + + failures = append(failures, devErr) + continue + } + + // Compare the terminal subjects found to those specified. + errs := validateSubjects(onrKey, foundSubjects, expectedSubjects) + failures = append(failures, errs...) + } + + if len(failures) > 0 { + return membershipSet, failures, nil + } + + return membershipSet, nil, nil +} + +func wrapResources(onrStrings []string) []string { + wrapped := make([]string, 0, len(onrStrings)) + for _, str := range onrStrings { + wrapped = append(wrapped, "<"+str+">") + } + + // Sort to ensure stability. + sort.Strings(wrapped) + return wrapped +} + +func validateSubjects(onrKey blocks.ObjectRelation, fs developmentmembership.FoundSubjects, expectedSubjects []blocks.ExpectedSubject) []*devinterface.DeveloperError { + onr := onrKey.ObjectAndRelation + + var failures []*devinterface.DeveloperError + + // Verify that every referenced subject is found in the membership. + encounteredSubjects := map[string]struct{}{} + for _, expectedSubject := range expectedSubjects { + subjectWithExceptions := expectedSubject.SubjectWithExceptions + // NOTE: zeroes are fine here on failure. + lineNumber, err := safecast.ToUint32(expectedSubject.SourcePosition.LineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToUint32(expectedSubject.SourcePosition.ColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + if subjectWithExceptions == nil { + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, no expected subject specified in `%s`", tuple.StringONR(onr), expectedSubject.ValidationString), + Source: devinterface.DeveloperError_VALIDATION_YAML, + Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, + Context: string(expectedSubject.ValidationString), + Line: lineNumber, + Column: columnPosition, + }) + continue + } + + encounteredSubjects[tuple.StringONR(subjectWithExceptions.Subject.Subject)] = struct{}{} + + subject, ok := fs.LookupSubject(subjectWithExceptions.Subject.Subject) + if !ok { + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, missing expected subject `%s`", tuple.StringONR(onr), tuple.StringONR(subjectWithExceptions.Subject.Subject)), + Source: devinterface.DeveloperError_VALIDATION_YAML, + Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, + Context: string(expectedSubject.ValidationString), + Line: lineNumber, + Column: columnPosition, + }) + continue + } + + // Verify that the relationships are the same. + foundParentResources := subject.ParentResources() + expectedONRStrings := tuple.StringsONRs(expectedSubject.Resources) + foundONRStrings := tuple.StringsONRs(foundParentResources) + if !cmp.Equal(expectedONRStrings, foundONRStrings) { + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, found different relationships for subject `%s`: Specified: `%s`, Computed: `%s`", + tuple.StringONR(onr), + tuple.StringONR(subjectWithExceptions.Subject.Subject), + strings.Join(wrapResources(expectedONRStrings), "/"), + strings.Join(wrapResources(foundONRStrings), "/"), + ), + Source: devinterface.DeveloperError_VALIDATION_YAML, + Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, + Context: string(expectedSubject.ValidationString), + Line: lineNumber, + Column: columnPosition, + }) + } + + // Verify exclusions are the same, if any. + foundExcludedSubjects, isWildcard := subject.ExcludedSubjectsFromWildcard() + expectedExcludedSubjects := subjectWithExceptions.Exceptions + if isWildcard { + expectedExcludedStrings := toExpectedRelationshipsStrings(expectedExcludedSubjects) + foundExcludedONRStrings := toFoundRelationshipsStrings(foundExcludedSubjects) + + sort.Strings(expectedExcludedStrings) + sort.Strings(foundExcludedONRStrings) + + if !cmp.Equal(expectedExcludedStrings, foundExcludedONRStrings) { + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, found different excluded subjects for subject `%s`: Specified: `%s`, Computed: `%s`", + tuple.StringONR(onr), + tuple.StringONR(subjectWithExceptions.Subject.Subject), + strings.Join(wrapResources(expectedExcludedStrings), ", "), + strings.Join(wrapResources(foundExcludedONRStrings), ", "), + ), + Source: devinterface.DeveloperError_VALIDATION_YAML, + Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, + Context: string(expectedSubject.ValidationString), + Line: lineNumber, + Column: columnPosition, + }) + } + } else { + if len(expectedExcludedSubjects) > 0 { + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, found unexpected excluded subjects", + tuple.StringONR(onr), + ), + Source: devinterface.DeveloperError_VALIDATION_YAML, + Kind: devinterface.DeveloperError_EXTRA_RELATIONSHIP_FOUND, + Context: string(expectedSubject.ValidationString), + Line: lineNumber, + Column: columnPosition, + }) + } + } + + // Verify caveats. + if (subject.GetCaveatExpression() != nil) != subjectWithExceptions.Subject.IsCaveated { + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, found caveat mismatch", + tuple.StringONR(onr), + ), + Source: devinterface.DeveloperError_VALIDATION_YAML, + Kind: devinterface.DeveloperError_MISSING_EXPECTED_RELATIONSHIP, + Context: string(expectedSubject.ValidationString), + Line: lineNumber, + Column: columnPosition, + }) + } + } + + // Verify that every subject found was referenced. + for _, foundSubject := range fs.ListFound() { + _, ok := encounteredSubjects[tuple.StringONR(foundSubject.Subject())] + if !ok { + onrLineNumber, err := safecast.ToUint32(onrKey.SourcePosition.LineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + onrColumnPosition, err := safecast.ToUint32(onrKey.SourcePosition.ColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + failures = append(failures, &devinterface.DeveloperError{ + Message: fmt.Sprintf("For object and permission/relation `%s`, subject `%s` found but not listed in expected subjects", + tuple.StringONR(onr), + tuple.StringONR(foundSubject.Subject()), + ), + Source: devinterface.DeveloperError_VALIDATION_YAML, + Kind: devinterface.DeveloperError_EXTRA_RELATIONSHIP_FOUND, + Context: tuple.StringONR(onr), + Line: onrLineNumber, + Column: onrColumnPosition, + }) + } + } + + return failures +} + +// GenerateValidation generates the validation block based on a membership set. +func GenerateValidation(membershipSet *developmentmembership.Set) (string, error) { + validationMap := map[string][]string{} + subjectsByONR := membershipSet.SubjectsByONR() + + onrStrings := make([]string, 0, len(subjectsByONR)) + for onrString := range subjectsByONR { + onrStrings = append(onrStrings, onrString) + } + + // Sort to ensure stability of output. + sort.Strings(onrStrings) + + for _, onrString := range onrStrings { + foundSubjects := subjectsByONR[onrString] + var strs []string + for _, fs := range foundSubjects.ListFound() { + strs = append(strs, + fmt.Sprintf("[%s] is %s", + fs.ToValidationString(), + strings.Join(wrapResources(tuple.StringsONRs(fs.ParentResources())), "/"), + )) + } + + // Sort to ensure stability of output. + sort.Strings(strs) + validationMap[onrString] = strs + } + + contents, err := yaml.Marshal(validationMap) + if err != nil { + return "", err + } + + return string(contents), nil +} + +func toExpectedRelationshipsStrings(subs []blocks.SubjectAndCaveat) []string { + mapped := make([]string, 0, len(subs)) + for _, sub := range subs { + if sub.IsCaveated { + mapped = append(mapped, tuple.StringONR(sub.Subject)+"[...]") + } else { + mapped = append(mapped, tuple.StringONR(sub.Subject)) + } + } + return mapped +} + +func toFoundRelationshipsStrings(subs []developmentmembership.FoundSubject) []string { + mapped := make([]string, 0, len(subs)) + for _, sub := range subs { + mapped = append(mapped, sub.ToValidationString()) + } + return mapped +} diff --git a/vendor/github.com/authzed/spicedb/pkg/development/warningdefs.go b/vendor/github.com/authzed/spicedb/pkg/development/warningdefs.go new file mode 100644 index 0000000..b995933 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/warningdefs.go @@ -0,0 +1,232 @@ +package development + +import ( + "context" + "fmt" + "strings" + + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/tuple" +) + +var lintRelationReferencesParentType = relationCheck{ + "relation-name-references-parent", + func( + ctx context.Context, + relation *corev1.Relation, + def *schema.Definition, + ) (*devinterface.DeveloperWarning, error) { + parentDef := def.Namespace() + if strings.HasSuffix(relation.Name, parentDef.Name) { + if def.IsPermission(relation.Name) { + return warningForMetadata( + "relation-name-references-parent", + fmt.Sprintf("Permission %q references parent type %q in its name; it is recommended to drop the suffix", relation.Name, parentDef.Name), + relation.Name, + relation, + ), nil + } + + return warningForMetadata( + "relation-name-references-parent", + fmt.Sprintf("Relation %q references parent type %q in its name; it is recommended to drop the suffix", relation.Name, parentDef.Name), + relation.Name, + relation, + ), nil + } + + return nil, nil + }, +} + +var lintPermissionReferencingItself = computedUsersetCheck{ + "permission-references-itself", + func( + ctx context.Context, + computedUserset *corev1.ComputedUserset, + sourcePosition *corev1.SourcePosition, + def *schema.Definition, + ) (*devinterface.DeveloperWarning, error) { + parentRelation := ctx.Value(relationKey).(*corev1.Relation) + permName := parentRelation.Name + if computedUserset.GetRelation() == permName { + return warningForPosition( + "permission-references-itself", + fmt.Sprintf("Permission %q references itself, which will cause an error to be raised due to infinite recursion", permName), + permName, + sourcePosition, + ), nil + } + + return nil, nil + }, +} + +var lintArrowReferencingUnreachable = ttuCheck{ + "arrow-references-unreachable-relation", + func( + ctx context.Context, + ttu ttu, + sourcePosition *corev1.SourcePosition, + def *schema.Definition, + ) (*devinterface.DeveloperWarning, error) { + parentRelation := ctx.Value(relationKey).(*corev1.Relation) + + referencedRelation, ok := def.GetRelation(ttu.GetTupleset().GetRelation()) + if !ok { + return nil, nil + } + + allowedSubjectTypes, err := def.AllowedSubjectRelations(referencedRelation.Name) + if err != nil { + return nil, err + } + + wasFound := false + for _, subjectType := range allowedSubjectTypes { + nts, err := def.TypeSystem().GetDefinition(ctx, subjectType.Namespace) + if err != nil { + return nil, err + } + + _, ok := nts.GetRelation(ttu.GetComputedUserset().GetRelation()) + if ok { + wasFound = true + } + } + + if !wasFound { + arrowString, err := ttu.GetArrowString() + if err != nil { + return nil, err + } + + return warningForPosition( + "arrow-references-unreachable-relation", + fmt.Sprintf( + "Arrow `%s` under permission %q references relation/permission %q that does not exist on any subject types of relation %q", + arrowString, + parentRelation.Name, + ttu.GetComputedUserset().GetRelation(), + ttu.GetTupleset().GetRelation(), + ), + arrowString, + sourcePosition, + ), nil + } + + return nil, nil + }, +} + +var lintArrowOverSubRelation = ttuCheck{ + "arrow-walks-subject-relation", + func( + ctx context.Context, + ttu ttu, + sourcePosition *corev1.SourcePosition, + def *schema.Definition, + ) (*devinterface.DeveloperWarning, error) { + parentRelation := ctx.Value(relationKey).(*corev1.Relation) + + referencedRelation, ok := def.GetRelation(ttu.GetTupleset().GetRelation()) + if !ok { + return nil, nil + } + + allowedSubjectTypes, err := def.AllowedSubjectRelations(referencedRelation.Name) + if err != nil { + return nil, err + } + + arrowString, err := ttu.GetArrowString() + if err != nil { + return nil, err + } + + for _, subjectType := range allowedSubjectTypes { + if subjectType.GetRelation() != tuple.Ellipsis { + return warningForPosition( + "arrow-walks-subject-relation", + fmt.Sprintf( + "Arrow `%s` under permission %q references relation %q that has relation %q on subject %q: *the subject relation will be ignored for the arrow*", + arrowString, + parentRelation.Name, + ttu.GetTupleset().GetRelation(), + subjectType.GetRelation(), + subjectType.Namespace, + ), + arrowString, + sourcePosition, + ), nil + } + } + + return nil, nil + }, +} + +var lintArrowReferencingRelation = ttuCheck{ + "arrow-references-relation", + func( + ctx context.Context, + ttu ttu, + sourcePosition *corev1.SourcePosition, + def *schema.Definition, + ) (*devinterface.DeveloperWarning, error) { + parentRelation := ctx.Value(relationKey).(*corev1.Relation) + + referencedRelation, ok := def.GetRelation(ttu.GetTupleset().GetRelation()) + if !ok { + return nil, nil + } + + // For each subject type of the referenced relation, check if the referenced permission + // is, in fact, a relation. + allowedSubjectTypes, err := def.AllowedSubjectRelations(referencedRelation.Name) + if err != nil { + return nil, err + } + + arrowString, err := ttu.GetArrowString() + if err != nil { + return nil, err + } + + for _, subjectType := range allowedSubjectTypes { + // Skip for arrow referencing relations in the same namespace. + if subjectType.Namespace == def.Namespace().Name { + continue + } + + nts, err := def.TypeSystem().GetDefinition(ctx, subjectType.Namespace) + if err != nil { + return nil, err + } + + targetRelation, ok := nts.GetRelation(ttu.GetComputedUserset().GetRelation()) + if !ok { + continue + } + + if !nts.IsPermission(targetRelation.Name) { + return warningForPosition( + "arrow-references-relation", + fmt.Sprintf( + "Arrow `%s` under permission %q references relation %q on definition %q; it is recommended to point to a permission", + arrowString, + parentRelation.Name, + targetRelation.Name, + subjectType.Namespace, + ), + arrowString, + sourcePosition, + ), nil + } + } + + return nil, nil + }, +} diff --git a/vendor/github.com/authzed/spicedb/pkg/development/warnings.go b/vendor/github.com/authzed/spicedb/pkg/development/warnings.go new file mode 100644 index 0000000..6be9a1c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/warnings.go @@ -0,0 +1,309 @@ +package development + +import ( + "context" + "fmt" + + "github.com/ccoveille/go-safecast" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/namespace" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var allChecks = checks{ + relationChecks: []relationCheck{ + lintRelationReferencesParentType, + }, + computedUsersetChecks: []computedUsersetCheck{ + lintPermissionReferencingItself, + }, + ttuChecks: []ttuCheck{ + lintArrowReferencingRelation, + lintArrowReferencingUnreachable, + lintArrowOverSubRelation, + }, +} + +func warningForMetadata(warningName string, message string, sourceCode string, metadata namespace.WithSourcePosition) *devinterface.DeveloperWarning { + return warningForPosition(warningName, message, sourceCode, metadata.GetSourcePosition()) +} + +func warningForPosition(warningName string, message string, sourceCode string, sourcePosition *corev1.SourcePosition) *devinterface.DeveloperWarning { + if sourcePosition == nil { + return &devinterface.DeveloperWarning{ + Message: message, + SourceCode: sourceCode, + } + } + + // NOTE: zeroes on failure are fine here. + lineNumber, err := safecast.ToUint32(sourcePosition.ZeroIndexedLineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnNumber, err := safecast.ToUint32(sourcePosition.ZeroIndexedColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + + return &devinterface.DeveloperWarning{ + Message: message + " (" + warningName + ")", + Line: lineNumber + 1, + Column: columnNumber + 1, + SourceCode: sourceCode, + } +} + +// GetWarnings returns a list of warnings for the given developer context. +func GetWarnings(ctx context.Context, devCtx *DevContext) ([]*devinterface.DeveloperWarning, error) { + warnings := []*devinterface.DeveloperWarning{} + res := schema.ResolverForCompiledSchema(*devCtx.CompiledSchema) + ts := schema.NewTypeSystem(res) + + for _, def := range devCtx.CompiledSchema.ObjectDefinitions { + found, err := addDefinitionWarnings(ctx, def, ts) + if err != nil { + return nil, err + } + warnings = append(warnings, found...) + } + + return warnings, nil +} + +type contextKey string + +var relationKey = contextKey("relation") + +func addDefinitionWarnings(ctx context.Context, nsDef *corev1.NamespaceDefinition, ts *schema.TypeSystem) ([]*devinterface.DeveloperWarning, error) { + def, err := schema.NewDefinition(ts, nsDef) + if err != nil { + return nil, err + } + + warnings := []*devinterface.DeveloperWarning{} + for _, rel := range nsDef.Relation { + ctx = context.WithValue(ctx, relationKey, rel) + + for _, check := range allChecks.relationChecks { + if shouldSkipCheck(rel.Metadata, check.name) { + continue + } + + checkerWarning, err := check.fn(ctx, rel, def) + if err != nil { + return nil, err + } + + if checkerWarning != nil { + warnings = append(warnings, checkerWarning) + } + } + + if def.IsPermission(rel.Name) { + found, err := walkUsersetRewrite(ctx, rel.UsersetRewrite, rel, allChecks, def) + if err != nil { + return nil, err + } + + warnings = append(warnings, found...) + } + } + + return warnings, nil +} + +func shouldSkipCheck(metadata *corev1.Metadata, name string) bool { + if metadata == nil { + return false + } + + comments := namespace.GetComments(metadata) + for _, comment := range comments { + if comment == "// spicedb-ignore-warning: "+name { + return true + } + } + + return false +} + +type tupleset interface { + GetRelation() string +} + +type ttu interface { + GetTupleset() tupleset + GetComputedUserset() *corev1.ComputedUserset + GetArrowString() (string, error) +} + +type ( + relationChecker func(ctx context.Context, relation *corev1.Relation, def *schema.Definition) (*devinterface.DeveloperWarning, error) + computedUsersetChecker func(ctx context.Context, computedUserset *corev1.ComputedUserset, sourcePosition *corev1.SourcePosition, def *schema.Definition) (*devinterface.DeveloperWarning, error) + ttuChecker func(ctx context.Context, ttu ttu, sourcePosition *corev1.SourcePosition, def *schema.Definition) (*devinterface.DeveloperWarning, error) +) + +type relationCheck struct { + name string + fn relationChecker +} + +type computedUsersetCheck struct { + name string + fn computedUsersetChecker +} + +type ttuCheck struct { + name string + fn ttuChecker +} + +type checks struct { + relationChecks []relationCheck + computedUsersetChecks []computedUsersetCheck + ttuChecks []ttuCheck +} + +func walkUsersetRewrite(ctx context.Context, rewrite *corev1.UsersetRewrite, relation *corev1.Relation, checks checks, def *schema.Definition) ([]*devinterface.DeveloperWarning, error) { + if rewrite == nil { + return nil, nil + } + + switch t := (rewrite.RewriteOperation).(type) { + case *corev1.UsersetRewrite_Union: + return walkUsersetOperations(ctx, t.Union.Child, relation, checks, def) + + case *corev1.UsersetRewrite_Intersection: + return walkUsersetOperations(ctx, t.Intersection.Child, relation, checks, def) + + case *corev1.UsersetRewrite_Exclusion: + return walkUsersetOperations(ctx, t.Exclusion.Child, relation, checks, def) + + default: + return nil, spiceerrors.MustBugf("unexpected rewrite operation type %T", t) + } +} + +func walkUsersetOperations(ctx context.Context, ops []*corev1.SetOperation_Child, relation *corev1.Relation, checks checks, def *schema.Definition) ([]*devinterface.DeveloperWarning, error) { + warnings := []*devinterface.DeveloperWarning{} + for _, op := range ops { + switch t := op.ChildType.(type) { + case *corev1.SetOperation_Child_XThis: + continue + + case *corev1.SetOperation_Child_ComputedUserset: + for _, check := range checks.computedUsersetChecks { + if shouldSkipCheck(relation.Metadata, check.name) { + continue + } + + checkerWarning, err := check.fn(ctx, t.ComputedUserset, op.SourcePosition, def) + if err != nil { + return nil, err + } + + if checkerWarning != nil { + warnings = append(warnings, checkerWarning) + } + } + + case *corev1.SetOperation_Child_UsersetRewrite: + found, err := walkUsersetRewrite(ctx, t.UsersetRewrite, relation, checks, def) + if err != nil { + return nil, err + } + + warnings = append(warnings, found...) + + case *corev1.SetOperation_Child_FunctionedTupleToUserset: + for _, check := range checks.ttuChecks { + if shouldSkipCheck(relation.Metadata, check.name) { + continue + } + + checkerWarning, err := check.fn(ctx, wrappedFunctionedTTU{t.FunctionedTupleToUserset}, op.SourcePosition, def) + if err != nil { + return nil, err + } + + if checkerWarning != nil { + warnings = append(warnings, checkerWarning) + } + } + + case *corev1.SetOperation_Child_TupleToUserset: + for _, check := range checks.ttuChecks { + if shouldSkipCheck(relation.Metadata, check.name) { + continue + } + + checkerWarning, err := check.fn(ctx, wrappedTTU{t.TupleToUserset}, op.SourcePosition, def) + if err != nil { + return nil, err + } + + if checkerWarning != nil { + warnings = append(warnings, checkerWarning) + } + } + + case *corev1.SetOperation_Child_XNil: + continue + + default: + return nil, spiceerrors.MustBugf("unexpected set operation type %T", t) + } + } + + return warnings, nil +} + +type wrappedFunctionedTTU struct { + *corev1.FunctionedTupleToUserset +} + +func (wfttu wrappedFunctionedTTU) GetTupleset() tupleset { + return wfttu.FunctionedTupleToUserset.GetTupleset() +} + +func (wfttu wrappedFunctionedTTU) GetComputedUserset() *corev1.ComputedUserset { + return wfttu.FunctionedTupleToUserset.GetComputedUserset() +} + +func (wfttu wrappedFunctionedTTU) GetArrowString() (string, error) { + var functionName string + switch wfttu.Function { + case corev1.FunctionedTupleToUserset_FUNCTION_ANY: + functionName = "any" + + case corev1.FunctionedTupleToUserset_FUNCTION_ALL: + functionName = "all" + + default: + return "", spiceerrors.MustBugf("unknown function type %T", wfttu.Function) + } + + return fmt.Sprintf("%s.%s(%s)", wfttu.GetTupleset().GetRelation(), functionName, wfttu.GetComputedUserset().GetRelation()), nil +} + +type wrappedTTU struct { + *corev1.TupleToUserset +} + +func (wtu wrappedTTU) GetTupleset() tupleset { + return wtu.TupleToUserset.GetTupleset() +} + +func (wtu wrappedTTU) GetComputedUserset() *corev1.ComputedUserset { + return wtu.TupleToUserset.GetComputedUserset() +} + +func (wtu wrappedTTU) GetArrowString() (string, error) { + arrowString := fmt.Sprintf("%s->%s", wtu.GetTupleset().GetRelation(), wtu.GetComputedUserset().GetRelation()) + return arrowString, nil +} |
