summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/pkg/proto/dispatch/v1/01_codec.go
blob: d72e4e4feb59d46379d06e86053881deae124665 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// This file registers a gRPC codec that replaces the default gRPC proto codec
// with one that attempts to use protobuf codecs in the following order:
// - vtprotobuf
// - google.golang.org/encoding/proto

package dispatchv1

import (
	"google.golang.org/grpc/encoding"
	"google.golang.org/grpc/mem"

	// Guarantee that the built-in proto is called registered before this one
	// so that it can be replaced.
	_ "google.golang.org/grpc/encoding/proto"
)

// Name is the name registered for the proto compressor.
const Name = "proto"

type vtprotoMessage interface {
	MarshalToSizedBufferVT(data []byte) (int, error)
	UnmarshalVT([]byte) error
	SizeVT() int
}

type vtprotoCodec struct {
	fallback encoding.CodecV2
}

func (vtprotoCodec) Name() string { return Name }

func (c *vtprotoCodec) Marshal(v any) (mem.BufferSlice, error) {
	if m, ok := v.(vtprotoMessage); ok {
		size := m.SizeVT()
		if mem.IsBelowBufferPoolingThreshold(size) {
			buf := make([]byte, size)
			n, err := m.MarshalToSizedBufferVT(buf)
			if err != nil {
				return nil, err
			}
			return mem.BufferSlice{mem.SliceBuffer(buf[:n])}, nil
		}
		pool := mem.DefaultBufferPool()
		buf := pool.Get(size)
		n, err := m.MarshalToSizedBufferVT(*buf)
		if err != nil {
			pool.Put(buf)
			return nil, err
		}
		*buf = (*buf)[:n]
		return mem.BufferSlice{mem.NewBuffer(buf, pool)}, nil
	}

	return c.fallback.Marshal(v)
}

func (c *vtprotoCodec) Unmarshal(data mem.BufferSlice, v any) error {
	if m, ok := v.(vtprotoMessage); ok {
		buf := data.MaterializeToBuffer(mem.DefaultBufferPool())
		defer buf.Free()
		return m.UnmarshalVT(buf.ReadOnlyData())
	}

	return c.fallback.Unmarshal(data, v)
}

func init() {
	encoding.RegisterCodecV2(&vtprotoCodec{
		fallback: encoding.GetCodecV2("proto"),
	})
}