diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/hamba/avro/v2/protocol.go | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/hamba/avro/v2/protocol.go')
| -rw-r--r-- | vendor/github.com/hamba/avro/v2/protocol.go | 377 |
1 files changed, 377 insertions, 0 deletions
diff --git a/vendor/github.com/hamba/avro/v2/protocol.go b/vendor/github.com/hamba/avro/v2/protocol.go new file mode 100644 index 0000000..3a62baa --- /dev/null +++ b/vendor/github.com/hamba/avro/v2/protocol.go @@ -0,0 +1,377 @@ +package avro + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "os" + + jsoniter "github.com/json-iterator/go" + "github.com/mitchellh/mapstructure" +) + +var ( + protocolReserved = []string{"doc", "types", "messages", "protocol", "namespace"} + messageReserved = []string{"doc", "response", "request", "errors", "one-way"} +) + +type protocolConfig struct { + doc string + props map[string]any +} + +// ProtocolOption is a function that sets a protocol option. +type ProtocolOption func(*protocolConfig) + +// WithProtoDoc sets the doc on a protocol. +func WithProtoDoc(doc string) ProtocolOption { + return func(opts *protocolConfig) { + opts.doc = doc + } +} + +// WithProtoProps sets the properties on a protocol. +func WithProtoProps(props map[string]any) ProtocolOption { + return func(opts *protocolConfig) { + opts.props = props + } +} + +// Protocol is an Avro protocol. +type Protocol struct { + name + properties + + types []NamedSchema + messages map[string]*Message + + doc string + + hash string +} + +// NewProtocol creates a protocol instance. +func NewProtocol( + name, namepsace string, + types []NamedSchema, + messages map[string]*Message, + opts ...ProtocolOption, +) (*Protocol, error) { + var cfg protocolConfig + for _, opt := range opts { + opt(&cfg) + } + + n, err := newName(name, namepsace, nil) + if err != nil { + return nil, err + } + + p := &Protocol{ + name: n, + properties: newProperties(cfg.props, protocolReserved), + types: types, + messages: messages, + doc: cfg.doc, + } + + b := md5.Sum([]byte(p.String())) + p.hash = hex.EncodeToString(b[:]) + + return p, nil +} + +// Message returns a message with the given name or nil. +func (p *Protocol) Message(name string) *Message { + return p.messages[name] +} + +// Doc returns the protocol doc. +func (p *Protocol) Doc() string { + return p.doc +} + +// Hash returns the MD5 hash of the protocol. +func (p *Protocol) Hash() string { + return p.hash +} + +// Types returns the types of the protocol. +func (p *Protocol) Types() []NamedSchema { + return p.types +} + +// String returns the canonical form of the protocol. +func (p *Protocol) String() string { + types := "" + for _, f := range p.types { + types += f.String() + "," + } + if len(types) > 0 { + types = types[:len(types)-1] + } + + messages := "" + for k, m := range p.messages { + messages += `"` + k + `":` + m.String() + "," + } + if len(messages) > 0 { + messages = messages[:len(messages)-1] + } + + return `{"protocol":"` + p.Name() + + `","namespace":"` + p.Namespace() + + `","types":[` + types + `],"messages":{` + messages + `}}` +} + +// Message is an Avro protocol message. +type Message struct { + properties + + req *RecordSchema + resp Schema + errs *UnionSchema + oneWay bool + + doc string +} + +// NewMessage creates a protocol message instance. +func NewMessage(req *RecordSchema, resp Schema, errors *UnionSchema, oneWay bool, opts ...ProtocolOption) *Message { + var cfg protocolConfig + for _, opt := range opts { + opt(&cfg) + } + + return &Message{ + properties: newProperties(cfg.props, messageReserved), + req: req, + resp: resp, + errs: errors, + oneWay: oneWay, + doc: cfg.doc, + } +} + +// Request returns the message request schema. +func (m *Message) Request() *RecordSchema { + return m.req +} + +// Response returns the message response schema. +func (m *Message) Response() Schema { + return m.resp +} + +// Errors returns the message errors union schema. +func (m *Message) Errors() *UnionSchema { + return m.errs +} + +// OneWay determines of the message is a one way message. +func (m *Message) OneWay() bool { + return m.oneWay +} + +// Doc returns the message doc. +func (m *Message) Doc() string { + return m.doc +} + +// String returns the canonical form of the message. +func (m *Message) String() string { + fields := "" + for _, f := range m.req.fields { + fields += f.String() + "," + } + if len(fields) > 0 { + fields = fields[:len(fields)-1] + } + + str := `{"request":[` + fields + `]` + if m.resp != nil { + str += `,"response":` + m.resp.String() + } + if m.errs != nil && len(m.errs.Types()) > 1 { + errs, _ := NewUnionSchema(m.errs.Types()[1:]) + str += `,"errors":` + errs.String() + } + str += "}" + return str +} + +// ParseProtocolFile parses an Avro protocol from a file. +func ParseProtocolFile(path string) (*Protocol, error) { + s, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + return ParseProtocol(string(s)) +} + +// MustParseProtocol parses an Avro protocol, panicing if there is an error. +func MustParseProtocol(protocol string) *Protocol { + parsed, err := ParseProtocol(protocol) + if err != nil { + panic(err) + } + + return parsed +} + +// ParseProtocol parses an Avro protocol. +func ParseProtocol(protocol string) (*Protocol, error) { + cache := &SchemaCache{} + + var m map[string]any + if err := jsoniter.Unmarshal([]byte(protocol), &m); err != nil { + return nil, err + } + + seen := seenCache{} + return parseProtocol(m, seen, cache) +} + +type protocol struct { + Protocol string `mapstructure:"protocol"` + Namespace string `mapstructure:"namespace"` + Doc string `mapstructure:"doc"` + Types []any `mapstructure:"types"` + Messages map[string]map[string]any `mapstructure:"messages"` + Props map[string]any `mapstructure:",remain"` +} + +func parseProtocol(m map[string]any, seen seenCache, cache *SchemaCache) (*Protocol, error) { + var ( + p protocol + meta mapstructure.Metadata + ) + if err := decodeMap(m, &p, &meta); err != nil { + return nil, fmt.Errorf("avro: error decoding protocol: %w", err) + } + + if err := checkParsedName(p.Protocol); err != nil { + return nil, err + } + + var ( + types []NamedSchema + err error + ) + if len(p.Types) > 0 { + types, err = parseProtocolTypes(p.Namespace, p.Types, seen, cache) + if err != nil { + return nil, err + } + } + + messages := map[string]*Message{} + if len(p.Messages) > 0 { + for k, msg := range p.Messages { + message, err := parseMessage(p.Namespace, msg, seen, cache) + if err != nil { + return nil, err + } + + messages[k] = message + } + } + + return NewProtocol(p.Protocol, p.Namespace, types, messages, WithProtoDoc(p.Doc), WithProtoProps(p.Props)) +} + +func parseProtocolTypes(namespace string, types []any, seen seenCache, cache *SchemaCache) ([]NamedSchema, error) { + ts := make([]NamedSchema, len(types)) + for i, typ := range types { + schema, err := parseType(namespace, typ, seen, cache) + if err != nil { + return nil, err + } + + namedSchema, ok := schema.(NamedSchema) + if !ok { + return nil, errors.New("avro: protocol types must be named schemas") + } + + ts[i] = namedSchema + } + + return ts, nil +} + +type message struct { + Doc string `mapstructure:"doc"` + Request []map[string]any `mapstructure:"request"` + Response any `mapstructure:"response"` + Errors []any `mapstructure:"errors"` + OneWay bool `mapstructure:"one-way"` + Props map[string]any `mapstructure:",remain"` +} + +func parseMessage(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (*Message, error) { + var ( + msg message + meta mapstructure.Metadata + ) + if err := decodeMap(m, &msg, &meta); err != nil { + return nil, fmt.Errorf("avro: error decoding message: %w", err) + } + + fields := make([]*Field, len(msg.Request)) + for i, f := range msg.Request { + field, err := parseField(namespace, f, seen, cache) + if err != nil { + return nil, err + } + fields[i] = field + } + request := &RecordSchema{ + name: name{}, + properties: properties{}, + fields: fields, + } + + var response Schema + if msg.Response != nil { + schema, err := parseType(namespace, msg.Response, seen, cache) + if err != nil { + return nil, err + } + + if schema.Type() != Null { + response = schema + } + } + + types := []Schema{NewPrimitiveSchema(String, nil)} + if len(msg.Errors) > 0 { + for _, e := range msg.Errors { + schema, err := parseType(namespace, e, seen, cache) + if err != nil { + return nil, err + } + + if rec, ok := schema.(*RecordSchema); ok && !rec.IsError() { + return nil, errors.New("avro: errors record schema must be of type error") + } + + types = append(types, schema) + } + } + errs, err := NewUnionSchema(types) + if err != nil { + return nil, err + } + + oneWay := msg.OneWay + if hasKey(meta.Keys, "one-way") && oneWay && (len(errs.Types()) > 1 || response != nil) { + return nil, errors.New("avro: one-way messages cannot not have a response or errors") + } + if !oneWay && len(errs.Types()) <= 1 && response == nil { + oneWay = true + } + + return NewMessage(request, response, errs, oneWay, WithProtoDoc(msg.Doc), WithProtoProps(msg.Props)), nil +} |
