summaryrefslogtreecommitdiff
path: root/vendor/github.com/hamba/avro/v2/protocol.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/hamba/avro/v2/protocol.go')
-rw-r--r--vendor/github.com/hamba/avro/v2/protocol.go377
1 files changed, 377 insertions, 0 deletions
diff --git a/vendor/github.com/hamba/avro/v2/protocol.go b/vendor/github.com/hamba/avro/v2/protocol.go
new file mode 100644
index 0000000..3a62baa
--- /dev/null
+++ b/vendor/github.com/hamba/avro/v2/protocol.go
@@ -0,0 +1,377 @@
+package avro
+
+import (
+ "crypto/md5"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "os"
+
+ jsoniter "github.com/json-iterator/go"
+ "github.com/mitchellh/mapstructure"
+)
+
+var (
+ protocolReserved = []string{"doc", "types", "messages", "protocol", "namespace"}
+ messageReserved = []string{"doc", "response", "request", "errors", "one-way"}
+)
+
+type protocolConfig struct {
+ doc string
+ props map[string]any
+}
+
+// ProtocolOption is a function that sets a protocol option.
+type ProtocolOption func(*protocolConfig)
+
+// WithProtoDoc sets the doc on a protocol.
+func WithProtoDoc(doc string) ProtocolOption {
+ return func(opts *protocolConfig) {
+ opts.doc = doc
+ }
+}
+
+// WithProtoProps sets the properties on a protocol.
+func WithProtoProps(props map[string]any) ProtocolOption {
+ return func(opts *protocolConfig) {
+ opts.props = props
+ }
+}
+
+// Protocol is an Avro protocol.
+type Protocol struct {
+ name
+ properties
+
+ types []NamedSchema
+ messages map[string]*Message
+
+ doc string
+
+ hash string
+}
+
+// NewProtocol creates a protocol instance.
+func NewProtocol(
+ name, namepsace string,
+ types []NamedSchema,
+ messages map[string]*Message,
+ opts ...ProtocolOption,
+) (*Protocol, error) {
+ var cfg protocolConfig
+ for _, opt := range opts {
+ opt(&cfg)
+ }
+
+ n, err := newName(name, namepsace, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ p := &Protocol{
+ name: n,
+ properties: newProperties(cfg.props, protocolReserved),
+ types: types,
+ messages: messages,
+ doc: cfg.doc,
+ }
+
+ b := md5.Sum([]byte(p.String()))
+ p.hash = hex.EncodeToString(b[:])
+
+ return p, nil
+}
+
+// Message returns a message with the given name or nil.
+func (p *Protocol) Message(name string) *Message {
+ return p.messages[name]
+}
+
+// Doc returns the protocol doc.
+func (p *Protocol) Doc() string {
+ return p.doc
+}
+
+// Hash returns the MD5 hash of the protocol.
+func (p *Protocol) Hash() string {
+ return p.hash
+}
+
+// Types returns the types of the protocol.
+func (p *Protocol) Types() []NamedSchema {
+ return p.types
+}
+
+// String returns the canonical form of the protocol.
+func (p *Protocol) String() string {
+ types := ""
+ for _, f := range p.types {
+ types += f.String() + ","
+ }
+ if len(types) > 0 {
+ types = types[:len(types)-1]
+ }
+
+ messages := ""
+ for k, m := range p.messages {
+ messages += `"` + k + `":` + m.String() + ","
+ }
+ if len(messages) > 0 {
+ messages = messages[:len(messages)-1]
+ }
+
+ return `{"protocol":"` + p.Name() +
+ `","namespace":"` + p.Namespace() +
+ `","types":[` + types + `],"messages":{` + messages + `}}`
+}
+
+// Message is an Avro protocol message.
+type Message struct {
+ properties
+
+ req *RecordSchema
+ resp Schema
+ errs *UnionSchema
+ oneWay bool
+
+ doc string
+}
+
+// NewMessage creates a protocol message instance.
+func NewMessage(req *RecordSchema, resp Schema, errors *UnionSchema, oneWay bool, opts ...ProtocolOption) *Message {
+ var cfg protocolConfig
+ for _, opt := range opts {
+ opt(&cfg)
+ }
+
+ return &Message{
+ properties: newProperties(cfg.props, messageReserved),
+ req: req,
+ resp: resp,
+ errs: errors,
+ oneWay: oneWay,
+ doc: cfg.doc,
+ }
+}
+
+// Request returns the message request schema.
+func (m *Message) Request() *RecordSchema {
+ return m.req
+}
+
+// Response returns the message response schema.
+func (m *Message) Response() Schema {
+ return m.resp
+}
+
+// Errors returns the message errors union schema.
+func (m *Message) Errors() *UnionSchema {
+ return m.errs
+}
+
+// OneWay determines of the message is a one way message.
+func (m *Message) OneWay() bool {
+ return m.oneWay
+}
+
+// Doc returns the message doc.
+func (m *Message) Doc() string {
+ return m.doc
+}
+
+// String returns the canonical form of the message.
+func (m *Message) String() string {
+ fields := ""
+ for _, f := range m.req.fields {
+ fields += f.String() + ","
+ }
+ if len(fields) > 0 {
+ fields = fields[:len(fields)-1]
+ }
+
+ str := `{"request":[` + fields + `]`
+ if m.resp != nil {
+ str += `,"response":` + m.resp.String()
+ }
+ if m.errs != nil && len(m.errs.Types()) > 1 {
+ errs, _ := NewUnionSchema(m.errs.Types()[1:])
+ str += `,"errors":` + errs.String()
+ }
+ str += "}"
+ return str
+}
+
+// ParseProtocolFile parses an Avro protocol from a file.
+func ParseProtocolFile(path string) (*Protocol, error) {
+ s, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ return ParseProtocol(string(s))
+}
+
+// MustParseProtocol parses an Avro protocol, panicing if there is an error.
+func MustParseProtocol(protocol string) *Protocol {
+ parsed, err := ParseProtocol(protocol)
+ if err != nil {
+ panic(err)
+ }
+
+ return parsed
+}
+
+// ParseProtocol parses an Avro protocol.
+func ParseProtocol(protocol string) (*Protocol, error) {
+ cache := &SchemaCache{}
+
+ var m map[string]any
+ if err := jsoniter.Unmarshal([]byte(protocol), &m); err != nil {
+ return nil, err
+ }
+
+ seen := seenCache{}
+ return parseProtocol(m, seen, cache)
+}
+
+type protocol struct {
+ Protocol string `mapstructure:"protocol"`
+ Namespace string `mapstructure:"namespace"`
+ Doc string `mapstructure:"doc"`
+ Types []any `mapstructure:"types"`
+ Messages map[string]map[string]any `mapstructure:"messages"`
+ Props map[string]any `mapstructure:",remain"`
+}
+
+func parseProtocol(m map[string]any, seen seenCache, cache *SchemaCache) (*Protocol, error) {
+ var (
+ p protocol
+ meta mapstructure.Metadata
+ )
+ if err := decodeMap(m, &p, &meta); err != nil {
+ return nil, fmt.Errorf("avro: error decoding protocol: %w", err)
+ }
+
+ if err := checkParsedName(p.Protocol); err != nil {
+ return nil, err
+ }
+
+ var (
+ types []NamedSchema
+ err error
+ )
+ if len(p.Types) > 0 {
+ types, err = parseProtocolTypes(p.Namespace, p.Types, seen, cache)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ messages := map[string]*Message{}
+ if len(p.Messages) > 0 {
+ for k, msg := range p.Messages {
+ message, err := parseMessage(p.Namespace, msg, seen, cache)
+ if err != nil {
+ return nil, err
+ }
+
+ messages[k] = message
+ }
+ }
+
+ return NewProtocol(p.Protocol, p.Namespace, types, messages, WithProtoDoc(p.Doc), WithProtoProps(p.Props))
+}
+
+func parseProtocolTypes(namespace string, types []any, seen seenCache, cache *SchemaCache) ([]NamedSchema, error) {
+ ts := make([]NamedSchema, len(types))
+ for i, typ := range types {
+ schema, err := parseType(namespace, typ, seen, cache)
+ if err != nil {
+ return nil, err
+ }
+
+ namedSchema, ok := schema.(NamedSchema)
+ if !ok {
+ return nil, errors.New("avro: protocol types must be named schemas")
+ }
+
+ ts[i] = namedSchema
+ }
+
+ return ts, nil
+}
+
+type message struct {
+ Doc string `mapstructure:"doc"`
+ Request []map[string]any `mapstructure:"request"`
+ Response any `mapstructure:"response"`
+ Errors []any `mapstructure:"errors"`
+ OneWay bool `mapstructure:"one-way"`
+ Props map[string]any `mapstructure:",remain"`
+}
+
+func parseMessage(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (*Message, error) {
+ var (
+ msg message
+ meta mapstructure.Metadata
+ )
+ if err := decodeMap(m, &msg, &meta); err != nil {
+ return nil, fmt.Errorf("avro: error decoding message: %w", err)
+ }
+
+ fields := make([]*Field, len(msg.Request))
+ for i, f := range msg.Request {
+ field, err := parseField(namespace, f, seen, cache)
+ if err != nil {
+ return nil, err
+ }
+ fields[i] = field
+ }
+ request := &RecordSchema{
+ name: name{},
+ properties: properties{},
+ fields: fields,
+ }
+
+ var response Schema
+ if msg.Response != nil {
+ schema, err := parseType(namespace, msg.Response, seen, cache)
+ if err != nil {
+ return nil, err
+ }
+
+ if schema.Type() != Null {
+ response = schema
+ }
+ }
+
+ types := []Schema{NewPrimitiveSchema(String, nil)}
+ if len(msg.Errors) > 0 {
+ for _, e := range msg.Errors {
+ schema, err := parseType(namespace, e, seen, cache)
+ if err != nil {
+ return nil, err
+ }
+
+ if rec, ok := schema.(*RecordSchema); ok && !rec.IsError() {
+ return nil, errors.New("avro: errors record schema must be of type error")
+ }
+
+ types = append(types, schema)
+ }
+ }
+ errs, err := NewUnionSchema(types)
+ if err != nil {
+ return nil, err
+ }
+
+ oneWay := msg.OneWay
+ if hasKey(meta.Keys, "one-way") && oneWay && (len(errs.Types()) > 1 || response != nil) {
+ return nil, errors.New("avro: one-way messages cannot not have a response or errors")
+ }
+ if !oneWay && len(errs.Types()) <= 1 && response == nil {
+ oneWay = true
+ }
+
+ return NewMessage(request, response, errs, oneWay, WithProtoDoc(msg.Doc), WithProtoProps(msg.Props)), nil
+}