summaryrefslogtreecommitdiff
path: root/vendor/github.com/hamba/avro/v2/codec_union.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/hamba/avro/v2/codec_union.go')
-rw-r--r--vendor/github.com/hamba/avro/v2/codec_union.go460
1 files changed, 460 insertions, 0 deletions
diff --git a/vendor/github.com/hamba/avro/v2/codec_union.go b/vendor/github.com/hamba/avro/v2/codec_union.go
new file mode 100644
index 0000000..7d80b53
--- /dev/null
+++ b/vendor/github.com/hamba/avro/v2/codec_union.go
@@ -0,0 +1,460 @@
+package avro
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "strings"
+ "unsafe"
+
+ "github.com/modern-go/reflect2"
+)
+
+func createDecoderOfUnion(d *decoderContext, schema *UnionSchema, typ reflect2.Type) ValDecoder {
+ switch typ.Kind() {
+ case reflect.Map:
+ if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
+ typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
+ break
+ }
+ return decoderOfMapUnion(d, schema, typ)
+ case reflect.Slice:
+ if !schema.Nullable() {
+ break
+ }
+ return decoderOfNullableUnion(d, schema, typ)
+ case reflect.Ptr:
+ if !schema.Nullable() {
+ break
+ }
+ return decoderOfNullableUnion(d, schema, typ)
+ case reflect.Interface:
+ if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
+ dec, err := decoderOfResolvedUnion(d, schema)
+ if err != nil {
+ return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)}
+ }
+ return dec
+ }
+ }
+
+ return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
+}
+
+func createEncoderOfUnion(e *encoderContext, schema *UnionSchema, typ reflect2.Type) ValEncoder {
+ switch typ.Kind() {
+ case reflect.Map:
+ if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
+ typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
+ break
+ }
+ return encoderOfMapUnion(e, schema, typ)
+ case reflect.Slice:
+ if !schema.Nullable() {
+ break
+ }
+ return encoderOfNullableUnion(e, schema, typ)
+ case reflect.Ptr:
+ if !schema.Nullable() {
+ break
+ }
+ return encoderOfNullableUnion(e, schema, typ)
+ }
+ return encoderOfResolverUnion(e, schema, typ)
+}
+
+func decoderOfMapUnion(d *decoderContext, union *UnionSchema, typ reflect2.Type) ValDecoder {
+ mapType := typ.(*reflect2.UnsafeMapType)
+
+ typeDecs := make([]ValDecoder, len(union.Types()))
+ for i, s := range union.Types() {
+ if s.Type() == Null {
+ continue
+ }
+ typeDecs[i] = newEfaceDecoder(d, s)
+ }
+
+ return &mapUnionDecoder{
+ cfg: d.cfg,
+ schema: union,
+ mapType: mapType,
+ elemType: mapType.Elem(),
+ typeDecs: typeDecs,
+ }
+}
+
+type mapUnionDecoder struct {
+ cfg *frozenConfig
+ schema *UnionSchema
+ mapType *reflect2.UnsafeMapType
+ elemType reflect2.Type
+ typeDecs []ValDecoder
+}
+
+func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
+ idx, resSchema := getUnionSchema(d.schema, r)
+ if resSchema == nil {
+ return
+ }
+
+ // In a null case, just return
+ if resSchema.Type() == Null {
+ return
+ }
+
+ if d.mapType.UnsafeIsNil(ptr) {
+ d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(1))
+ }
+
+ key := schemaTypeName(resSchema)
+ keyPtr := reflect2.PtrOf(key)
+
+ elemPtr := d.elemType.UnsafeNew()
+ d.typeDecs[idx].Decode(elemPtr, r)
+
+ d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
+}
+
+func encoderOfMapUnion(e *encoderContext, union *UnionSchema, _ reflect2.Type) ValEncoder {
+ return &mapUnionEncoder{
+ cfg: e.cfg,
+ schema: union,
+ }
+}
+
+type mapUnionEncoder struct {
+ cfg *frozenConfig
+ schema *UnionSchema
+}
+
+func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
+ m := *((*map[string]any)(ptr))
+
+ if len(m) > 1 {
+ w.Error = errors.New("avro: cannot encode union map with multiple entries")
+ return
+ }
+
+ name := "null"
+ val := any(nil)
+ for k, v := range m {
+ name = k
+ val = v
+ break
+ }
+
+ schema, pos := e.schema.Types().Get(name)
+ if schema == nil {
+ w.Error = fmt.Errorf("avro: unknown union type %s", name)
+ return
+ }
+
+ w.WriteInt(int32(pos))
+
+ if schema.Type() == Null && val == nil {
+ return
+ }
+
+ elemType := reflect2.TypeOf(val)
+ elemPtr := reflect2.PtrOf(val)
+
+ encoder := encoderOfType(newEncoderContext(e.cfg), schema, elemType)
+ if elemType.LikePtr() {
+ encoder = &onePtrEncoder{encoder}
+ }
+ encoder.Encode(elemPtr, w)
+}
+
+func decoderOfNullableUnion(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder {
+ union := schema.(*UnionSchema)
+ _, typeIdx := union.Indices()
+
+ var (
+ baseTyp reflect2.Type
+ isPtr bool
+ )
+ switch v := typ.(type) {
+ case *reflect2.UnsafePtrType:
+ baseTyp = v.Elem()
+ isPtr = true
+ case *reflect2.UnsafeSliceType:
+ baseTyp = v
+ }
+ decoder := decoderOfType(d, union.Types()[typeIdx], baseTyp)
+
+ return &unionNullableDecoder{
+ schema: union,
+ typ: baseTyp,
+ isPtr: isPtr,
+ decoder: decoder,
+ }
+}
+
+type unionNullableDecoder struct {
+ schema *UnionSchema
+ typ reflect2.Type
+ isPtr bool
+ decoder ValDecoder
+}
+
+func (d *unionNullableDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
+ _, schema := getUnionSchema(d.schema, r)
+ if schema == nil {
+ return
+ }
+
+ if schema.Type() == Null {
+ *((*unsafe.Pointer)(ptr)) = nil
+ return
+ }
+
+ // Handle the non-ptr case separately.
+ if !d.isPtr {
+ if d.typ.UnsafeIsNil(ptr) {
+ // Create a new instance.
+ newPtr := d.typ.UnsafeNew()
+ d.decoder.Decode(newPtr, r)
+ d.typ.UnsafeSet(ptr, newPtr)
+ return
+ }
+
+ // Reuse the existing instance.
+ d.decoder.Decode(ptr, r)
+ return
+ }
+
+ if *((*unsafe.Pointer)(ptr)) == nil {
+ // Create new instance.
+ newPtr := d.typ.UnsafeNew()
+ d.decoder.Decode(newPtr, r)
+ *((*unsafe.Pointer)(ptr)) = newPtr
+ return
+ }
+
+ // Reuse existing instance.
+ d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r)
+}
+
+func encoderOfNullableUnion(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder {
+ union := schema.(*UnionSchema)
+ nullIdx, typeIdx := union.Indices()
+
+ var (
+ baseTyp reflect2.Type
+ isPtr bool
+ )
+ switch v := typ.(type) {
+ case *reflect2.UnsafePtrType:
+ baseTyp = v.Elem()
+ isPtr = true
+ case *reflect2.UnsafeSliceType:
+ baseTyp = v
+ }
+ encoder := encoderOfType(e, union.Types()[typeIdx], baseTyp)
+
+ return &unionNullableEncoder{
+ schema: union,
+ encoder: encoder,
+ isPtr: isPtr,
+ nullIdx: int32(nullIdx),
+ typeIdx: int32(typeIdx),
+ }
+}
+
+type unionNullableEncoder struct {
+ schema *UnionSchema
+ encoder ValEncoder
+ isPtr bool
+ nullIdx int32
+ typeIdx int32
+}
+
+func (e *unionNullableEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
+ if *((*unsafe.Pointer)(ptr)) == nil {
+ w.WriteInt(e.nullIdx)
+ return
+ }
+
+ w.WriteInt(e.typeIdx)
+ newPtr := ptr
+ if e.isPtr {
+ newPtr = *((*unsafe.Pointer)(ptr))
+ }
+ e.encoder.Encode(newPtr, w)
+}
+
+func decoderOfResolvedUnion(d *decoderContext, schema Schema) (ValDecoder, error) {
+ union := schema.(*UnionSchema)
+
+ types := make([]reflect2.Type, len(union.Types()))
+ decoders := make([]ValDecoder, len(union.Types()))
+ for i, schema := range union.Types() {
+ name := unionResolutionName(schema)
+
+ typ, err := d.cfg.resolver.Type(name)
+ if err != nil {
+ if d.cfg.config.UnionResolutionError {
+ return nil, err
+ }
+
+ if d.cfg.config.PartialUnionTypeResolution {
+ decoders[i] = nil
+ types[i] = nil
+ continue
+ }
+
+ decoders = []ValDecoder{}
+ types = []reflect2.Type{}
+ break
+ }
+
+ decoder := decoderOfType(d, schema, typ)
+ decoders[i] = decoder
+ types[i] = typ
+ }
+
+ return &unionResolvedDecoder{
+ cfg: d.cfg,
+ schema: union,
+ types: types,
+ decoders: decoders,
+ }, nil
+}
+
+type unionResolvedDecoder struct {
+ cfg *frozenConfig
+ schema *UnionSchema
+ types []reflect2.Type
+ decoders []ValDecoder
+}
+
+func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
+ i, schema := getUnionSchema(d.schema, r)
+ if schema == nil {
+ return
+ }
+
+ pObj := (*any)(ptr)
+
+ if schema.Type() == Null {
+ *pObj = nil
+ return
+ }
+
+ if i >= len(d.decoders) || d.decoders[i] == nil {
+ if d.cfg.config.UnionResolutionError {
+ r.ReportError("decode union type", "unknown union type")
+ return
+ }
+
+ // We cannot resolve this, set it to the map type
+ name := schemaTypeName(schema)
+ obj := map[string]any{}
+ vTyp, err := genericReceiver(schema)
+ if err != nil {
+ r.ReportError("Union", err.Error())
+ return
+ }
+ obj[name] = genericDecode(vTyp, decoderOfType(newDecoderContext(d.cfg), schema, vTyp), r)
+
+ *pObj = obj
+ return
+ }
+
+ typ := d.types[i]
+ var newPtr unsafe.Pointer
+ switch typ.Kind() {
+ case reflect.Map:
+ mapType := typ.(*reflect2.UnsafeMapType)
+ newPtr = mapType.UnsafeMakeMap(1)
+
+ case reflect.Slice:
+ mapType := typ.(*reflect2.UnsafeSliceType)
+ newPtr = mapType.UnsafeMakeSlice(1, 1)
+
+ case reflect.Ptr:
+ elemType := typ.(*reflect2.UnsafePtrType).Elem()
+ newPtr = elemType.UnsafeNew()
+
+ default:
+ newPtr = typ.UnsafeNew()
+ }
+
+ d.decoders[i].Decode(newPtr, r)
+ *pObj = typ.UnsafeIndirect(newPtr)
+}
+
+func unionResolutionName(schema Schema) string {
+ name := schemaTypeName(schema)
+ switch schema.Type() {
+ case Map:
+ name += ":"
+ valSchema := schema.(*MapSchema).Values()
+ valName := schemaTypeName(valSchema)
+
+ name += valName
+
+ case Array:
+ name += ":"
+ itemSchema := schema.(*ArraySchema).Items()
+ itemName := schemaTypeName(itemSchema)
+
+ name += itemName
+ }
+
+ return name
+}
+
+func encoderOfResolverUnion(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder {
+ union := schema.(*UnionSchema)
+
+ names, err := e.cfg.resolver.Name(typ)
+ if err != nil {
+ return &errorEncoder{err: err}
+ }
+
+ var pos int
+ for _, name := range names {
+ if idx := strings.Index(name, ":"); idx > 0 {
+ name = name[:idx]
+ }
+
+ schema, pos = union.Types().Get(name)
+ if schema != nil {
+ break
+ }
+ }
+ if schema == nil {
+ return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])}
+ }
+
+ encoder := encoderOfType(e, schema, typ)
+
+ return &unionResolverEncoder{
+ pos: pos,
+ encoder: encoder,
+ }
+}
+
+type unionResolverEncoder struct {
+ pos int
+ encoder ValEncoder
+}
+
+func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
+ w.WriteInt(int32(e.pos))
+
+ e.encoder.Encode(ptr, w)
+}
+
+func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) {
+ types := schema.Types()
+
+ idx := int(r.ReadInt())
+ if idx < 0 || idx > len(types)-1 {
+ r.ReportError("decode union type", "unknown union type")
+ return 0, nil
+ }
+
+ return idx, types[idx]
+}