summaryrefslogtreecommitdiff
path: root/vendor/github.com/hamba/avro/v2/schema_compatibility.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/hamba/avro/v2/schema_compatibility.go')
-rw-r--r--vendor/github.com/hamba/avro/v2/schema_compatibility.go487
1 files changed, 487 insertions, 0 deletions
diff --git a/vendor/github.com/hamba/avro/v2/schema_compatibility.go b/vendor/github.com/hamba/avro/v2/schema_compatibility.go
new file mode 100644
index 0000000..0b1d9ac
--- /dev/null
+++ b/vendor/github.com/hamba/avro/v2/schema_compatibility.go
@@ -0,0 +1,487 @@
+package avro
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+)
+
+type recursionError struct{}
+
+func (e recursionError) Error() string {
+ return ""
+}
+
+type compatKey struct {
+ reader [32]byte
+ writer [32]byte
+}
+
+// SchemaCompatibility determines the compatibility of schemas.
+type SchemaCompatibility struct {
+ cache sync.Map // map[compatKey]error
+}
+
+// NewSchemaCompatibility creates a new schema compatibility instance.
+func NewSchemaCompatibility() *SchemaCompatibility {
+ return &SchemaCompatibility{}
+}
+
+// Compatible determines the compatibility if the reader and writer schemas.
+func (c *SchemaCompatibility) Compatible(reader, writer Schema) error {
+ return c.compatible(reader, writer)
+}
+
+func (c *SchemaCompatibility) compatible(reader, writer Schema) error {
+ key := compatKey{reader: reader.Fingerprint(), writer: writer.Fingerprint()}
+ if err, ok := c.cache.Load(key); ok {
+ if _, ok := err.(recursionError); ok {
+ // Break the recursion here.
+ return nil
+ }
+
+ if err == nil {
+ return nil
+ }
+
+ return err.(error)
+ }
+
+ c.cache.Store(key, recursionError{})
+ err := c.match(reader, writer)
+ if err != nil {
+ // We dont want to pay the cost of fmt.Errorf every time
+ err = errors.New(err.Error())
+ }
+ c.cache.Store(key, err)
+ return err
+}
+
+func (c *SchemaCompatibility) match(reader, writer Schema) error {
+ // If the schema is a reference, get the actual schema
+ if reader.Type() == Ref {
+ reader = reader.(*RefSchema).Schema()
+ }
+ if writer.Type() == Ref {
+ writer = writer.(*RefSchema).Schema()
+ }
+
+ if reader.Type() != writer.Type() {
+ if writer.Type() == Union {
+ // Reader must be compatible with all types in writer
+ for _, schema := range writer.(*UnionSchema).Types() {
+ if err := c.compatible(reader, schema); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+
+ if reader.Type() == Union {
+ // Writer must be compatible with at least one reader schema
+ var err error
+ for _, schema := range reader.(*UnionSchema).Types() {
+ err = c.compatible(schema, writer)
+ if err == nil {
+ return nil
+ }
+ }
+
+ return fmt.Errorf("reader union lacking writer schema %s", writer.Type())
+ }
+
+ switch writer.Type() {
+ case Int:
+ if reader.Type() == Long || reader.Type() == Float || reader.Type() == Double {
+ return nil
+ }
+
+ case Long:
+ if reader.Type() == Float || reader.Type() == Double {
+ return nil
+ }
+
+ case Float:
+ if reader.Type() == Double {
+ return nil
+ }
+
+ case String:
+ if reader.Type() == Bytes {
+ return nil
+ }
+
+ case Bytes:
+ if reader.Type() == String {
+ return nil
+ }
+ }
+
+ return fmt.Errorf("reader schema %s not compatible with writer schema %s", reader.Type(), writer.Type())
+ }
+
+ switch reader.Type() {
+ case Array:
+ return c.compatible(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items())
+
+ case Map:
+ return c.compatible(reader.(*MapSchema).Values(), writer.(*MapSchema).Values())
+
+ case Fixed:
+ r := reader.(*FixedSchema)
+ w := writer.(*FixedSchema)
+
+ if err := c.checkSchemaName(r, w); err != nil {
+ return err
+ }
+
+ if err := c.checkFixedSize(r, w); err != nil {
+ return err
+ }
+
+ case Enum:
+ r := reader.(*EnumSchema)
+ w := writer.(*EnumSchema)
+
+ if err := c.checkSchemaName(r, w); err != nil {
+ return err
+ }
+
+ if err := c.checkEnumSymbols(r, w); err != nil {
+ if r.HasDefault() {
+ return nil
+ }
+ return err
+ }
+
+ case Record:
+ r := reader.(*RecordSchema)
+ w := writer.(*RecordSchema)
+
+ if err := c.checkSchemaName(r, w); err != nil {
+ return err
+ }
+
+ if err := c.checkRecordFields(r, w); err != nil {
+ return err
+ }
+
+ case Union:
+ for _, schema := range writer.(*UnionSchema).Types() {
+ if err := c.compatible(reader, schema); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error {
+ if reader.Name() != writer.Name() {
+ if c.contains(reader.Aliases(), writer.FullName()) {
+ return nil
+ }
+ return fmt.Errorf("reader schema %s and writer schema %s names do not match", reader.FullName(), writer.FullName())
+ }
+
+ return nil
+}
+
+func (c *SchemaCompatibility) checkFixedSize(reader, writer *FixedSchema) error {
+ if reader.Size() != writer.Size() {
+ return fmt.Errorf("%s reader and writer fixed sizes do not match", reader.FullName())
+ }
+
+ return nil
+}
+
+func (c *SchemaCompatibility) checkEnumSymbols(reader, writer *EnumSchema) error {
+ for _, symbol := range writer.Symbols() {
+ if !c.contains(reader.Symbols(), symbol) {
+ return fmt.Errorf("reader %s is missing symbol %s", reader.FullName(), symbol)
+ }
+ }
+
+ return nil
+}
+
+func (c *SchemaCompatibility) checkRecordFields(reader, writer *RecordSchema) error {
+ for _, field := range reader.Fields() {
+ f, ok := c.getField(writer.Fields(), field, func(gfo *getFieldOptions) {
+ gfo.fieldAlias = true
+ })
+ if !ok {
+ if field.HasDefault() {
+ continue
+ }
+
+ return fmt.Errorf("reader field %s is missing in writer schema and has no default", field.Name())
+ }
+
+ if err := c.compatible(field.Type(), f.Type()); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (c *SchemaCompatibility) contains(a []string, s string) bool {
+ for _, str := range a {
+ if str == s {
+ return true
+ }
+ }
+
+ return false
+}
+
+type getFieldOptions struct {
+ fieldAlias bool
+ elemAlias bool
+}
+
+func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*getFieldOptions)) (*Field, bool) {
+ opt := getFieldOptions{}
+ for _, fn := range optFns {
+ fn(&opt)
+ }
+ for _, field := range a {
+ if field.Name() == f.Name() {
+ return field, true
+ }
+ if opt.fieldAlias {
+ if c.contains(f.Aliases(), field.Name()) {
+ return field, true
+ }
+ }
+ if opt.elemAlias {
+ if c.contains(field.Aliases(), f.Name()) {
+ return field, true
+ }
+ }
+ }
+
+ return nil, false
+}
+
+// Resolve returns a composite schema that allows decoding data written by the writer schema,
+// and makes necessary adjustments to support the reader schema.
+//
+// It fails if the writer and reader schemas are not compatible.
+func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) {
+ if err := c.compatible(reader, writer); err != nil {
+ return nil, err
+ }
+
+ schema, _, err := c.resolve(reader, writer)
+ return schema, err
+}
+
+// resolve requires the reader's schema to be already compatible with the writer's.
+func (c *SchemaCompatibility) resolve(reader, writer Schema) (schema Schema, resolved bool, err error) {
+ if reader.Type() == Ref {
+ reader = reader.(*RefSchema).Schema()
+ }
+ if writer.Type() == Ref {
+ writer = writer.(*RefSchema).Schema()
+ }
+
+ if writer.Type() != reader.Type() {
+ if reader.Type() == Union {
+ for _, schema := range reader.(*UnionSchema).Types() {
+ // Compatibility is not guaranteed for every Union reader schema.
+ // Therefore, we need to check compatibility in every iteration.
+ if err := c.compatible(schema, writer); err != nil {
+ continue
+ }
+ sch, _, err := c.resolve(schema, writer)
+ if err != nil {
+ continue
+ }
+ return sch, true, nil
+ }
+
+ return nil, false, fmt.Errorf("reader union lacking writer schema %s", writer.Type())
+ }
+
+ if writer.Type() == Union {
+ schemas := make([]Schema, 0)
+ for _, schema := range writer.(*UnionSchema).Types() {
+ sch, _, err := c.resolve(reader, schema)
+ if err != nil {
+ return nil, false, err
+ }
+ schemas = append(schemas, sch)
+ }
+ s, err := NewUnionSchema(schemas, withWriterFingerprint(writer.Fingerprint()))
+ return s, true, err
+ }
+
+ if isPromotable(writer.Type(), reader.Type()) {
+ r := NewPrimitiveSchema(reader.Type(), reader.(*PrimitiveSchema).Logical(),
+ withWriterFingerprint(writer.Fingerprint()),
+ )
+ r.encodedType = writer.Type()
+ return r, true, nil
+ }
+
+ return nil, false, fmt.Errorf("failed to resolve composite schema for %s and %s", reader.Type(), writer.Type())
+ }
+
+ if isNative(writer.Type()) {
+ return reader, false, nil
+ }
+
+ if writer.Type() == Enum {
+ r := reader.(*EnumSchema)
+ w := writer.(*EnumSchema)
+ if err = c.checkEnumSymbols(r, w); err != nil {
+ if r.HasDefault() {
+ enum, _ := NewEnumSchema(r.Name(), r.Namespace(), r.Symbols(),
+ WithAliases(r.Aliases()),
+ WithDefault(r.Default()),
+ withWriterFingerprint(w.Fingerprint()),
+ )
+ enum.encodedSymbols = w.Symbols()
+ return enum, true, nil
+ }
+
+ return nil, false, err
+ }
+ return reader, false, nil
+ }
+
+ if writer.Type() == Fixed {
+ return reader, false, nil
+ }
+
+ if writer.Type() == Union {
+ schemas := make([]Schema, 0)
+ for _, s := range writer.(*UnionSchema).Types() {
+ sch, resolv, err := c.resolve(reader, s)
+ if err != nil {
+ return nil, false, err
+ }
+ schemas = append(schemas, sch)
+ resolved = resolv || resolved
+ }
+ s, err := NewUnionSchema(schemas, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved))
+ if err != nil {
+ return nil, false, err
+ }
+ return s, resolved, nil
+ }
+
+ if writer.Type() == Array {
+ schema, resolved, err = c.resolve(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items())
+ if err != nil {
+ return nil, false, err
+ }
+ return NewArraySchema(schema, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved)), resolved, nil
+ }
+
+ if writer.Type() == Map {
+ schema, resolved, err = c.resolve(reader.(*MapSchema).Values(), writer.(*MapSchema).Values())
+ if err != nil {
+ return nil, false, err
+ }
+ return NewMapSchema(schema, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved)), resolved, nil
+ }
+
+ if writer.Type() == Record {
+ return c.resolveRecord(reader, writer)
+ }
+
+ return nil, false, fmt.Errorf("failed to resolve composite schema for %s and %s", reader.Type(), writer.Type())
+}
+
+func (c *SchemaCompatibility) resolveRecord(reader, writer Schema) (Schema, bool, error) {
+ w := writer.(*RecordSchema)
+ r := reader.(*RecordSchema)
+
+ fields := make([]*Field, 0)
+ seen := make(map[string]struct{})
+
+ var resolved bool
+ for _, wf := range w.Fields() {
+ rf, ok := c.getField(r.Fields(), wf, func(gfo *getFieldOptions) {
+ gfo.elemAlias = true
+ })
+ if !ok {
+ // The field was not found in the reader schema, it should be ignored.
+ f, _ := NewField(wf.Name(), wf.Type(), WithAliases(wf.aliases), WithOrder(wf.order))
+ f.def = wf.def
+ f.hasDef = wf.hasDef
+ f.action = FieldIgnore
+ fields = append(fields, f)
+
+ resolved = true
+ continue
+ }
+
+ ft, resolv, err := c.resolve(rf.Type(), wf.Type())
+ if err != nil {
+ return nil, false, err
+ }
+ f, _ := NewField(rf.Name(), ft, WithAliases(rf.aliases), WithOrder(rf.order))
+ f.def = rf.def
+ f.hasDef = rf.hasDef
+ fields = append(fields, f)
+ resolved = resolv || resolved
+
+ seen[rf.Name()] = struct{}{}
+ }
+
+ for _, rf := range r.Fields() {
+ if _, ok := seen[rf.Name()]; ok {
+ // This field has already been seen.
+ continue
+ }
+
+ // The schemas are already known to be compatible, so there must be a default on
+ // the field in the writer. Use the default.
+
+ f, _ := NewField(rf.Name(), rf.Type(), WithAliases(rf.aliases), WithOrder(rf.order))
+ f.def = rf.def
+ f.hasDef = rf.hasDef
+ f.action = FieldSetDefault
+ fields = append(fields, f)
+
+ resolved = true
+ }
+
+ schema, err := NewRecordSchema(r.Name(), r.Namespace(), fields,
+ WithAliases(r.Aliases()),
+ withWriterFingerprintIfResolved(writer.Fingerprint(), resolved),
+ )
+ return schema, resolved, err
+}
+
+func isNative(typ Type) bool {
+ switch typ {
+ case Null, Boolean, Int, Long, Float, Double, Bytes, String:
+ return true
+ default:
+ return false
+ }
+}
+
+func isPromotable(writerTyp, readerType Type) bool {
+ switch writerTyp {
+ case Int:
+ return readerType == Long || readerType == Float || readerType == Double
+ case Long:
+ return readerType == Float || readerType == Double
+ case Float:
+ return readerType == Double
+ case String:
+ return readerType == Bytes
+ case Bytes:
+ return readerType == String
+ default:
+ return false
+ }
+}