summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go
blob: a572216cab358f539355c7d4a39058bb5f6beb9a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package v1

import (
	"context"

	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
	caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
	"github.com/authzed/spicedb/pkg/datastore"
	"github.com/authzed/spicedb/pkg/diff"
	"github.com/authzed/spicedb/pkg/middleware/consistency"
	core "github.com/authzed/spicedb/pkg/proto/core/v1"
	"github.com/authzed/spicedb/pkg/schemadsl/compiler"
	"github.com/authzed/spicedb/pkg/schemadsl/input"
)

func loadCurrentSchema(ctx context.Context) (*diff.DiffableSchema, datastore.Revision, error) {
	ds := datastoremw.MustFromContext(ctx)

	atRevision, _, err := consistency.RevisionFromContext(ctx)
	if err != nil {
		return nil, nil, err
	}

	reader := ds.SnapshotReader(atRevision)

	namespacesAndRevs, err := reader.ListAllNamespaces(ctx)
	if err != nil {
		return nil, atRevision, err
	}

	caveatsAndRevs, err := reader.ListAllCaveats(ctx)
	if err != nil {
		return nil, atRevision, err
	}

	namespaces := make([]*core.NamespaceDefinition, 0, len(namespacesAndRevs))
	for _, namespaceAndRev := range namespacesAndRevs {
		namespaces = append(namespaces, namespaceAndRev.Definition)
	}

	caveats := make([]*core.CaveatDefinition, 0, len(caveatsAndRevs))
	for _, caveatAndRev := range caveatsAndRevs {
		caveats = append(caveats, caveatAndRev.Definition)
	}

	return &diff.DiffableSchema{
		ObjectDefinitions: namespaces,
		CaveatDefinitions: caveats,
	}, atRevision, nil
}

func schemaDiff(ctx context.Context, comparisonSchemaString string, caveatTypeSet *caveattypes.TypeSet) (*diff.SchemaDiff, *diff.DiffableSchema, *diff.DiffableSchema, error) {
	existingSchema, _, err := loadCurrentSchema(ctx)
	if err != nil {
		return nil, nil, nil, err
	}

	// Compile the comparison schema.
	compiled, err := compiler.Compile(compiler.InputSchema{
		Source:       input.Source("schema"),
		SchemaString: comparisonSchemaString,
	}, compiler.AllowUnprefixedObjectType(), compiler.CaveatTypeSet(caveatTypeSet))
	if err != nil {
		return nil, nil, nil, err
	}

	comparisonSchema := diff.NewDiffableSchemaFromCompiledSchema(compiled)

	diff, err := diff.DiffSchemas(*existingSchema, comparisonSchema, caveatTypeSet)
	if err != nil {
		return nil, nil, nil, err
	}

	// Return the diff.
	return diff, existingSchema, &comparisonSchema, nil
}