diff options
Diffstat (limited to 'vendor/github.com/hamba/avro/v2/schema_compatibility.go')
| -rw-r--r-- | vendor/github.com/hamba/avro/v2/schema_compatibility.go | 487 |
1 files changed, 487 insertions, 0 deletions
diff --git a/vendor/github.com/hamba/avro/v2/schema_compatibility.go b/vendor/github.com/hamba/avro/v2/schema_compatibility.go new file mode 100644 index 0000000..0b1d9ac --- /dev/null +++ b/vendor/github.com/hamba/avro/v2/schema_compatibility.go @@ -0,0 +1,487 @@ +package avro + +import ( + "errors" + "fmt" + "sync" +) + +type recursionError struct{} + +func (e recursionError) Error() string { + return "" +} + +type compatKey struct { + reader [32]byte + writer [32]byte +} + +// SchemaCompatibility determines the compatibility of schemas. +type SchemaCompatibility struct { + cache sync.Map // map[compatKey]error +} + +// NewSchemaCompatibility creates a new schema compatibility instance. +func NewSchemaCompatibility() *SchemaCompatibility { + return &SchemaCompatibility{} +} + +// Compatible determines the compatibility if the reader and writer schemas. +func (c *SchemaCompatibility) Compatible(reader, writer Schema) error { + return c.compatible(reader, writer) +} + +func (c *SchemaCompatibility) compatible(reader, writer Schema) error { + key := compatKey{reader: reader.Fingerprint(), writer: writer.Fingerprint()} + if err, ok := c.cache.Load(key); ok { + if _, ok := err.(recursionError); ok { + // Break the recursion here. + return nil + } + + if err == nil { + return nil + } + + return err.(error) + } + + c.cache.Store(key, recursionError{}) + err := c.match(reader, writer) + if err != nil { + // We dont want to pay the cost of fmt.Errorf every time + err = errors.New(err.Error()) + } + c.cache.Store(key, err) + return err +} + +func (c *SchemaCompatibility) match(reader, writer Schema) error { + // If the schema is a reference, get the actual schema + if reader.Type() == Ref { + reader = reader.(*RefSchema).Schema() + } + if writer.Type() == Ref { + writer = writer.(*RefSchema).Schema() + } + + if reader.Type() != writer.Type() { + if writer.Type() == Union { + // Reader must be compatible with all types in writer + for _, schema := range writer.(*UnionSchema).Types() { + if err := c.compatible(reader, schema); err != nil { + return err + } + } + + return nil + } + + if reader.Type() == Union { + // Writer must be compatible with at least one reader schema + var err error + for _, schema := range reader.(*UnionSchema).Types() { + err = c.compatible(schema, writer) + if err == nil { + return nil + } + } + + return fmt.Errorf("reader union lacking writer schema %s", writer.Type()) + } + + switch writer.Type() { + case Int: + if reader.Type() == Long || reader.Type() == Float || reader.Type() == Double { + return nil + } + + case Long: + if reader.Type() == Float || reader.Type() == Double { + return nil + } + + case Float: + if reader.Type() == Double { + return nil + } + + case String: + if reader.Type() == Bytes { + return nil + } + + case Bytes: + if reader.Type() == String { + return nil + } + } + + return fmt.Errorf("reader schema %s not compatible with writer schema %s", reader.Type(), writer.Type()) + } + + switch reader.Type() { + case Array: + return c.compatible(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items()) + + case Map: + return c.compatible(reader.(*MapSchema).Values(), writer.(*MapSchema).Values()) + + case Fixed: + r := reader.(*FixedSchema) + w := writer.(*FixedSchema) + + if err := c.checkSchemaName(r, w); err != nil { + return err + } + + if err := c.checkFixedSize(r, w); err != nil { + return err + } + + case Enum: + r := reader.(*EnumSchema) + w := writer.(*EnumSchema) + + if err := c.checkSchemaName(r, w); err != nil { + return err + } + + if err := c.checkEnumSymbols(r, w); err != nil { + if r.HasDefault() { + return nil + } + return err + } + + case Record: + r := reader.(*RecordSchema) + w := writer.(*RecordSchema) + + if err := c.checkSchemaName(r, w); err != nil { + return err + } + + if err := c.checkRecordFields(r, w); err != nil { + return err + } + + case Union: + for _, schema := range writer.(*UnionSchema).Types() { + if err := c.compatible(reader, schema); err != nil { + return err + } + } + } + + return nil +} + +func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error { + if reader.Name() != writer.Name() { + if c.contains(reader.Aliases(), writer.FullName()) { + return nil + } + return fmt.Errorf("reader schema %s and writer schema %s names do not match", reader.FullName(), writer.FullName()) + } + + return nil +} + +func (c *SchemaCompatibility) checkFixedSize(reader, writer *FixedSchema) error { + if reader.Size() != writer.Size() { + return fmt.Errorf("%s reader and writer fixed sizes do not match", reader.FullName()) + } + + return nil +} + +func (c *SchemaCompatibility) checkEnumSymbols(reader, writer *EnumSchema) error { + for _, symbol := range writer.Symbols() { + if !c.contains(reader.Symbols(), symbol) { + return fmt.Errorf("reader %s is missing symbol %s", reader.FullName(), symbol) + } + } + + return nil +} + +func (c *SchemaCompatibility) checkRecordFields(reader, writer *RecordSchema) error { + for _, field := range reader.Fields() { + f, ok := c.getField(writer.Fields(), field, func(gfo *getFieldOptions) { + gfo.fieldAlias = true + }) + if !ok { + if field.HasDefault() { + continue + } + + return fmt.Errorf("reader field %s is missing in writer schema and has no default", field.Name()) + } + + if err := c.compatible(field.Type(), f.Type()); err != nil { + return err + } + } + + return nil +} + +func (c *SchemaCompatibility) contains(a []string, s string) bool { + for _, str := range a { + if str == s { + return true + } + } + + return false +} + +type getFieldOptions struct { + fieldAlias bool + elemAlias bool +} + +func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*getFieldOptions)) (*Field, bool) { + opt := getFieldOptions{} + for _, fn := range optFns { + fn(&opt) + } + for _, field := range a { + if field.Name() == f.Name() { + return field, true + } + if opt.fieldAlias { + if c.contains(f.Aliases(), field.Name()) { + return field, true + } + } + if opt.elemAlias { + if c.contains(field.Aliases(), f.Name()) { + return field, true + } + } + } + + return nil, false +} + +// Resolve returns a composite schema that allows decoding data written by the writer schema, +// and makes necessary adjustments to support the reader schema. +// +// It fails if the writer and reader schemas are not compatible. +func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) { + if err := c.compatible(reader, writer); err != nil { + return nil, err + } + + schema, _, err := c.resolve(reader, writer) + return schema, err +} + +// resolve requires the reader's schema to be already compatible with the writer's. +func (c *SchemaCompatibility) resolve(reader, writer Schema) (schema Schema, resolved bool, err error) { + if reader.Type() == Ref { + reader = reader.(*RefSchema).Schema() + } + if writer.Type() == Ref { + writer = writer.(*RefSchema).Schema() + } + + if writer.Type() != reader.Type() { + if reader.Type() == Union { + for _, schema := range reader.(*UnionSchema).Types() { + // Compatibility is not guaranteed for every Union reader schema. + // Therefore, we need to check compatibility in every iteration. + if err := c.compatible(schema, writer); err != nil { + continue + } + sch, _, err := c.resolve(schema, writer) + if err != nil { + continue + } + return sch, true, nil + } + + return nil, false, fmt.Errorf("reader union lacking writer schema %s", writer.Type()) + } + + if writer.Type() == Union { + schemas := make([]Schema, 0) + for _, schema := range writer.(*UnionSchema).Types() { + sch, _, err := c.resolve(reader, schema) + if err != nil { + return nil, false, err + } + schemas = append(schemas, sch) + } + s, err := NewUnionSchema(schemas, withWriterFingerprint(writer.Fingerprint())) + return s, true, err + } + + if isPromotable(writer.Type(), reader.Type()) { + r := NewPrimitiveSchema(reader.Type(), reader.(*PrimitiveSchema).Logical(), + withWriterFingerprint(writer.Fingerprint()), + ) + r.encodedType = writer.Type() + return r, true, nil + } + + return nil, false, fmt.Errorf("failed to resolve composite schema for %s and %s", reader.Type(), writer.Type()) + } + + if isNative(writer.Type()) { + return reader, false, nil + } + + if writer.Type() == Enum { + r := reader.(*EnumSchema) + w := writer.(*EnumSchema) + if err = c.checkEnumSymbols(r, w); err != nil { + if r.HasDefault() { + enum, _ := NewEnumSchema(r.Name(), r.Namespace(), r.Symbols(), + WithAliases(r.Aliases()), + WithDefault(r.Default()), + withWriterFingerprint(w.Fingerprint()), + ) + enum.encodedSymbols = w.Symbols() + return enum, true, nil + } + + return nil, false, err + } + return reader, false, nil + } + + if writer.Type() == Fixed { + return reader, false, nil + } + + if writer.Type() == Union { + schemas := make([]Schema, 0) + for _, s := range writer.(*UnionSchema).Types() { + sch, resolv, err := c.resolve(reader, s) + if err != nil { + return nil, false, err + } + schemas = append(schemas, sch) + resolved = resolv || resolved + } + s, err := NewUnionSchema(schemas, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved)) + if err != nil { + return nil, false, err + } + return s, resolved, nil + } + + if writer.Type() == Array { + schema, resolved, err = c.resolve(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items()) + if err != nil { + return nil, false, err + } + return NewArraySchema(schema, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved)), resolved, nil + } + + if writer.Type() == Map { + schema, resolved, err = c.resolve(reader.(*MapSchema).Values(), writer.(*MapSchema).Values()) + if err != nil { + return nil, false, err + } + return NewMapSchema(schema, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved)), resolved, nil + } + + if writer.Type() == Record { + return c.resolveRecord(reader, writer) + } + + return nil, false, fmt.Errorf("failed to resolve composite schema for %s and %s", reader.Type(), writer.Type()) +} + +func (c *SchemaCompatibility) resolveRecord(reader, writer Schema) (Schema, bool, error) { + w := writer.(*RecordSchema) + r := reader.(*RecordSchema) + + fields := make([]*Field, 0) + seen := make(map[string]struct{}) + + var resolved bool + for _, wf := range w.Fields() { + rf, ok := c.getField(r.Fields(), wf, func(gfo *getFieldOptions) { + gfo.elemAlias = true + }) + if !ok { + // The field was not found in the reader schema, it should be ignored. + f, _ := NewField(wf.Name(), wf.Type(), WithAliases(wf.aliases), WithOrder(wf.order)) + f.def = wf.def + f.hasDef = wf.hasDef + f.action = FieldIgnore + fields = append(fields, f) + + resolved = true + continue + } + + ft, resolv, err := c.resolve(rf.Type(), wf.Type()) + if err != nil { + return nil, false, err + } + f, _ := NewField(rf.Name(), ft, WithAliases(rf.aliases), WithOrder(rf.order)) + f.def = rf.def + f.hasDef = rf.hasDef + fields = append(fields, f) + resolved = resolv || resolved + + seen[rf.Name()] = struct{}{} + } + + for _, rf := range r.Fields() { + if _, ok := seen[rf.Name()]; ok { + // This field has already been seen. + continue + } + + // The schemas are already known to be compatible, so there must be a default on + // the field in the writer. Use the default. + + f, _ := NewField(rf.Name(), rf.Type(), WithAliases(rf.aliases), WithOrder(rf.order)) + f.def = rf.def + f.hasDef = rf.hasDef + f.action = FieldSetDefault + fields = append(fields, f) + + resolved = true + } + + schema, err := NewRecordSchema(r.Name(), r.Namespace(), fields, + WithAliases(r.Aliases()), + withWriterFingerprintIfResolved(writer.Fingerprint(), resolved), + ) + return schema, resolved, err +} + +func isNative(typ Type) bool { + switch typ { + case Null, Boolean, Int, Long, Float, Double, Bytes, String: + return true + default: + return false + } +} + +func isPromotable(writerTyp, readerType Type) bool { + switch writerTyp { + case Int: + return readerType == Long || readerType == Float || readerType == Double + case Long: + return readerType == Float || readerType == Double + case Float: + return readerType == Double + case String: + return readerType == Bytes + case Bytes: + return readerType == String + default: + return false + } +} |
