summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/pkg/cursor/cursor.go
blob: 4233efaa5ccced694cce1692adbd3988d48b2a3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package cursor

import (
	"encoding/base64"
	"errors"
	"fmt"

	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"

	"github.com/authzed/spicedb/pkg/datastore"
	dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
	impl "github.com/authzed/spicedb/pkg/proto/impl/v1"
	"github.com/authzed/spicedb/pkg/spiceerrors"
)

// Encode converts a decoded cursor to its opaque version.
func Encode(decoded *impl.DecodedCursor) (*v1.Cursor, error) {
	marshalled, err := decoded.MarshalVT()
	if err != nil {
		return nil, NewInvalidCursorErr(fmt.Errorf(errEncodeError, err))
	}

	return &v1.Cursor{
		Token: base64.StdEncoding.EncodeToString(marshalled),
	}, nil
}

// Decode converts an encoded cursor to its decoded version.
func Decode(encoded *v1.Cursor) (*impl.DecodedCursor, error) {
	if encoded == nil {
		return nil, NewInvalidCursorErr(errors.New("cursor pointer was nil"))
	}

	decodedBytes, err := base64.StdEncoding.DecodeString(encoded.Token)
	if err != nil {
		return nil, NewInvalidCursorErr(fmt.Errorf(errDecodeError, err))
	}

	decoded := &impl.DecodedCursor{}
	if err := decoded.UnmarshalVT(decodedBytes); err != nil {
		return nil, NewInvalidCursorErr(fmt.Errorf(errDecodeError, err))
	}

	return decoded, nil
}

// EncodeFromDispatchCursor encodes an internal dispatching cursor into a V1 cursor for external
// consumption, including the provided call context to ensure the API cursor reflects the calling
// API method. The call hash should contain all the parameters of the calling API function,
// as well as its revision and name.
func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision, flags map[string]string) (*v1.Cursor, error) {
	if dispatchCursor == nil {
		return nil, spiceerrors.MustBugf("got nil dispatch cursor")
	}

	return Encode(&impl.DecodedCursor{
		VersionOneof: &impl.DecodedCursor_V1{
			V1: &impl.V1Cursor{
				Revision:              revision.String(),
				DispatchVersion:       dispatchCursor.DispatchVersion,
				Sections:              dispatchCursor.Sections,
				CallAndParametersHash: callAndParameterHash,
				Flags:                 flags,
			},
		},
	})
}

// GetCursorFlag retrieves a flag from an encoded API cursor, if any.
func GetCursorFlag(encoded *v1.Cursor, flagName string) (string, bool, error) {
	decoded, err := Decode(encoded)
	if err != nil {
		return "", false, err
	}

	v1decoded := decoded.GetV1()
	if v1decoded == nil {
		return "", false, NewInvalidCursorErr(ErrNilCursor)
	}

	value, ok := v1decoded.Flags[flagName]
	return value, ok, nil
}

// DecodeToDispatchCursor decodes an encoded API cursor into an internal dispatching cursor,
// ensuring that the provided call context matches that encoded into the API cursor. The call
// hash should contain all the parameters of the calling API function, as well as its revision
// and name.
func DecodeToDispatchCursor(encoded *v1.Cursor, callAndParameterHash string) (*dispatch.Cursor, map[string]string, error) {
	decoded, err := Decode(encoded)
	if err != nil {
		return nil, nil, err
	}

	v1decoded := decoded.GetV1()
	if v1decoded == nil {
		return nil, nil, NewInvalidCursorErr(ErrNilCursor)
	}

	if v1decoded.CallAndParametersHash != callAndParameterHash {
		return nil, nil, NewInvalidCursorErr(ErrHashMismatch)
	}

	return &dispatch.Cursor{
		DispatchVersion: v1decoded.DispatchVersion,
		Sections:        v1decoded.Sections,
	}, v1decoded.Flags, nil
}

// DecodeToDispatchRevision decodes an encoded API cursor into an internal dispatch revision.
// NOTE: this method does *not* verify the caller's method signature.
func DecodeToDispatchRevision(encoded *v1.Cursor, ds revisionDecoder) (datastore.Revision, error) {
	decoded, err := Decode(encoded)
	if err != nil {
		return nil, err
	}

	v1decoded := decoded.GetV1()
	if v1decoded == nil {
		return nil, ErrNilCursor
	}

	parsed, err := ds.RevisionFromString(v1decoded.Revision)
	if err != nil {
		return datastore.NoRevision, fmt.Errorf(errDecodeError, err)
	}

	return parsed, nil
}

type revisionDecoder interface {
	RevisionFromString(string) (datastore.Revision, error)
}