diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/hamba/avro/v2/codec_union.go | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/hamba/avro/v2/codec_union.go')
| -rw-r--r-- | vendor/github.com/hamba/avro/v2/codec_union.go | 460 |
1 files changed, 460 insertions, 0 deletions
diff --git a/vendor/github.com/hamba/avro/v2/codec_union.go b/vendor/github.com/hamba/avro/v2/codec_union.go new file mode 100644 index 0000000..7d80b53 --- /dev/null +++ b/vendor/github.com/hamba/avro/v2/codec_union.go @@ -0,0 +1,460 @@ +package avro + +import ( + "errors" + "fmt" + "reflect" + "strings" + "unsafe" + + "github.com/modern-go/reflect2" +) + +func createDecoderOfUnion(d *decoderContext, schema *UnionSchema, typ reflect2.Type) ValDecoder { + switch typ.Kind() { + case reflect.Map: + if typ.(reflect2.MapType).Key().Kind() != reflect.String || + typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { + break + } + return decoderOfMapUnion(d, schema, typ) + case reflect.Slice: + if !schema.Nullable() { + break + } + return decoderOfNullableUnion(d, schema, typ) + case reflect.Ptr: + if !schema.Nullable() { + break + } + return decoderOfNullableUnion(d, schema, typ) + case reflect.Interface: + if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { + dec, err := decoderOfResolvedUnion(d, schema) + if err != nil { + return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)} + } + return dec + } + } + + return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} +} + +func createEncoderOfUnion(e *encoderContext, schema *UnionSchema, typ reflect2.Type) ValEncoder { + switch typ.Kind() { + case reflect.Map: + if typ.(reflect2.MapType).Key().Kind() != reflect.String || + typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { + break + } + return encoderOfMapUnion(e, schema, typ) + case reflect.Slice: + if !schema.Nullable() { + break + } + return encoderOfNullableUnion(e, schema, typ) + case reflect.Ptr: + if !schema.Nullable() { + break + } + return encoderOfNullableUnion(e, schema, typ) + } + return encoderOfResolverUnion(e, schema, typ) +} + +func decoderOfMapUnion(d *decoderContext, union *UnionSchema, typ reflect2.Type) ValDecoder { + mapType := typ.(*reflect2.UnsafeMapType) + + typeDecs := make([]ValDecoder, len(union.Types())) + for i, s := range union.Types() { + if s.Type() == Null { + continue + } + typeDecs[i] = newEfaceDecoder(d, s) + } + + return &mapUnionDecoder{ + cfg: d.cfg, + schema: union, + mapType: mapType, + elemType: mapType.Elem(), + typeDecs: typeDecs, + } +} + +type mapUnionDecoder struct { + cfg *frozenConfig + schema *UnionSchema + mapType *reflect2.UnsafeMapType + elemType reflect2.Type + typeDecs []ValDecoder +} + +func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + idx, resSchema := getUnionSchema(d.schema, r) + if resSchema == nil { + return + } + + // In a null case, just return + if resSchema.Type() == Null { + return + } + + if d.mapType.UnsafeIsNil(ptr) { + d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(1)) + } + + key := schemaTypeName(resSchema) + keyPtr := reflect2.PtrOf(key) + + elemPtr := d.elemType.UnsafeNew() + d.typeDecs[idx].Decode(elemPtr, r) + + d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) +} + +func encoderOfMapUnion(e *encoderContext, union *UnionSchema, _ reflect2.Type) ValEncoder { + return &mapUnionEncoder{ + cfg: e.cfg, + schema: union, + } +} + +type mapUnionEncoder struct { + cfg *frozenConfig + schema *UnionSchema +} + +func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { + m := *((*map[string]any)(ptr)) + + if len(m) > 1 { + w.Error = errors.New("avro: cannot encode union map with multiple entries") + return + } + + name := "null" + val := any(nil) + for k, v := range m { + name = k + val = v + break + } + + schema, pos := e.schema.Types().Get(name) + if schema == nil { + w.Error = fmt.Errorf("avro: unknown union type %s", name) + return + } + + w.WriteInt(int32(pos)) + + if schema.Type() == Null && val == nil { + return + } + + elemType := reflect2.TypeOf(val) + elemPtr := reflect2.PtrOf(val) + + encoder := encoderOfType(newEncoderContext(e.cfg), schema, elemType) + if elemType.LikePtr() { + encoder = &onePtrEncoder{encoder} + } + encoder.Encode(elemPtr, w) +} + +func decoderOfNullableUnion(d *decoderContext, schema Schema, typ reflect2.Type) ValDecoder { + union := schema.(*UnionSchema) + _, typeIdx := union.Indices() + + var ( + baseTyp reflect2.Type + isPtr bool + ) + switch v := typ.(type) { + case *reflect2.UnsafePtrType: + baseTyp = v.Elem() + isPtr = true + case *reflect2.UnsafeSliceType: + baseTyp = v + } + decoder := decoderOfType(d, union.Types()[typeIdx], baseTyp) + + return &unionNullableDecoder{ + schema: union, + typ: baseTyp, + isPtr: isPtr, + decoder: decoder, + } +} + +type unionNullableDecoder struct { + schema *UnionSchema + typ reflect2.Type + isPtr bool + decoder ValDecoder +} + +func (d *unionNullableDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + _, schema := getUnionSchema(d.schema, r) + if schema == nil { + return + } + + if schema.Type() == Null { + *((*unsafe.Pointer)(ptr)) = nil + return + } + + // Handle the non-ptr case separately. + if !d.isPtr { + if d.typ.UnsafeIsNil(ptr) { + // Create a new instance. + newPtr := d.typ.UnsafeNew() + d.decoder.Decode(newPtr, r) + d.typ.UnsafeSet(ptr, newPtr) + return + } + + // Reuse the existing instance. + d.decoder.Decode(ptr, r) + return + } + + if *((*unsafe.Pointer)(ptr)) == nil { + // Create new instance. + newPtr := d.typ.UnsafeNew() + d.decoder.Decode(newPtr, r) + *((*unsafe.Pointer)(ptr)) = newPtr + return + } + + // Reuse existing instance. + d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) +} + +func encoderOfNullableUnion(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder { + union := schema.(*UnionSchema) + nullIdx, typeIdx := union.Indices() + + var ( + baseTyp reflect2.Type + isPtr bool + ) + switch v := typ.(type) { + case *reflect2.UnsafePtrType: + baseTyp = v.Elem() + isPtr = true + case *reflect2.UnsafeSliceType: + baseTyp = v + } + encoder := encoderOfType(e, union.Types()[typeIdx], baseTyp) + + return &unionNullableEncoder{ + schema: union, + encoder: encoder, + isPtr: isPtr, + nullIdx: int32(nullIdx), + typeIdx: int32(typeIdx), + } +} + +type unionNullableEncoder struct { + schema *UnionSchema + encoder ValEncoder + isPtr bool + nullIdx int32 + typeIdx int32 +} + +func (e *unionNullableEncoder) Encode(ptr unsafe.Pointer, w *Writer) { + if *((*unsafe.Pointer)(ptr)) == nil { + w.WriteInt(e.nullIdx) + return + } + + w.WriteInt(e.typeIdx) + newPtr := ptr + if e.isPtr { + newPtr = *((*unsafe.Pointer)(ptr)) + } + e.encoder.Encode(newPtr, w) +} + +func decoderOfResolvedUnion(d *decoderContext, schema Schema) (ValDecoder, error) { + union := schema.(*UnionSchema) + + types := make([]reflect2.Type, len(union.Types())) + decoders := make([]ValDecoder, len(union.Types())) + for i, schema := range union.Types() { + name := unionResolutionName(schema) + + typ, err := d.cfg.resolver.Type(name) + if err != nil { + if d.cfg.config.UnionResolutionError { + return nil, err + } + + if d.cfg.config.PartialUnionTypeResolution { + decoders[i] = nil + types[i] = nil + continue + } + + decoders = []ValDecoder{} + types = []reflect2.Type{} + break + } + + decoder := decoderOfType(d, schema, typ) + decoders[i] = decoder + types[i] = typ + } + + return &unionResolvedDecoder{ + cfg: d.cfg, + schema: union, + types: types, + decoders: decoders, + }, nil +} + +type unionResolvedDecoder struct { + cfg *frozenConfig + schema *UnionSchema + types []reflect2.Type + decoders []ValDecoder +} + +func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + i, schema := getUnionSchema(d.schema, r) + if schema == nil { + return + } + + pObj := (*any)(ptr) + + if schema.Type() == Null { + *pObj = nil + return + } + + if i >= len(d.decoders) || d.decoders[i] == nil { + if d.cfg.config.UnionResolutionError { + r.ReportError("decode union type", "unknown union type") + return + } + + // We cannot resolve this, set it to the map type + name := schemaTypeName(schema) + obj := map[string]any{} + vTyp, err := genericReceiver(schema) + if err != nil { + r.ReportError("Union", err.Error()) + return + } + obj[name] = genericDecode(vTyp, decoderOfType(newDecoderContext(d.cfg), schema, vTyp), r) + + *pObj = obj + return + } + + typ := d.types[i] + var newPtr unsafe.Pointer + switch typ.Kind() { + case reflect.Map: + mapType := typ.(*reflect2.UnsafeMapType) + newPtr = mapType.UnsafeMakeMap(1) + + case reflect.Slice: + mapType := typ.(*reflect2.UnsafeSliceType) + newPtr = mapType.UnsafeMakeSlice(1, 1) + + case reflect.Ptr: + elemType := typ.(*reflect2.UnsafePtrType).Elem() + newPtr = elemType.UnsafeNew() + + default: + newPtr = typ.UnsafeNew() + } + + d.decoders[i].Decode(newPtr, r) + *pObj = typ.UnsafeIndirect(newPtr) +} + +func unionResolutionName(schema Schema) string { + name := schemaTypeName(schema) + switch schema.Type() { + case Map: + name += ":" + valSchema := schema.(*MapSchema).Values() + valName := schemaTypeName(valSchema) + + name += valName + + case Array: + name += ":" + itemSchema := schema.(*ArraySchema).Items() + itemName := schemaTypeName(itemSchema) + + name += itemName + } + + return name +} + +func encoderOfResolverUnion(e *encoderContext, schema Schema, typ reflect2.Type) ValEncoder { + union := schema.(*UnionSchema) + + names, err := e.cfg.resolver.Name(typ) + if err != nil { + return &errorEncoder{err: err} + } + + var pos int + for _, name := range names { + if idx := strings.Index(name, ":"); idx > 0 { + name = name[:idx] + } + + schema, pos = union.Types().Get(name) + if schema != nil { + break + } + } + if schema == nil { + return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])} + } + + encoder := encoderOfType(e, schema, typ) + + return &unionResolverEncoder{ + pos: pos, + encoder: encoder, + } +} + +type unionResolverEncoder struct { + pos int + encoder ValEncoder +} + +func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) { + w.WriteInt(int32(e.pos)) + + e.encoder.Encode(ptr, w) +} + +func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) { + types := schema.Types() + + idx := int(r.ReadInt()) + if idx < 0 || idx > len(types)-1 { + r.ReportError("decode union type", "unknown union type") + return 0, nil + } + + return idx, types[idx] +} |
