summaryrefslogtreecommitdiff
path: root/vendor/github.com/hamba/avro/v2/codec_map.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/hamba/avro/v2/codec_map.go')
-rw-r--r--vendor/github.com/hamba/avro/v2/codec_map.go246
1 files changed, 246 insertions, 0 deletions
diff --git a/vendor/github.com/hamba/avro/v2/codec_map.go b/vendor/github.com/hamba/avro/v2/codec_map.go
new file mode 100644
index 0000000..18c7de1
--- /dev/null
+++ b/vendor/github.com/hamba/avro/v2/codec_map.go
@@ -0,0 +1,246 @@
+package avro
+
+import (
+ "encoding"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "unsafe"
+
+ "github.com/modern-go/reflect2"
+)
+
+func createDecoderOfMap(d *decoderContext, schema *MapSchema, typ reflect2.Type) ValDecoder {
+ if typ.Kind() == reflect.Map {
+ keyType := typ.(reflect2.MapType).Key()
+ switch {
+ case keyType.Kind() == reflect.String:
+ return decoderOfMap(d, schema, typ)
+ case keyType.Implements(textUnmarshalerType):
+ return decoderOfMapUnmarshaler(d, schema, typ)
+ }
+ }
+
+ return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
+}
+
+func createEncoderOfMap(e *encoderContext, schema *MapSchema, typ reflect2.Type) ValEncoder {
+ if typ.Kind() == reflect.Map {
+ keyType := typ.(reflect2.MapType).Key()
+ switch {
+ case keyType.Kind() == reflect.String:
+ return encoderOfMap(e, schema, typ)
+ case keyType.Implements(textMarshalerType):
+ return encoderOfMapMarshaler(e, schema, typ)
+ }
+ }
+
+ return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
+}
+
+func decoderOfMap(d *decoderContext, m *MapSchema, typ reflect2.Type) ValDecoder {
+ mapType := typ.(*reflect2.UnsafeMapType)
+ decoder := decoderOfType(d, m.Values(), mapType.Elem())
+
+ return &mapDecoder{
+ mapType: mapType,
+ elemType: mapType.Elem(),
+ decoder: decoder,
+ }
+}
+
+type mapDecoder struct {
+ mapType *reflect2.UnsafeMapType
+ elemType reflect2.Type
+ decoder ValDecoder
+}
+
+func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
+ if d.mapType.UnsafeIsNil(ptr) {
+ d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
+ }
+
+ for {
+ l, _ := r.ReadBlockHeader()
+ if l == 0 {
+ break
+ }
+
+ for range l {
+ keyPtr := reflect2.PtrOf(r.ReadString())
+ elemPtr := d.elemType.UnsafeNew()
+ d.decoder.Decode(elemPtr, r)
+ if r.Error != nil {
+ r.Error = fmt.Errorf("reading map[string]%s: %w", d.elemType.String(), r.Error)
+ return
+ }
+
+ d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
+ }
+ }
+
+ if r.Error != nil && !errors.Is(r.Error, io.EOF) {
+ r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
+ }
+}
+
+func decoderOfMapUnmarshaler(d *decoderContext, m *MapSchema, typ reflect2.Type) ValDecoder {
+ mapType := typ.(*reflect2.UnsafeMapType)
+ decoder := decoderOfType(d, m.Values(), mapType.Elem())
+
+ return &mapDecoderUnmarshaler{
+ mapType: mapType,
+ keyType: mapType.Key(),
+ elemType: mapType.Elem(),
+ decoder: decoder,
+ }
+}
+
+type mapDecoderUnmarshaler struct {
+ mapType *reflect2.UnsafeMapType
+ keyType reflect2.Type
+ elemType reflect2.Type
+ decoder ValDecoder
+}
+
+func (d *mapDecoderUnmarshaler) Decode(ptr unsafe.Pointer, r *Reader) {
+ if d.mapType.UnsafeIsNil(ptr) {
+ d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
+ }
+
+ for {
+ l, _ := r.ReadBlockHeader()
+ if l == 0 {
+ break
+ }
+
+ for range l {
+ keyPtr := d.keyType.UnsafeNew()
+ keyObj := d.keyType.UnsafeIndirect(keyPtr)
+ if reflect2.IsNil(keyObj) {
+ ptrType := d.keyType.(*reflect2.UnsafePtrType)
+ newPtr := ptrType.Elem().UnsafeNew()
+ *((*unsafe.Pointer)(keyPtr)) = newPtr
+ keyObj = d.keyType.UnsafeIndirect(keyPtr)
+ }
+ unmarshaler := keyObj.(encoding.TextUnmarshaler)
+ err := unmarshaler.UnmarshalText([]byte(r.ReadString()))
+ if err != nil {
+ r.ReportError("mapDecoderUnmarshaler", err.Error())
+ return
+ }
+
+ elemPtr := d.elemType.UnsafeNew()
+ d.decoder.Decode(elemPtr, r)
+
+ d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
+ }
+ }
+
+ if r.Error != nil && !errors.Is(r.Error, io.EOF) {
+ r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
+ }
+}
+
+func encoderOfMap(e *encoderContext, m *MapSchema, typ reflect2.Type) ValEncoder {
+ mapType := typ.(*reflect2.UnsafeMapType)
+ encoder := encoderOfType(e, m.Values(), mapType.Elem())
+
+ return &mapEncoder{
+ blockLength: e.cfg.getBlockLength(),
+ mapType: mapType,
+ encoder: encoder,
+ }
+}
+
+type mapEncoder struct {
+ blockLength int
+ mapType *reflect2.UnsafeMapType
+ encoder ValEncoder
+}
+
+func (e *mapEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
+ blockLength := e.blockLength
+
+ iter := e.mapType.UnsafeIterate(ptr)
+
+ for {
+ wrote := w.WriteBlockCB(func(w *Writer) int64 {
+ var i int
+ for i = 0; iter.HasNext() && i < blockLength; i++ {
+ keyPtr, elemPtr := iter.UnsafeNext()
+ w.WriteString(*((*string)(keyPtr)))
+ e.encoder.Encode(elemPtr, w)
+ }
+
+ return int64(i)
+ })
+
+ if wrote == 0 {
+ break
+ }
+ }
+
+ if w.Error != nil && !errors.Is(w.Error, io.EOF) {
+ w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error)
+ }
+}
+
+func encoderOfMapMarshaler(e *encoderContext, m *MapSchema, typ reflect2.Type) ValEncoder {
+ mapType := typ.(*reflect2.UnsafeMapType)
+ encoder := encoderOfType(e, m.Values(), mapType.Elem())
+
+ return &mapEncoderMarshaller{
+ blockLength: e.cfg.getBlockLength(),
+ mapType: mapType,
+ keyType: mapType.Key(),
+ encoder: encoder,
+ }
+}
+
+type mapEncoderMarshaller struct {
+ blockLength int
+ mapType *reflect2.UnsafeMapType
+ keyType reflect2.Type
+ encoder ValEncoder
+}
+
+func (e *mapEncoderMarshaller) Encode(ptr unsafe.Pointer, w *Writer) {
+ blockLength := e.blockLength
+
+ iter := e.mapType.UnsafeIterate(ptr)
+
+ for {
+ wrote := w.WriteBlockCB(func(w *Writer) int64 {
+ var i int
+ for i = 0; iter.HasNext() && i < blockLength; i++ {
+ keyPtr, elemPtr := iter.UnsafeNext()
+
+ obj := e.keyType.UnsafeIndirect(keyPtr)
+ if e.keyType.IsNullable() && reflect2.IsNil(obj) {
+ w.Error = errors.New("avro: mapEncoderMarshaller: encoding nil TextMarshaller")
+ return int64(0)
+ }
+ marshaler := (obj).(encoding.TextMarshaler)
+ b, err := marshaler.MarshalText()
+ if err != nil {
+ w.Error = err
+ return int64(0)
+ }
+ w.WriteString(string(b))
+
+ e.encoder.Encode(elemPtr, w)
+ }
+ return int64(i)
+ })
+
+ if wrote == 0 {
+ break
+ }
+ }
+
+ if w.Error != nil && !errors.Is(w.Error, io.EOF) {
+ w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error)
+ }
+}