diff options
Diffstat (limited to 'vendor/github.com/hamba/avro/v2/codec_map.go')
| -rw-r--r-- | vendor/github.com/hamba/avro/v2/codec_map.go | 246 |
1 files changed, 246 insertions, 0 deletions
diff --git a/vendor/github.com/hamba/avro/v2/codec_map.go b/vendor/github.com/hamba/avro/v2/codec_map.go new file mode 100644 index 0000000..18c7de1 --- /dev/null +++ b/vendor/github.com/hamba/avro/v2/codec_map.go @@ -0,0 +1,246 @@ +package avro + +import ( + "encoding" + "errors" + "fmt" + "io" + "reflect" + "unsafe" + + "github.com/modern-go/reflect2" +) + +func createDecoderOfMap(d *decoderContext, schema *MapSchema, typ reflect2.Type) ValDecoder { + if typ.Kind() == reflect.Map { + keyType := typ.(reflect2.MapType).Key() + switch { + case keyType.Kind() == reflect.String: + return decoderOfMap(d, schema, typ) + case keyType.Implements(textUnmarshalerType): + return decoderOfMapUnmarshaler(d, schema, typ) + } + } + + return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} +} + +func createEncoderOfMap(e *encoderContext, schema *MapSchema, typ reflect2.Type) ValEncoder { + if typ.Kind() == reflect.Map { + keyType := typ.(reflect2.MapType).Key() + switch { + case keyType.Kind() == reflect.String: + return encoderOfMap(e, schema, typ) + case keyType.Implements(textMarshalerType): + return encoderOfMapMarshaler(e, schema, typ) + } + } + + return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} +} + +func decoderOfMap(d *decoderContext, m *MapSchema, typ reflect2.Type) ValDecoder { + mapType := typ.(*reflect2.UnsafeMapType) + decoder := decoderOfType(d, m.Values(), mapType.Elem()) + + return &mapDecoder{ + mapType: mapType, + elemType: mapType.Elem(), + decoder: decoder, + } +} + +type mapDecoder struct { + mapType *reflect2.UnsafeMapType + elemType reflect2.Type + decoder ValDecoder +} + +func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { + if d.mapType.UnsafeIsNil(ptr) { + d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) + } + + for { + l, _ := r.ReadBlockHeader() + if l == 0 { + break + } + + for range l { + keyPtr := reflect2.PtrOf(r.ReadString()) + elemPtr := d.elemType.UnsafeNew() + d.decoder.Decode(elemPtr, r) + if r.Error != nil { + r.Error = fmt.Errorf("reading map[string]%s: %w", d.elemType.String(), r.Error) + return + } + + d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) + } + } + + if r.Error != nil && !errors.Is(r.Error, io.EOF) { + r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error) + } +} + +func decoderOfMapUnmarshaler(d *decoderContext, m *MapSchema, typ reflect2.Type) ValDecoder { + mapType := typ.(*reflect2.UnsafeMapType) + decoder := decoderOfType(d, m.Values(), mapType.Elem()) + + return &mapDecoderUnmarshaler{ + mapType: mapType, + keyType: mapType.Key(), + elemType: mapType.Elem(), + decoder: decoder, + } +} + +type mapDecoderUnmarshaler struct { + mapType *reflect2.UnsafeMapType + keyType reflect2.Type + elemType reflect2.Type + decoder ValDecoder +} + +func (d *mapDecoderUnmarshaler) Decode(ptr unsafe.Pointer, r *Reader) { + if d.mapType.UnsafeIsNil(ptr) { + d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) + } + + for { + l, _ := r.ReadBlockHeader() + if l == 0 { + break + } + + for range l { + keyPtr := d.keyType.UnsafeNew() + keyObj := d.keyType.UnsafeIndirect(keyPtr) + if reflect2.IsNil(keyObj) { + ptrType := d.keyType.(*reflect2.UnsafePtrType) + newPtr := ptrType.Elem().UnsafeNew() + *((*unsafe.Pointer)(keyPtr)) = newPtr + keyObj = d.keyType.UnsafeIndirect(keyPtr) + } + unmarshaler := keyObj.(encoding.TextUnmarshaler) + err := unmarshaler.UnmarshalText([]byte(r.ReadString())) + if err != nil { + r.ReportError("mapDecoderUnmarshaler", err.Error()) + return + } + + elemPtr := d.elemType.UnsafeNew() + d.decoder.Decode(elemPtr, r) + + d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) + } + } + + if r.Error != nil && !errors.Is(r.Error, io.EOF) { + r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error) + } +} + +func encoderOfMap(e *encoderContext, m *MapSchema, typ reflect2.Type) ValEncoder { + mapType := typ.(*reflect2.UnsafeMapType) + encoder := encoderOfType(e, m.Values(), mapType.Elem()) + + return &mapEncoder{ + blockLength: e.cfg.getBlockLength(), + mapType: mapType, + encoder: encoder, + } +} + +type mapEncoder struct { + blockLength int + mapType *reflect2.UnsafeMapType + encoder ValEncoder +} + +func (e *mapEncoder) Encode(ptr unsafe.Pointer, w *Writer) { + blockLength := e.blockLength + + iter := e.mapType.UnsafeIterate(ptr) + + for { + wrote := w.WriteBlockCB(func(w *Writer) int64 { + var i int + for i = 0; iter.HasNext() && i < blockLength; i++ { + keyPtr, elemPtr := iter.UnsafeNext() + w.WriteString(*((*string)(keyPtr))) + e.encoder.Encode(elemPtr, w) + } + + return int64(i) + }) + + if wrote == 0 { + break + } + } + + if w.Error != nil && !errors.Is(w.Error, io.EOF) { + w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error) + } +} + +func encoderOfMapMarshaler(e *encoderContext, m *MapSchema, typ reflect2.Type) ValEncoder { + mapType := typ.(*reflect2.UnsafeMapType) + encoder := encoderOfType(e, m.Values(), mapType.Elem()) + + return &mapEncoderMarshaller{ + blockLength: e.cfg.getBlockLength(), + mapType: mapType, + keyType: mapType.Key(), + encoder: encoder, + } +} + +type mapEncoderMarshaller struct { + blockLength int + mapType *reflect2.UnsafeMapType + keyType reflect2.Type + encoder ValEncoder +} + +func (e *mapEncoderMarshaller) Encode(ptr unsafe.Pointer, w *Writer) { + blockLength := e.blockLength + + iter := e.mapType.UnsafeIterate(ptr) + + for { + wrote := w.WriteBlockCB(func(w *Writer) int64 { + var i int + for i = 0; iter.HasNext() && i < blockLength; i++ { + keyPtr, elemPtr := iter.UnsafeNext() + + obj := e.keyType.UnsafeIndirect(keyPtr) + if e.keyType.IsNullable() && reflect2.IsNil(obj) { + w.Error = errors.New("avro: mapEncoderMarshaller: encoding nil TextMarshaller") + return int64(0) + } + marshaler := (obj).(encoding.TextMarshaler) + b, err := marshaler.MarshalText() + if err != nil { + w.Error = err + return int64(0) + } + w.WriteString(string(b)) + + e.encoder.Encode(elemPtr, w) + } + return int64(i) + }) + + if wrote == 0 { + break + } + } + + if w.Error != nil && !errors.Is(w.Error, io.EOF) { + w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error) + } +} |
