diff options
Diffstat (limited to 'vendor/github.com/jhump/protoreflect/dynamic')
14 files changed, 7030 insertions, 0 deletions
diff --git a/vendor/github.com/jhump/protoreflect/dynamic/binary.go b/vendor/github.com/jhump/protoreflect/dynamic/binary.go new file mode 100644 index 0000000..39e077a --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/binary.go @@ -0,0 +1,193 @@ +package dynamic + +// Binary serialization and de-serialization for dynamic messages + +import ( + "fmt" + "io" + + "github.com/golang/protobuf/proto" + + "github.com/jhump/protoreflect/codec" +) + +// defaultDeterminism, if true, will mean that calls to Marshal will produce +// deterministic output. This is used to make the output of proto.Marshal(...) +// deterministic (since there is no way to have that convey determinism intent). +// **This is only used from tests.** +var defaultDeterminism = false + +// Marshal serializes this message to bytes, returning an error if the operation +// fails. The resulting bytes are in the standard protocol buffer binary format. +func (m *Message) Marshal() ([]byte, error) { + var b codec.Buffer + b.SetDeterministic(defaultDeterminism) + if err := m.marshal(&b); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// MarshalAppend behaves exactly the same as Marshal, except instead of allocating a +// new byte slice to marshal into, it uses the provided byte slice. The backing array +// for the returned byte slice *may* be the same as the one that was passed in, but +// it's not guaranteed as a new backing array will automatically be allocated if +// more bytes need to be written than the provided buffer has capacity for. +func (m *Message) MarshalAppend(b []byte) ([]byte, error) { + codedBuf := codec.NewBuffer(b) + codedBuf.SetDeterministic(defaultDeterminism) + if err := m.marshal(codedBuf); err != nil { + return nil, err + } + return codedBuf.Bytes(), nil +} + +// MarshalDeterministic serializes this message to bytes in a deterministic way, +// returning an error if the operation fails. This differs from Marshal in that +// map keys will be sorted before serializing to bytes. The protobuf spec does +// not define ordering for map entries, so Marshal will use standard Go map +// iteration order (which will be random). But for cases where determinism is +// more important than performance, use this method instead. +func (m *Message) MarshalDeterministic() ([]byte, error) { + var b codec.Buffer + b.SetDeterministic(true) + if err := m.marshal(&b); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// MarshalAppendDeterministic behaves exactly the same as MarshalDeterministic, +// except instead of allocating a new byte slice to marshal into, it uses the +// provided byte slice. The backing array for the returned byte slice *may* be +// the same as the one that was passed in, but it's not guaranteed as a new +// backing array will automatically be allocated if more bytes need to be written +// than the provided buffer has capacity for. +func (m *Message) MarshalAppendDeterministic(b []byte) ([]byte, error) { + codedBuf := codec.NewBuffer(b) + codedBuf.SetDeterministic(true) + if err := m.marshal(codedBuf); err != nil { + return nil, err + } + return codedBuf.Bytes(), nil +} + +func (m *Message) marshal(b *codec.Buffer) error { + if m.GetMessageDescriptor().GetMessageOptions().GetMessageSetWireFormat() { + return fmt.Errorf("%s is a message set; marshaling message sets is not implemented", m.GetMessageDescriptor().GetFullyQualifiedName()) + } + if err := m.marshalKnownFields(b); err != nil { + return err + } + return m.marshalUnknownFields(b) +} + +func (m *Message) marshalKnownFields(b *codec.Buffer) error { + for _, tag := range m.knownFieldTags() { + itag := int32(tag) + val := m.values[itag] + fd := m.FindFieldDescriptor(itag) + if fd == nil { + panic(fmt.Sprintf("Couldn't find field for tag %d", itag)) + } + if err := b.EncodeFieldValue(fd, val); err != nil { + return err + } + } + return nil +} + +func (m *Message) marshalUnknownFields(b *codec.Buffer) error { + for _, tag := range m.unknownFieldTags() { + itag := int32(tag) + sl := m.unknownFields[itag] + for _, u := range sl { + if err := b.EncodeTagAndWireType(itag, u.Encoding); err != nil { + return err + } + switch u.Encoding { + case proto.WireBytes: + if err := b.EncodeRawBytes(u.Contents); err != nil { + return err + } + case proto.WireStartGroup: + _, _ = b.Write(u.Contents) + if err := b.EncodeTagAndWireType(itag, proto.WireEndGroup); err != nil { + return err + } + case proto.WireFixed32: + if err := b.EncodeFixed32(u.Value); err != nil { + return err + } + case proto.WireFixed64: + if err := b.EncodeFixed64(u.Value); err != nil { + return err + } + case proto.WireVarint: + if err := b.EncodeVarint(u.Value); err != nil { + return err + } + default: + return codec.ErrBadWireType + } + } + } + return nil +} + +// Unmarshal de-serializes the message that is present in the given bytes into +// this message. It first resets the current message. It returns an error if the +// given bytes do not contain a valid encoding of this message type. +func (m *Message) Unmarshal(b []byte) error { + m.Reset() + if err := m.UnmarshalMerge(b); err != nil { + return err + } + return m.Validate() +} + +// UnmarshalMerge de-serializes the message that is present in the given bytes +// into this message. Unlike Unmarshal, it does not first reset the message, +// instead merging the data in the given bytes into the existing data in this +// message. +func (m *Message) UnmarshalMerge(b []byte) error { + return m.unmarshal(codec.NewBuffer(b), false) +} + +func (m *Message) unmarshal(buf *codec.Buffer, isGroup bool) error { + if m.GetMessageDescriptor().GetMessageOptions().GetMessageSetWireFormat() { + return fmt.Errorf("%s is a message set; unmarshaling message sets is not implemented", m.GetMessageDescriptor().GetFullyQualifiedName()) + } + for !buf.EOF() { + fd, val, err := buf.DecodeFieldValue(m.FindFieldDescriptor, m.mf) + if err != nil { + if err == codec.ErrWireTypeEndGroup { + if isGroup { + // finished parsing group + return nil + } + return codec.ErrBadWireType + } + return err + } + + if fd == nil { + if m.unknownFields == nil { + m.unknownFields = map[int32][]UnknownField{} + } + uv := val.(codec.UnknownField) + u := UnknownField{ + Encoding: uv.Encoding, + Value: uv.Value, + Contents: uv.Contents, + } + m.unknownFields[uv.Tag] = append(m.unknownFields[uv.Tag], u) + } else if err := mergeField(m, fd, val); err != nil { + return err + } + } + if isGroup { + return io.ErrUnexpectedEOF + } + return nil +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/doc.go b/vendor/github.com/jhump/protoreflect/dynamic/doc.go new file mode 100644 index 0000000..59b77eb --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/doc.go @@ -0,0 +1,167 @@ +// Package dynamic provides an implementation for a dynamic protobuf message. +// +// The dynamic message is essentially a message descriptor along with a map of +// tag numbers to values. It has a broad API for interacting with the message, +// including inspection and modification. Generally, most operations have two +// forms: a regular method that panics on bad input or error and a "Try" form +// of the method that will instead return an error. +// +// A dynamic message can optionally be constructed with a MessageFactory. The +// MessageFactory has various registries that may be used by the dynamic message, +// such as during de-serialization. The message factory is "inherited" by any +// other dynamic messages created, such as nested messages that are created +// during de-serialization. Similarly, any dynamic message created using +// MessageFactory.NewMessage will be associated with that factory, which in turn +// will be used to create other messages or parse extension fields during +// de-serialization. +// +// # Field Types +// +// The types of values expected by setters and returned by getters are the +// same as protoc generates for scalar fields. For repeated fields, there are +// methods for getting and setting values at a particular index or for adding +// an element. Similarly, for map fields, there are methods for getting and +// setting values for a particular key. +// +// If you use GetField for a repeated field, it will return a copy of all +// elements as a slice []interface{}. Similarly, using GetField for a map field +// will return a copy of all mappings as a map[interface{}]interface{}. You can +// also use SetField to supply an entire slice or map for repeated or map fields. +// The slice need not be []interface{} but can actually be typed according to +// the field's expected type. For example, a repeated uint64 field can be set +// using a slice of type []uint64. +// +// Descriptors for map fields describe them as repeated fields with a nested +// message type. The nested message type is a special generated type that +// represents a single mapping: key and value pair. The dynamic message has some +// special affordances for this representation. For example, you can use +// SetField to set a map field using a slice of these entry messages. Internally, +// the slice of entries will be converted to an actual map. Similarly, you can +// use AddRepeatedField with an entry message to add (or overwrite) a mapping. +// However, you cannot use GetRepeatedField or SetRepeatedField to modify maps, +// since those take numeric index arguments which are not relevant to maps +// (since maps in Go have no defined ordering). +// +// When setting field values in dynamic messages, the type-checking is lenient +// in that it accepts any named type with the right kind. So a string field can +// be assigned to any type that is defined as a string. Enum fields require +// int32 values (or any type that is defined as an int32). +// +// Unlike normal use of numeric values in Go, values will be automatically +// widened when assigned. So, for example, an int64 field can be set using an +// int32 value since it can be safely widened without truncation or loss of +// precision. Similar goes for uint32 values being converted to uint64 and +// float32 being converted to float64. Narrowing conversions are not done, +// however. Also, unsigned values will never be automatically converted to +// signed (and vice versa), and floating point values will never be +// automatically converted to integral values (and vice versa). Since the bit +// width of int and uint fields is allowed to be platform dependent, but will +// always be less than or equal to 64, they can only be used as values for +// int64 and uint64 fields, respectively. They cannot be used to set int32 or +// uint32 fields, which includes enums fields. +// +// Fields whose type is a nested message can have values set to either other +// dynamic messages or generated messages (e.g. pointers to structs generated by +// protoc). Getting a value for such a field will return the actual type it is +// set to (e.g. either a dynamic message or a generated message). If the value +// is not set and the message uses proto2 syntax, the default message returned +// will be whatever is returned by the dynamic message's MessageFactory (if the +// dynamic message was not created with a factory, it will use the logic of the +// zero value factory). In most typical cases, it will return a dynamic message, +// but if the factory is configured with a KnownTypeRegistry, or if the field's +// type is a well-known type, it will return a zero value generated message. +// +// # Unrecognized Fields +// +// Unrecognized fields are preserved by the dynamic message when unmarshaling +// from the standard binary format. If the message's MessageFactory was +// configured with an ExtensionRegistry, it will be used to identify and parse +// extension fields for the message. +// +// Unrecognized fields can dynamically become recognized fields if the +// application attempts to retrieve an unrecognized field's value using a +// FieldDescriptor. In this case, the given FieldDescriptor is used to parse the +// unknown field and move the parsed value into the message's set of known +// fields. This behavior is most suited to the use of extensions, where an +// ExtensionRegistry is not setup with all known extensions ahead of time. But +// it can even happen for non-extension fields! Here's an example scenario where +// a non-extension field can initially be unknown and become known: +// +// 1. A dynamic message is created with a descriptor, A, and then +// de-serialized from a stream of bytes. The stream includes an +// unrecognized tag T. The message will include tag T in its unrecognized +// field set. +// 2. Another call site retrieves a newer descriptor, A', which includes a +// newly added field with tag T. +// 3. That other call site then uses a FieldDescriptor to access the value of +// the new field. This will cause the dynamic message to parse the bytes +// for the unknown tag T and store them as a known field. +// 4. Subsequent operations for tag T, including setting the field using only +// tag number or de-serializing a stream that includes tag T, will operate +// as if that tag were part of the original descriptor, A. +// +// # Compatibility +// +// In addition to implementing the proto.Message interface, the included +// Message type also provides an XXX_MessageName() method, so it can work with +// proto.MessageName. And it provides a Descriptor() method that behaves just +// like the method of the same signature in messages generated by protoc. +// Because of this, it is actually compatible with proto.Message in many (though +// not all) contexts. In particular, it is compatible with proto.Marshal and +// proto.Unmarshal for serializing and de-serializing messages. +// +// The dynamic message supports binary and text marshaling, using protobuf's +// well-defined binary format and the same text format that protoc-generated +// types use. It also supports JSON serialization/de-serialization by +// implementing the json.Marshaler and json.Unmarshaler interfaces. And dynamic +// messages can safely be used with the jsonpb package for JSON serialization +// and de-serialization. +// +// In addition to implementing the proto.Message interface and numerous related +// methods, it also provides inter-op with generated messages via conversion. +// The ConvertTo, ConvertFrom, MergeInto, and MergeFrom methods copy message +// contents from a dynamic message to a generated message and vice versa. +// +// When copying from a generated message into a dynamic message, if the +// generated message contains fields unknown to the dynamic message (e.g. not +// present in the descriptor used to create the dynamic message), these fields +// become known to the dynamic message (as per behavior described above in +// "Unrecognized Fields"). If the generated message has unrecognized fields of +// its own, including unrecognized extensions, they are preserved in the dynamic +// message. It is possible that the dynamic message knows about fields that the +// generated message did not, like if it has a different version of the +// descriptor or its MessageFactory has an ExtensionRegistry that knows about +// different extensions than were linked into the program. In this case, these +// unrecognized fields in the generated message will be known fields in the +// dynamic message. +// +// Similarly, when copying from a dynamic message into a generated message, if +// the dynamic message has unrecognized fields they can be preserved in the +// generated message (currently only for syntax proto2 since proto3 generated +// messages do not preserve unrecognized fields). If the generated message knows +// about fields that the dynamic message does not, these unrecognized fields may +// become known fields in the generated message. +// +// # Registries +// +// This package also contains a couple of registries, for managing known types +// and descriptors. +// +// The KnownTypeRegistry allows de-serialization of a dynamic message to use +// generated message types, instead of dynamic messages, for some kinds of +// nested message fields. This is particularly useful for working with proto +// messages that have special encodings as JSON (e.g. the well-known types), +// since the dynamic message does not try to handle these special cases in its +// JSON marshaling facilities. +// +// The ExtensionRegistry allows for recognizing and parsing extensions fields +// (for proto2 messages). +// +// Deprecated: This module was created for use with the older "v1" Protobuf API +// in github.com/golang/protobuf. However, much of this module is no longer +// necessary as the newer "v2" API in google.golang.org/protobuf provides similar +// capabilities. Instead of using this github.com/jhump/protoreflect/dynamic package, +// see [google.golang.org/protobuf/types/dynamicpb]. +// +// [google.golang.org/protobuf/types/dynamicpb]: https://pkg.go.dev/google.golang.org/protobuf/types/dynamicpb +package dynamic diff --git a/vendor/github.com/jhump/protoreflect/dynamic/dynamic_message.go b/vendor/github.com/jhump/protoreflect/dynamic/dynamic_message.go new file mode 100644 index 0000000..ff136b0 --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/dynamic_message.go @@ -0,0 +1,2830 @@ +package dynamic + +import ( + "bytes" + "compress/gzip" + "errors" + "fmt" + "reflect" + "sort" + "strings" + + "github.com/golang/protobuf/proto" + protov2 "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" + + "github.com/jhump/protoreflect/codec" + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/internal" +) + +// ErrUnknownTagNumber is an error that is returned when an operation refers +// to an unknown tag number. +var ErrUnknownTagNumber = errors.New("unknown tag number") + +// UnknownTagNumberError is the same as ErrUnknownTagNumber. +// Deprecated: use ErrUnknownTagNumber +var UnknownTagNumberError = ErrUnknownTagNumber + +// ErrUnknownFieldName is an error that is returned when an operation refers +// to an unknown field name. +var ErrUnknownFieldName = errors.New("unknown field name") + +// UnknownFieldNameError is the same as ErrUnknownFieldName. +// Deprecated: use ErrUnknownFieldName +var UnknownFieldNameError = ErrUnknownFieldName + +// ErrFieldIsNotMap is an error that is returned when map-related operations +// are attempted with fields that are not maps. +var ErrFieldIsNotMap = errors.New("field is not a map type") + +// FieldIsNotMapError is the same as ErrFieldIsNotMap. +// Deprecated: use ErrFieldIsNotMap +var FieldIsNotMapError = ErrFieldIsNotMap + +// ErrFieldIsNotRepeated is an error that is returned when repeated field +// operations are attempted with fields that are not repeated. +var ErrFieldIsNotRepeated = errors.New("field is not repeated") + +// FieldIsNotRepeatedError is the same as ErrFieldIsNotRepeated. +// Deprecated: use ErrFieldIsNotRepeated +var FieldIsNotRepeatedError = ErrFieldIsNotRepeated + +// ErrIndexOutOfRange is an error that is returned when an invalid index is +// provided when access a single element of a repeated field. +var ErrIndexOutOfRange = errors.New("index is out of range") + +// IndexOutOfRangeError is the same as ErrIndexOutOfRange. +// Deprecated: use ErrIndexOutOfRange +var IndexOutOfRangeError = ErrIndexOutOfRange + +// ErrNumericOverflow is an error returned by operations that encounter a +// numeric value that is too large, for example de-serializing a value into an +// int32 field when the value is larger that can fit into a 32-bit value. +var ErrNumericOverflow = errors.New("numeric value is out of range") + +// NumericOverflowError is the same as ErrNumericOverflow. +// Deprecated: use ErrNumericOverflow +var NumericOverflowError = ErrNumericOverflow + +var typeOfProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem() +var typeOfDynamicMessage = reflect.TypeOf((*Message)(nil)) +var typeOfBytes = reflect.TypeOf(([]byte)(nil)) + +// Message is a dynamic protobuf message. Instead of a generated struct, +// like most protobuf messages, this is a map of field number to values and +// a message descriptor, which is used to validate the field values and +// also to de-serialize messages (from the standard binary format, as well +// as from the text format and from JSON). +type Message struct { + md *desc.MessageDescriptor + er *ExtensionRegistry + mf *MessageFactory + extraFields map[int32]*desc.FieldDescriptor + values map[int32]interface{} + unknownFields map[int32][]UnknownField +} + +// UnknownField represents a field that was parsed from the binary wire +// format for a message, but was not a recognized field number. Enough +// information is preserved so that re-serializing the message won't lose +// any of the unrecognized data. +type UnknownField struct { + // Encoding indicates how the unknown field was encoded on the wire. If it + // is proto.WireBytes or proto.WireGroupStart then Contents will be set to + // the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least + // significant 32 bits of Value. Otherwise, the data is in all 64 bits of + // Value. + Encoding int8 + Contents []byte + Value uint64 +} + +// NewMessage creates a new dynamic message for the type represented by the given +// message descriptor. During de-serialization, a default MessageFactory is used to +// instantiate any nested message fields and no extension fields will be parsed. To +// use a custom MessageFactory or ExtensionRegistry, use MessageFactory.NewMessage. +func NewMessage(md *desc.MessageDescriptor) *Message { + return NewMessageWithMessageFactory(md, nil) +} + +// NewMessageWithExtensionRegistry creates a new dynamic message for the type +// represented by the given message descriptor. During de-serialization, the given +// ExtensionRegistry is used to parse extension fields and nested messages will be +// instantiated using dynamic.NewMessageFactoryWithExtensionRegistry(er). +func NewMessageWithExtensionRegistry(md *desc.MessageDescriptor, er *ExtensionRegistry) *Message { + mf := NewMessageFactoryWithExtensionRegistry(er) + return NewMessageWithMessageFactory(md, mf) +} + +// NewMessageWithMessageFactory creates a new dynamic message for the type +// represented by the given message descriptor. During de-serialization, the given +// MessageFactory is used to instantiate nested messages. +func NewMessageWithMessageFactory(md *desc.MessageDescriptor, mf *MessageFactory) *Message { + var er *ExtensionRegistry + if mf != nil { + er = mf.er + } + return &Message{ + md: md, + mf: mf, + er: er, + } +} + +// AsDynamicMessage converts the given message to a dynamic message. If the +// given message is dynamic, it is returned. Otherwise, a dynamic message is +// created using NewMessage. +func AsDynamicMessage(msg proto.Message) (*Message, error) { + return AsDynamicMessageWithMessageFactory(msg, nil) +} + +// AsDynamicMessageWithExtensionRegistry converts the given message to a dynamic +// message. If the given message is dynamic, it is returned. Otherwise, a +// dynamic message is created using NewMessageWithExtensionRegistry. +func AsDynamicMessageWithExtensionRegistry(msg proto.Message, er *ExtensionRegistry) (*Message, error) { + mf := NewMessageFactoryWithExtensionRegistry(er) + return AsDynamicMessageWithMessageFactory(msg, mf) +} + +// AsDynamicMessageWithMessageFactory converts the given message to a dynamic +// message. If the given message is dynamic, it is returned. Otherwise, a +// dynamic message is created using NewMessageWithMessageFactory. +func AsDynamicMessageWithMessageFactory(msg proto.Message, mf *MessageFactory) (*Message, error) { + if dm, ok := msg.(*Message); ok { + return dm, nil + } + md, err := desc.LoadMessageDescriptorForMessage(msg) + if err != nil { + return nil, err + } + dm := NewMessageWithMessageFactory(md, mf) + err = dm.mergeFrom(msg) + if err != nil { + return nil, err + } + return dm, nil +} + +// GetMessageDescriptor returns a descriptor for this message's type. +func (m *Message) GetMessageDescriptor() *desc.MessageDescriptor { + return m.md +} + +// GetKnownFields returns a slice of descriptors for all known fields. The +// fields will not be in any defined order. +func (m *Message) GetKnownFields() []*desc.FieldDescriptor { + if len(m.extraFields) == 0 { + return m.md.GetFields() + } + flds := make([]*desc.FieldDescriptor, len(m.md.GetFields()), len(m.md.GetFields())+len(m.extraFields)) + copy(flds, m.md.GetFields()) + for _, fld := range m.extraFields { + if !fld.IsExtension() { + flds = append(flds, fld) + } + } + return flds +} + +// GetKnownExtensions returns a slice of descriptors for all extensions known by +// the message's extension registry. The fields will not be in any defined order. +func (m *Message) GetKnownExtensions() []*desc.FieldDescriptor { + if !m.md.IsExtendable() { + return nil + } + exts := m.er.AllExtensionsForType(m.md.GetFullyQualifiedName()) + for _, fld := range m.extraFields { + if fld.IsExtension() { + exts = append(exts, fld) + } + } + return exts +} + +// GetUnknownFields returns a slice of tag numbers for all unknown fields that +// this message contains. The tags will not be in any defined order. +func (m *Message) GetUnknownFields() []int32 { + flds := make([]int32, 0, len(m.unknownFields)) + for tag := range m.unknownFields { + flds = append(flds, tag) + } + return flds +} + +// Descriptor returns the serialized form of the file descriptor in which the +// message was defined and a path to the message type therein. This mimics the +// method of the same name on message types generated by protoc. +func (m *Message) Descriptor() ([]byte, []int) { + // get encoded file descriptor + b, err := proto.Marshal(m.md.GetFile().AsProto()) + if err != nil { + panic(fmt.Sprintf("failed to get encoded descriptor for %s: %v", m.md.GetFile().GetName(), err)) + } + var zippedBytes bytes.Buffer + w := gzip.NewWriter(&zippedBytes) + if _, err := w.Write(b); err != nil { + panic(fmt.Sprintf("failed to get encoded descriptor for %s: %v", m.md.GetFile().GetName(), err)) + } + if err := w.Close(); err != nil { + panic(fmt.Sprintf("failed to get an encoded descriptor for %s: %v", m.md.GetFile().GetName(), err)) + } + + // and path to message + path := []int{} + var d desc.Descriptor + name := m.md.GetFullyQualifiedName() + for d = m.md.GetParent(); d != nil; name, d = d.GetFullyQualifiedName(), d.GetParent() { + found := false + switch d := d.(type) { + case (*desc.FileDescriptor): + for i, md := range d.GetMessageTypes() { + if md.GetFullyQualifiedName() == name { + found = true + path = append(path, i) + } + } + case (*desc.MessageDescriptor): + for i, md := range d.GetNestedMessageTypes() { + if md.GetFullyQualifiedName() == name { + found = true + path = append(path, i) + } + } + } + if !found { + panic(fmt.Sprintf("failed to compute descriptor path for %s", m.md.GetFullyQualifiedName())) + } + } + // reverse the path + i := 0 + j := len(path) - 1 + for i < j { + path[i], path[j] = path[j], path[i] + i++ + j-- + } + + return zippedBytes.Bytes(), path +} + +// XXX_MessageName returns the fully qualified name of this message's type. This +// allows dynamic messages to be used with proto.MessageName. +func (m *Message) XXX_MessageName() string { + return m.md.GetFullyQualifiedName() +} + +// FindFieldDescriptor returns a field descriptor for the given tag number. This +// searches known fields in the descriptor, known fields discovered during calls +// to GetField or SetField, and extension fields known by the message's extension +// registry. It returns nil if the tag is unknown. +func (m *Message) FindFieldDescriptor(tagNumber int32) *desc.FieldDescriptor { + fd := m.md.FindFieldByNumber(tagNumber) + if fd != nil { + return fd + } + fd = m.er.FindExtension(m.md.GetFullyQualifiedName(), tagNumber) + if fd != nil { + return fd + } + return m.extraFields[tagNumber] +} + +// FindFieldDescriptorByName returns a field descriptor for the given field +// name. This searches known fields in the descriptor, known fields discovered +// during calls to GetField or SetField, and extension fields known by the +// message's extension registry. It returns nil if the name is unknown. If the +// given name refers to an extension, it should be fully qualified and may be +// optionally enclosed in parentheses or brackets. +func (m *Message) FindFieldDescriptorByName(name string) *desc.FieldDescriptor { + if name == "" { + return nil + } + fd := m.md.FindFieldByName(name) + if fd != nil { + return fd + } + mustBeExt := false + if name[0] == '(' { + if name[len(name)-1] != ')' { + // malformed name + return nil + } + mustBeExt = true + name = name[1 : len(name)-1] + } else if name[0] == '[' { + if name[len(name)-1] != ']' { + // malformed name + return nil + } + mustBeExt = true + name = name[1 : len(name)-1] + } + fd = m.er.FindExtensionByName(m.md.GetFullyQualifiedName(), name) + if fd != nil { + return fd + } + for _, fd := range m.extraFields { + if fd.IsExtension() && name == fd.GetFullyQualifiedName() { + return fd + } else if !mustBeExt && !fd.IsExtension() && name == fd.GetName() { + return fd + } + } + + return nil +} + +// FindFieldDescriptorByJSONName returns a field descriptor for the given JSON +// name. This searches known fields in the descriptor, known fields discovered +// during calls to GetField or SetField, and extension fields known by the +// message's extension registry. If no field matches the given JSON name, it +// will fall back to searching field names (e.g. FindFieldDescriptorByName). If +// this also yields no match, nil is returned. +func (m *Message) FindFieldDescriptorByJSONName(name string) *desc.FieldDescriptor { + if name == "" { + return nil + } + fd := m.md.FindFieldByJSONName(name) + if fd != nil { + return fd + } + mustBeExt := false + if name[0] == '(' { + if name[len(name)-1] != ')' { + // malformed name + return nil + } + mustBeExt = true + name = name[1 : len(name)-1] + } else if name[0] == '[' { + if name[len(name)-1] != ']' { + // malformed name + return nil + } + mustBeExt = true + name = name[1 : len(name)-1] + } + fd = m.er.FindExtensionByJSONName(m.md.GetFullyQualifiedName(), name) + if fd != nil { + return fd + } + for _, fd := range m.extraFields { + if fd.IsExtension() && name == fd.GetFullyQualifiedJSONName() { + return fd + } else if !mustBeExt && !fd.IsExtension() && name == fd.GetJSONName() { + return fd + } + } + + // try non-JSON names + return m.FindFieldDescriptorByName(name) +} + +func (m *Message) checkField(fd *desc.FieldDescriptor) error { + return checkField(fd, m.md) +} + +func checkField(fd *desc.FieldDescriptor, md *desc.MessageDescriptor) error { + if fd.GetOwner().GetFullyQualifiedName() != md.GetFullyQualifiedName() { + return fmt.Errorf("given field, %s, is for wrong message type: %s; expecting %s", fd.GetName(), fd.GetOwner().GetFullyQualifiedName(), md.GetFullyQualifiedName()) + } + if fd.IsExtension() && !md.IsExtension(fd.GetNumber()) { + return fmt.Errorf("given field, %s, is an extension but is not in message extension range: %v", fd.GetFullyQualifiedName(), md.GetExtensionRanges()) + } + return nil +} + +// GetField returns the value for the given field descriptor. It panics if an +// error is encountered. See TryGetField. +func (m *Message) GetField(fd *desc.FieldDescriptor) interface{} { + if v, err := m.TryGetField(fd); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetField returns the value for the given field descriptor. An error is +// returned if the given field descriptor does not belong to the right message +// type. +// +// The Go type of the returned value, for scalar fields, is the same as protoc +// would generate for the field (in a non-dynamic message). The table below +// lists the scalar types and the corresponding Go types. +// +// +-------------------------+-----------+ +// | Declared Type | Go Type | +// +-------------------------+-----------+ +// | int32, sint32, sfixed32 | int32 | +// | int64, sint64, sfixed64 | int64 | +// | uint32, fixed32 | uint32 | +// | uint64, fixed64 | uint64 | +// | float | float32 | +// | double | double32 | +// | bool | bool | +// | string | string | +// | bytes | []byte | +// +-------------------------+-----------+ +// +// Values for enum fields will always be int32 values. You can use the enum +// descriptor associated with the field to lookup value names with those values. +// Values for message type fields may be an instance of the generated type *or* +// may be another *dynamic.Message that represents the type. +// +// If the given field is a map field, the returned type will be +// map[interface{}]interface{}. The actual concrete types of keys and values is +// as described above. If the given field is a (non-map) repeated field, the +// returned type is always []interface{}; the type of the actual elements is as +// described above. +// +// If this message has no value for the given field, its default value is +// returned. If the message is defined in a file with "proto3" syntax, the +// default is always the zero value for the field. The default value for map and +// repeated fields is a nil map or slice (respectively). For field's whose types +// is a message, the default value is an empty message for "proto2" syntax or a +// nil message for "proto3" syntax. Note that the in the latter case, a non-nil +// interface with a nil pointer is returned, not a nil interface. Also note that +// whether the returned value is an empty message or nil depends on if *this* +// message was defined as "proto3" syntax, not the message type referred to by +// the field's type. +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) but corresponds to an unknown field, the unknown value will be +// parsed and become known. The parsed value will be returned, or an error will +// be returned if the unknown value cannot be parsed according to the field +// descriptor's type information. +func (m *Message) TryGetField(fd *desc.FieldDescriptor) (interface{}, error) { + if err := m.checkField(fd); err != nil { + return nil, err + } + return m.getField(fd) +} + +// GetFieldByName returns the value for the field with the given name. It panics +// if an error is encountered. See TryGetFieldByName. +func (m *Message) GetFieldByName(name string) interface{} { + if v, err := m.TryGetFieldByName(name); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetFieldByName returns the value for the field with the given name. An +// error is returned if the given name is unknown. If the given name refers to +// an extension field, it should be fully qualified and optionally enclosed in +// parenthesis or brackets. +// +// If this message has no value for the given field, its default value is +// returned. (See TryGetField for more info on types and default field values.) +func (m *Message) TryGetFieldByName(name string) (interface{}, error) { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return nil, UnknownFieldNameError + } + return m.getField(fd) +} + +// GetFieldByNumber returns the value for the field with the given tag number. +// It panics if an error is encountered. See TryGetFieldByNumber. +func (m *Message) GetFieldByNumber(tagNumber int) interface{} { + if v, err := m.TryGetFieldByNumber(tagNumber); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetFieldByNumber returns the value for the field with the given tag +// number. An error is returned if the given tag is unknown. +// +// If this message has no value for the given field, its default value is +// returned. (See TryGetField for more info on types and default field values.) +func (m *Message) TryGetFieldByNumber(tagNumber int) (interface{}, error) { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return nil, UnknownTagNumberError + } + return m.getField(fd) +} + +func (m *Message) getField(fd *desc.FieldDescriptor) (interface{}, error) { + return m.doGetField(fd, false) +} + +func (m *Message) doGetField(fd *desc.FieldDescriptor, nilIfAbsent bool) (interface{}, error) { + res := m.values[fd.GetNumber()] + if res == nil { + var err error + if res, err = m.parseUnknownField(fd); err != nil { + return nil, err + } + if res == nil { + if nilIfAbsent { + return nil, nil + } else { + def := fd.GetDefaultValue() + if def != nil { + return def, nil + } + // GetDefaultValue only returns nil for message types + md := fd.GetMessageType() + if m.md.IsProto3() { + return nilMessage(md), nil + } else { + // for proto2, return default instance of message + return m.mf.NewMessage(md), nil + } + } + } + } + rt := reflect.TypeOf(res) + if rt.Kind() == reflect.Map { + // make defensive copies to prevent caller from storing illegal keys and values + m := res.(map[interface{}]interface{}) + res := map[interface{}]interface{}{} + for k, v := range m { + res[k] = v + } + return res, nil + } else if rt.Kind() == reflect.Slice && rt != typeOfBytes { + // make defensive copies to prevent caller from storing illegal elements + sl := res.([]interface{}) + res := make([]interface{}, len(sl)) + copy(res, sl) + return res, nil + } + return res, nil +} + +func nilMessage(md *desc.MessageDescriptor) interface{} { + // try to return a proper nil pointer + msgType := proto.MessageType(md.GetFullyQualifiedName()) + if msgType != nil && msgType.Implements(typeOfProtoMessage) { + return reflect.Zero(msgType).Interface().(proto.Message) + } + // fallback to nil dynamic message pointer + return (*Message)(nil) +} + +// HasField returns true if this message has a value for the given field. If the +// given field is not valid (e.g. belongs to a different message type), false is +// returned. If this message is defined in a file with "proto3" syntax, this +// will return false even if a field was explicitly assigned its zero value (the +// zero values for a field are intentionally indistinguishable from absent). +func (m *Message) HasField(fd *desc.FieldDescriptor) bool { + if err := m.checkField(fd); err != nil { + return false + } + return m.HasFieldNumber(int(fd.GetNumber())) +} + +// HasFieldName returns true if this message has a value for a field with the +// given name. If the given name is unknown, this returns false. +func (m *Message) HasFieldName(name string) bool { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return false + } + return m.HasFieldNumber(int(fd.GetNumber())) +} + +// HasFieldNumber returns true if this message has a value for a field with the +// given tag number. If the given tag is unknown, this returns false. +func (m *Message) HasFieldNumber(tagNumber int) bool { + if _, ok := m.values[int32(tagNumber)]; ok { + return true + } + _, ok := m.unknownFields[int32(tagNumber)] + return ok +} + +// SetField sets the value for the given field descriptor to the given value. It +// panics if an error is encountered. See TrySetField. +func (m *Message) SetField(fd *desc.FieldDescriptor, val interface{}) { + if err := m.TrySetField(fd, val); err != nil { + panic(err.Error()) + } +} + +// TrySetField sets the value for the given field descriptor to the given value. +// An error is returned if the given field descriptor does not belong to the +// right message type or if the given value is not a correct/compatible type for +// the given field. +// +// The Go type expected for a field is the same as TryGetField would return for +// the field. So message values can be supplied as either the correct generated +// message type or as a *dynamic.Message. +// +// Since it is cumbersome to work with dynamic messages, some concessions are +// made to simplify usage regarding types: +// +// 1. If a numeric type is provided that can be converted *without loss or +// overflow*, it is accepted. This allows for setting int64 fields using int +// or int32 values. Similarly for uint64 with uint and uint32 values and for +// float64 fields with float32 values. +// 2. The value can be a named type, as long as its underlying type is correct. +// 3. Map and repeated fields can be set using any kind of concrete map or +// slice type, as long as the values within are all of the correct type. So +// a field defined as a 'map<string, int32>` can be set using a +// map[string]int32, a map[string]interface{}, or even a +// map[interface{}]interface{}. +// 4. Finally, dynamic code that chooses to not treat maps as a special-case +// find that they can set map fields using a slice where each element is a +// message that matches the implicit map-entry field message type. +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) it will become known. Subsequent operations using tag numbers or +// names will be able to resolve the newly-known type. If the message has a +// value for the unknown value, it is cleared, replaced by the given known +// value. +func (m *Message) TrySetField(fd *desc.FieldDescriptor, val interface{}) error { + if err := m.checkField(fd); err != nil { + return err + } + return m.setField(fd, val) +} + +// SetFieldByName sets the value for the field with the given name to the given +// value. It panics if an error is encountered. See TrySetFieldByName. +func (m *Message) SetFieldByName(name string, val interface{}) { + if err := m.TrySetFieldByName(name, val); err != nil { + panic(err.Error()) + } +} + +// TrySetFieldByName sets the value for the field with the given name to the +// given value. An error is returned if the given name is unknown or if the +// given value has an incorrect type. If the given name refers to an extension +// field, it should be fully qualified and optionally enclosed in parenthesis or +// brackets. +// +// (See TrySetField for more info on types.) +func (m *Message) TrySetFieldByName(name string, val interface{}) error { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return UnknownFieldNameError + } + return m.setField(fd, val) +} + +// SetFieldByNumber sets the value for the field with the given tag number to +// the given value. It panics if an error is encountered. See +// TrySetFieldByNumber. +func (m *Message) SetFieldByNumber(tagNumber int, val interface{}) { + if err := m.TrySetFieldByNumber(tagNumber, val); err != nil { + panic(err.Error()) + } +} + +// TrySetFieldByNumber sets the value for the field with the given tag number to +// the given value. An error is returned if the given tag is unknown or if the +// given value has an incorrect type. +// +// (See TrySetField for more info on types.) +func (m *Message) TrySetFieldByNumber(tagNumber int, val interface{}) error { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return UnknownTagNumberError + } + return m.setField(fd, val) +} + +func (m *Message) setField(fd *desc.FieldDescriptor, val interface{}) error { + var err error + if val, err = validFieldValue(fd, val); err != nil { + return err + } + m.internalSetField(fd, val) + return nil +} + +func (m *Message) internalSetField(fd *desc.FieldDescriptor, val interface{}) { + if fd.IsRepeated() { + // Unset fields and zero-length fields are indistinguishable, in both + // proto2 and proto3 syntax + if reflect.ValueOf(val).Len() == 0 { + if m.values != nil { + delete(m.values, fd.GetNumber()) + } + return + } + } else if m.md.IsProto3() && fd.GetOneOf() == nil { + // proto3 considers fields that are set to their zero value as unset + // (we already handled repeated fields above) + var equal bool + if b, ok := val.([]byte); ok { + // can't compare slices, so we have to special-case []byte values + equal = ok && bytes.Equal(b, fd.GetDefaultValue().([]byte)) + } else { + defVal := fd.GetDefaultValue() + equal = defVal == val + if !equal && defVal == nil { + // above just checks if value is the nil interface, + // but we should also test if the given value is a + // nil pointer + rv := reflect.ValueOf(val) + if rv.Kind() == reflect.Ptr && rv.IsNil() { + equal = true + } + } + } + if equal { + if m.values != nil { + delete(m.values, fd.GetNumber()) + } + return + } + } + if m.values == nil { + m.values = map[int32]interface{}{} + } + m.values[fd.GetNumber()] = val + // if this field is part of a one-of, make sure all other one-of choices are cleared + od := fd.GetOneOf() + if od != nil { + for _, other := range od.GetChoices() { + if other.GetNumber() != fd.GetNumber() { + delete(m.values, other.GetNumber()) + } + } + } + // also clear any unknown fields + if m.unknownFields != nil { + delete(m.unknownFields, fd.GetNumber()) + } + // and add this field if it was previously unknown + if existing := m.FindFieldDescriptor(fd.GetNumber()); existing == nil { + m.addField(fd) + } +} + +func (m *Message) addField(fd *desc.FieldDescriptor) { + if m.extraFields == nil { + m.extraFields = map[int32]*desc.FieldDescriptor{} + } + m.extraFields[fd.GetNumber()] = fd +} + +// ClearField removes any value for the given field. It panics if an error is +// encountered. See TryClearField. +func (m *Message) ClearField(fd *desc.FieldDescriptor) { + if err := m.TryClearField(fd); err != nil { + panic(err.Error()) + } +} + +// TryClearField removes any value for the given field. An error is returned if +// the given field descriptor does not belong to the right message type. +func (m *Message) TryClearField(fd *desc.FieldDescriptor) error { + if err := m.checkField(fd); err != nil { + return err + } + m.clearField(fd) + return nil +} + +// ClearFieldByName removes any value for the field with the given name. It +// panics if an error is encountered. See TryClearFieldByName. +func (m *Message) ClearFieldByName(name string) { + if err := m.TryClearFieldByName(name); err != nil { + panic(err.Error()) + } +} + +// TryClearFieldByName removes any value for the field with the given name. An +// error is returned if the given name is unknown. If the given name refers to +// an extension field, it should be fully qualified and optionally enclosed in +// parenthesis or brackets. +func (m *Message) TryClearFieldByName(name string) error { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return UnknownFieldNameError + } + m.clearField(fd) + return nil +} + +// ClearFieldByNumber removes any value for the field with the given tag number. +// It panics if an error is encountered. See TryClearFieldByNumber. +func (m *Message) ClearFieldByNumber(tagNumber int) { + if err := m.TryClearFieldByNumber(tagNumber); err != nil { + panic(err.Error()) + } +} + +// TryClearFieldByNumber removes any value for the field with the given tag +// number. An error is returned if the given tag is unknown. +func (m *Message) TryClearFieldByNumber(tagNumber int) error { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return UnknownTagNumberError + } + m.clearField(fd) + return nil +} + +func (m *Message) clearField(fd *desc.FieldDescriptor) { + // clear value + if m.values != nil { + delete(m.values, fd.GetNumber()) + } + // also clear any unknown fields + if m.unknownFields != nil { + delete(m.unknownFields, fd.GetNumber()) + } + // and add this field if it was previously unknown + if existing := m.FindFieldDescriptor(fd.GetNumber()); existing == nil { + m.addField(fd) + } +} + +// GetOneOfField returns which of the given one-of's fields is set and the +// corresponding value. It panics if an error is encountered. See +// TryGetOneOfField. +func (m *Message) GetOneOfField(od *desc.OneOfDescriptor) (*desc.FieldDescriptor, interface{}) { + if fd, val, err := m.TryGetOneOfField(od); err != nil { + panic(err.Error()) + } else { + return fd, val + } +} + +// TryGetOneOfField returns which of the given one-of's fields is set and the +// corresponding value. An error is returned if the given one-of belongs to the +// wrong message type. If the given one-of has no field set, this method will +// return nil, nil. +// +// The type of the value, if one is set, is the same as would be returned by +// TryGetField using the returned field descriptor. +// +// Like with TryGetField, if the given one-of contains any fields that are not +// known (e.g. not present in this message's descriptor), they will become known +// and any unknown value will be parsed (and become a known value on success). +func (m *Message) TryGetOneOfField(od *desc.OneOfDescriptor) (*desc.FieldDescriptor, interface{}, error) { + if od.GetOwner().GetFullyQualifiedName() != m.md.GetFullyQualifiedName() { + return nil, nil, fmt.Errorf("given one-of, %s, is for wrong message type: %s; expecting %s", od.GetName(), od.GetOwner().GetFullyQualifiedName(), m.md.GetFullyQualifiedName()) + } + for _, fd := range od.GetChoices() { + val, err := m.doGetField(fd, true) + if err != nil { + return nil, nil, err + } + if val != nil { + return fd, val, nil + } + } + return nil, nil, nil +} + +// ClearOneOfField removes any value for any of the given one-of's fields. It +// panics if an error is encountered. See TryClearOneOfField. +func (m *Message) ClearOneOfField(od *desc.OneOfDescriptor) { + if err := m.TryClearOneOfField(od); err != nil { + panic(err.Error()) + } +} + +// TryClearOneOfField removes any value for any of the given one-of's fields. An +// error is returned if the given one-of descriptor does not belong to the right +// message type. +func (m *Message) TryClearOneOfField(od *desc.OneOfDescriptor) error { + if od.GetOwner().GetFullyQualifiedName() != m.md.GetFullyQualifiedName() { + return fmt.Errorf("given one-of, %s, is for wrong message type: %s; expecting %s", od.GetName(), od.GetOwner().GetFullyQualifiedName(), m.md.GetFullyQualifiedName()) + } + for _, fd := range od.GetChoices() { + m.clearField(fd) + } + return nil +} + +// GetMapField returns the value for the given map field descriptor and given +// key. It panics if an error is encountered. See TryGetMapField. +func (m *Message) GetMapField(fd *desc.FieldDescriptor, key interface{}) interface{} { + if v, err := m.TryGetMapField(fd, key); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetMapField returns the value for the given map field descriptor and given +// key. An error is returned if the given field descriptor does not belong to +// the right message type or if it is not a map field. +// +// If the map field does not contain the requested key, this method returns +// nil, nil. The Go type of the value returned mirrors the type that protoc +// would generate for the field. (See TryGetField for more details on types). +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) but corresponds to an unknown field, the unknown value will be +// parsed and become known. The parsed value will be searched for the requested +// key and any value returned. An error will be returned if the unknown value +// cannot be parsed according to the field descriptor's type information. +func (m *Message) TryGetMapField(fd *desc.FieldDescriptor, key interface{}) (interface{}, error) { + if err := m.checkField(fd); err != nil { + return nil, err + } + return m.getMapField(fd, key) +} + +// GetMapFieldByName returns the value for the map field with the given name and +// given key. It panics if an error is encountered. See TryGetMapFieldByName. +func (m *Message) GetMapFieldByName(name string, key interface{}) interface{} { + if v, err := m.TryGetMapFieldByName(name, key); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetMapFieldByName returns the value for the map field with the given name +// and given key. An error is returned if the given name is unknown or if it +// names a field that is not a map field. +// +// If this message has no value for the given field or the value has no value +// for the requested key, then this method returns nil, nil. +// +// (See TryGetField for more info on types.) +func (m *Message) TryGetMapFieldByName(name string, key interface{}) (interface{}, error) { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return nil, UnknownFieldNameError + } + return m.getMapField(fd, key) +} + +// GetMapFieldByNumber returns the value for the map field with the given tag +// number and given key. It panics if an error is encountered. See +// TryGetMapFieldByNumber. +func (m *Message) GetMapFieldByNumber(tagNumber int, key interface{}) interface{} { + if v, err := m.TryGetMapFieldByNumber(tagNumber, key); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetMapFieldByNumber returns the value for the map field with the given tag +// number and given key. An error is returned if the given tag is unknown or if +// it indicates a field that is not a map field. +// +// If this message has no value for the given field or the value has no value +// for the requested key, then this method returns nil, nil. +// +// (See TryGetField for more info on types.) +func (m *Message) TryGetMapFieldByNumber(tagNumber int, key interface{}) (interface{}, error) { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return nil, UnknownTagNumberError + } + return m.getMapField(fd, key) +} + +func (m *Message) getMapField(fd *desc.FieldDescriptor, key interface{}) (interface{}, error) { + if !fd.IsMap() { + return nil, FieldIsNotMapError + } + kfd := fd.GetMessageType().GetFields()[0] + ki, err := validElementFieldValue(kfd, key, false) + if err != nil { + return nil, err + } + mp := m.values[fd.GetNumber()] + if mp == nil { + if mp, err = m.parseUnknownField(fd); err != nil { + return nil, err + } else if mp == nil { + return nil, nil + } + } + return mp.(map[interface{}]interface{})[ki], nil +} + +// ForEachMapFieldEntry executes the given function for each entry in the map +// value for the given field descriptor. It stops iteration if the function +// returns false. It panics if an error is encountered. See +// TryForEachMapFieldEntry. +func (m *Message) ForEachMapFieldEntry(fd *desc.FieldDescriptor, fn func(key, val interface{}) bool) { + if err := m.TryForEachMapFieldEntry(fd, fn); err != nil { + panic(err.Error()) + } +} + +// TryForEachMapFieldEntry executes the given function for each entry in the map +// value for the given field descriptor. An error is returned if the given field +// descriptor does not belong to the right message type or if it is not a map +// field. +// +// Iteration ends either when all entries have been examined or when the given +// function returns false. So the function is expected to return true for normal +// iteration and false to break out. If this message has no value for the given +// field, it returns without invoking the given function. +// +// The Go type of the key and value supplied to the function mirrors the type +// that protoc would generate for the field. (See TryGetField for more details +// on types). +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) but corresponds to an unknown field, the unknown value will be +// parsed and become known. The parsed value will be searched for the requested +// key and any value returned. An error will be returned if the unknown value +// cannot be parsed according to the field descriptor's type information. +func (m *Message) TryForEachMapFieldEntry(fd *desc.FieldDescriptor, fn func(key, val interface{}) bool) error { + if err := m.checkField(fd); err != nil { + return err + } + return m.forEachMapFieldEntry(fd, fn) +} + +// ForEachMapFieldEntryByName executes the given function for each entry in the +// map value for the field with the given name. It stops iteration if the +// function returns false. It panics if an error is encountered. See +// TryForEachMapFieldEntryByName. +func (m *Message) ForEachMapFieldEntryByName(name string, fn func(key, val interface{}) bool) { + if err := m.TryForEachMapFieldEntryByName(name, fn); err != nil { + panic(err.Error()) + } +} + +// TryForEachMapFieldEntryByName executes the given function for each entry in +// the map value for the field with the given name. It stops iteration if the +// function returns false. An error is returned if the given name is unknown or +// if it names a field that is not a map field. +// +// If this message has no value for the given field, it returns without ever +// invoking the given function. +// +// (See TryGetField for more info on types supplied to the function.) +func (m *Message) TryForEachMapFieldEntryByName(name string, fn func(key, val interface{}) bool) error { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return UnknownFieldNameError + } + return m.forEachMapFieldEntry(fd, fn) +} + +// ForEachMapFieldEntryByNumber executes the given function for each entry in +// the map value for the field with the given tag number. It stops iteration if +// the function returns false. It panics if an error is encountered. See +// TryForEachMapFieldEntryByNumber. +func (m *Message) ForEachMapFieldEntryByNumber(tagNumber int, fn func(key, val interface{}) bool) { + if err := m.TryForEachMapFieldEntryByNumber(tagNumber, fn); err != nil { + panic(err.Error()) + } +} + +// TryForEachMapFieldEntryByNumber executes the given function for each entry in +// the map value for the field with the given tag number. It stops iteration if +// the function returns false. An error is returned if the given tag is unknown +// or if it indicates a field that is not a map field. +// +// If this message has no value for the given field, it returns without ever +// invoking the given function. +// +// (See TryGetField for more info on types supplied to the function.) +func (m *Message) TryForEachMapFieldEntryByNumber(tagNumber int, fn func(key, val interface{}) bool) error { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return UnknownTagNumberError + } + return m.forEachMapFieldEntry(fd, fn) +} + +func (m *Message) forEachMapFieldEntry(fd *desc.FieldDescriptor, fn func(key, val interface{}) bool) error { + if !fd.IsMap() { + return FieldIsNotMapError + } + mp := m.values[fd.GetNumber()] + if mp == nil { + if mp, err := m.parseUnknownField(fd); err != nil { + return err + } else if mp == nil { + return nil + } + } + for k, v := range mp.(map[interface{}]interface{}) { + if !fn(k, v) { + break + } + } + return nil +} + +// PutMapField sets the value for the given map field descriptor and given key +// to the given value. It panics if an error is encountered. See TryPutMapField. +func (m *Message) PutMapField(fd *desc.FieldDescriptor, key interface{}, val interface{}) { + if err := m.TryPutMapField(fd, key, val); err != nil { + panic(err.Error()) + } +} + +// TryPutMapField sets the value for the given map field descriptor and given +// key to the given value. An error is returned if the given field descriptor +// does not belong to the right message type, if the given field is not a map +// field, or if the given value is not a correct/compatible type for the given +// field. +// +// The Go type expected for a field is the same as required by TrySetField for +// a field with the same type as the map's value type. +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) it will become known. Subsequent operations using tag numbers or +// names will be able to resolve the newly-known type. If the message has a +// value for the unknown value, it is cleared, replaced by the given known +// value. +func (m *Message) TryPutMapField(fd *desc.FieldDescriptor, key interface{}, val interface{}) error { + if err := m.checkField(fd); err != nil { + return err + } + return m.putMapField(fd, key, val) +} + +// PutMapFieldByName sets the value for the map field with the given name and +// given key to the given value. It panics if an error is encountered. See +// TryPutMapFieldByName. +func (m *Message) PutMapFieldByName(name string, key interface{}, val interface{}) { + if err := m.TryPutMapFieldByName(name, key, val); err != nil { + panic(err.Error()) + } +} + +// TryPutMapFieldByName sets the value for the map field with the given name and +// the given key to the given value. An error is returned if the given name is +// unknown, if it names a field that is not a map, or if the given value has an +// incorrect type. +// +// (See TrySetField for more info on types.) +func (m *Message) TryPutMapFieldByName(name string, key interface{}, val interface{}) error { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return UnknownFieldNameError + } + return m.putMapField(fd, key, val) +} + +// PutMapFieldByNumber sets the value for the map field with the given tag +// number and given key to the given value. It panics if an error is +// encountered. See TryPutMapFieldByNumber. +func (m *Message) PutMapFieldByNumber(tagNumber int, key interface{}, val interface{}) { + if err := m.TryPutMapFieldByNumber(tagNumber, key, val); err != nil { + panic(err.Error()) + } +} + +// TryPutMapFieldByNumber sets the value for the map field with the given tag +// number and the given key to the given value. An error is returned if the +// given tag is unknown, if it indicates a field that is not a map, or if the +// given value has an incorrect type. +// +// (See TrySetField for more info on types.) +func (m *Message) TryPutMapFieldByNumber(tagNumber int, key interface{}, val interface{}) error { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return UnknownTagNumberError + } + return m.putMapField(fd, key, val) +} + +func (m *Message) putMapField(fd *desc.FieldDescriptor, key interface{}, val interface{}) error { + if !fd.IsMap() { + return FieldIsNotMapError + } + kfd := fd.GetMessageType().GetFields()[0] + ki, err := validElementFieldValue(kfd, key, false) + if err != nil { + return err + } + vfd := fd.GetMessageType().GetFields()[1] + vi, err := validElementFieldValue(vfd, val, true) + if err != nil { + return err + } + mp := m.values[fd.GetNumber()] + if mp == nil { + if mp, err = m.parseUnknownField(fd); err != nil { + return err + } else if mp == nil { + m.internalSetField(fd, map[interface{}]interface{}{ki: vi}) + return nil + } + } + mp.(map[interface{}]interface{})[ki] = vi + return nil +} + +// RemoveMapField changes the value for the given field descriptor by removing +// any value associated with the given key. It panics if an error is +// encountered. See TryRemoveMapField. +func (m *Message) RemoveMapField(fd *desc.FieldDescriptor, key interface{}) { + if err := m.TryRemoveMapField(fd, key); err != nil { + panic(err.Error()) + } +} + +// TryRemoveMapField changes the value for the given field descriptor by +// removing any value associated with the given key. An error is returned if the +// given field descriptor does not belong to the right message type or if the +// given field is not a map field. +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) it will become known. Subsequent operations using tag numbers or +// names will be able to resolve the newly-known type. If the message has a +// value for the unknown value, it is parsed and any value for the given key +// removed. +func (m *Message) TryRemoveMapField(fd *desc.FieldDescriptor, key interface{}) error { + if err := m.checkField(fd); err != nil { + return err + } + return m.removeMapField(fd, key) +} + +// RemoveMapFieldByName changes the value for the field with the given name by +// removing any value associated with the given key. It panics if an error is +// encountered. See TryRemoveMapFieldByName. +func (m *Message) RemoveMapFieldByName(name string, key interface{}) { + if err := m.TryRemoveMapFieldByName(name, key); err != nil { + panic(err.Error()) + } +} + +// TryRemoveMapFieldByName changes the value for the field with the given name +// by removing any value associated with the given key. An error is returned if +// the given name is unknown or if it names a field that is not a map. +func (m *Message) TryRemoveMapFieldByName(name string, key interface{}) error { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return UnknownFieldNameError + } + return m.removeMapField(fd, key) +} + +// RemoveMapFieldByNumber changes the value for the field with the given tag +// number by removing any value associated with the given key. It panics if an +// error is encountered. See TryRemoveMapFieldByNumber. +func (m *Message) RemoveMapFieldByNumber(tagNumber int, key interface{}) { + if err := m.TryRemoveMapFieldByNumber(tagNumber, key); err != nil { + panic(err.Error()) + } +} + +// TryRemoveMapFieldByNumber changes the value for the field with the given tag +// number by removing any value associated with the given key. An error is +// returned if the given tag is unknown or if it indicates a field that is not +// a map. +func (m *Message) TryRemoveMapFieldByNumber(tagNumber int, key interface{}) error { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return UnknownTagNumberError + } + return m.removeMapField(fd, key) +} + +func (m *Message) removeMapField(fd *desc.FieldDescriptor, key interface{}) error { + if !fd.IsMap() { + return FieldIsNotMapError + } + kfd := fd.GetMessageType().GetFields()[0] + ki, err := validElementFieldValue(kfd, key, false) + if err != nil { + return err + } + mp := m.values[fd.GetNumber()] + if mp == nil { + if mp, err = m.parseUnknownField(fd); err != nil { + return err + } else if mp == nil { + return nil + } + } + res := mp.(map[interface{}]interface{}) + delete(res, ki) + if len(res) == 0 { + delete(m.values, fd.GetNumber()) + } + return nil +} + +// FieldLength returns the number of elements in this message for the given +// field descriptor. It panics if an error is encountered. See TryFieldLength. +func (m *Message) FieldLength(fd *desc.FieldDescriptor) int { + l, err := m.TryFieldLength(fd) + if err != nil { + panic(err.Error()) + } + return l +} + +// TryFieldLength returns the number of elements in this message for the given +// field descriptor. An error is returned if the given field descriptor does not +// belong to the right message type or if it is neither a map field nor a +// repeated field. +func (m *Message) TryFieldLength(fd *desc.FieldDescriptor) (int, error) { + if err := m.checkField(fd); err != nil { + return 0, err + } + return m.fieldLength(fd) +} + +// FieldLengthByName returns the number of elements in this message for the +// field with the given name. It panics if an error is encountered. See +// TryFieldLengthByName. +func (m *Message) FieldLengthByName(name string) int { + l, err := m.TryFieldLengthByName(name) + if err != nil { + panic(err.Error()) + } + return l +} + +// TryFieldLengthByName returns the number of elements in this message for the +// field with the given name. An error is returned if the given name is unknown +// or if the named field is neither a map field nor a repeated field. +func (m *Message) TryFieldLengthByName(name string) (int, error) { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return 0, UnknownFieldNameError + } + return m.fieldLength(fd) +} + +// FieldLengthByNumber returns the number of elements in this message for the +// field with the given tag number. It panics if an error is encountered. See +// TryFieldLengthByNumber. +func (m *Message) FieldLengthByNumber(tagNumber int32) int { + l, err := m.TryFieldLengthByNumber(tagNumber) + if err != nil { + panic(err.Error()) + } + return l +} + +// TryFieldLengthByNumber returns the number of elements in this message for the +// field with the given tag number. An error is returned if the given tag is +// unknown or if the named field is neither a map field nor a repeated field. +func (m *Message) TryFieldLengthByNumber(tagNumber int32) (int, error) { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return 0, UnknownTagNumberError + } + return m.fieldLength(fd) +} + +func (m *Message) fieldLength(fd *desc.FieldDescriptor) (int, error) { + if !fd.IsRepeated() { + return 0, FieldIsNotRepeatedError + } + val := m.values[fd.GetNumber()] + if val == nil { + var err error + if val, err = m.parseUnknownField(fd); err != nil { + return 0, err + } else if val == nil { + return 0, nil + } + } + if sl, ok := val.([]interface{}); ok { + return len(sl), nil + } else if mp, ok := val.(map[interface{}]interface{}); ok { + return len(mp), nil + } + return 0, nil +} + +// GetRepeatedField returns the value for the given repeated field descriptor at +// the given index. It panics if an error is encountered. See +// TryGetRepeatedField. +func (m *Message) GetRepeatedField(fd *desc.FieldDescriptor, index int) interface{} { + if v, err := m.TryGetRepeatedField(fd, index); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetRepeatedField returns the value for the given repeated field descriptor +// at the given index. An error is returned if the given field descriptor does +// not belong to the right message type, if it is not a repeated field, or if +// the given index is out of range (less than zero or greater than or equal to +// the length of the repeated field). Also, even though map fields technically +// are repeated fields, if the given field is a map field an error will result: +// map representation does not lend itself to random access by index. +// +// The Go type of the value returned mirrors the type that protoc would generate +// for the field's element type. (See TryGetField for more details on types). +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) but corresponds to an unknown field, the unknown value will be +// parsed and become known. The value at the given index in the parsed value +// will be returned. An error will be returned if the unknown value cannot be +// parsed according to the field descriptor's type information. +func (m *Message) TryGetRepeatedField(fd *desc.FieldDescriptor, index int) (interface{}, error) { + if index < 0 { + return nil, IndexOutOfRangeError + } + if err := m.checkField(fd); err != nil { + return nil, err + } + return m.getRepeatedField(fd, index) +} + +// GetRepeatedFieldByName returns the value for the repeated field with the +// given name at the given index. It panics if an error is encountered. See +// TryGetRepeatedFieldByName. +func (m *Message) GetRepeatedFieldByName(name string, index int) interface{} { + if v, err := m.TryGetRepeatedFieldByName(name, index); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetRepeatedFieldByName returns the value for the repeated field with the +// given name at the given index. An error is returned if the given name is +// unknown, if it names a field that is not a repeated field (or is a map +// field), or if the given index is out of range (less than zero or greater +// than or equal to the length of the repeated field). +// +// (See TryGetField for more info on types.) +func (m *Message) TryGetRepeatedFieldByName(name string, index int) (interface{}, error) { + if index < 0 { + return nil, IndexOutOfRangeError + } + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return nil, UnknownFieldNameError + } + return m.getRepeatedField(fd, index) +} + +// GetRepeatedFieldByNumber returns the value for the repeated field with the +// given tag number at the given index. It panics if an error is encountered. +// See TryGetRepeatedFieldByNumber. +func (m *Message) GetRepeatedFieldByNumber(tagNumber int, index int) interface{} { + if v, err := m.TryGetRepeatedFieldByNumber(tagNumber, index); err != nil { + panic(err.Error()) + } else { + return v + } +} + +// TryGetRepeatedFieldByNumber returns the value for the repeated field with the +// given tag number at the given index. An error is returned if the given tag is +// unknown, if it indicates a field that is not a repeated field (or is a map +// field), or if the given index is out of range (less than zero or greater than +// or equal to the length of the repeated field). +// +// (See TryGetField for more info on types.) +func (m *Message) TryGetRepeatedFieldByNumber(tagNumber int, index int) (interface{}, error) { + if index < 0 { + return nil, IndexOutOfRangeError + } + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return nil, UnknownTagNumberError + } + return m.getRepeatedField(fd, index) +} + +func (m *Message) getRepeatedField(fd *desc.FieldDescriptor, index int) (interface{}, error) { + if fd.IsMap() || !fd.IsRepeated() { + return nil, FieldIsNotRepeatedError + } + sl := m.values[fd.GetNumber()] + if sl == nil { + var err error + if sl, err = m.parseUnknownField(fd); err != nil { + return nil, err + } else if sl == nil { + return nil, IndexOutOfRangeError + } + } + res := sl.([]interface{}) + if index >= len(res) { + return nil, IndexOutOfRangeError + } + return res[index], nil +} + +// AddRepeatedField appends the given value to the given repeated field. It +// panics if an error is encountered. See TryAddRepeatedField. +func (m *Message) AddRepeatedField(fd *desc.FieldDescriptor, val interface{}) { + if err := m.TryAddRepeatedField(fd, val); err != nil { + panic(err.Error()) + } +} + +// TryAddRepeatedField appends the given value to the given repeated field. An +// error is returned if the given field descriptor does not belong to the right +// message type, if the given field is not repeated, or if the given value is +// not a correct/compatible type for the given field. If the given field is a +// map field, the call will succeed if the given value is an instance of the +// map's entry message type. +// +// The Go type expected for a field is the same as required by TrySetField for +// a non-repeated field of the same type. +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) it will become known. Subsequent operations using tag numbers or +// names will be able to resolve the newly-known type. If the message has a +// value for the unknown value, it is parsed and the given value is appended to +// it. +func (m *Message) TryAddRepeatedField(fd *desc.FieldDescriptor, val interface{}) error { + if err := m.checkField(fd); err != nil { + return err + } + return m.addRepeatedField(fd, val) +} + +// AddRepeatedFieldByName appends the given value to the repeated field with the +// given name. It panics if an error is encountered. See +// TryAddRepeatedFieldByName. +func (m *Message) AddRepeatedFieldByName(name string, val interface{}) { + if err := m.TryAddRepeatedFieldByName(name, val); err != nil { + panic(err.Error()) + } +} + +// TryAddRepeatedFieldByName appends the given value to the repeated field with +// the given name. An error is returned if the given name is unknown, if it +// names a field that is not repeated, or if the given value has an incorrect +// type. +// +// (See TrySetField for more info on types.) +func (m *Message) TryAddRepeatedFieldByName(name string, val interface{}) error { + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return UnknownFieldNameError + } + return m.addRepeatedField(fd, val) +} + +// AddRepeatedFieldByNumber appends the given value to the repeated field with +// the given tag number. It panics if an error is encountered. See +// TryAddRepeatedFieldByNumber. +func (m *Message) AddRepeatedFieldByNumber(tagNumber int, val interface{}) { + if err := m.TryAddRepeatedFieldByNumber(tagNumber, val); err != nil { + panic(err.Error()) + } +} + +// TryAddRepeatedFieldByNumber appends the given value to the repeated field +// with the given tag number. An error is returned if the given tag is unknown, +// if it indicates a field that is not repeated, or if the given value has an +// incorrect type. +// +// (See TrySetField for more info on types.) +func (m *Message) TryAddRepeatedFieldByNumber(tagNumber int, val interface{}) error { + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return UnknownTagNumberError + } + return m.addRepeatedField(fd, val) +} + +func (m *Message) addRepeatedField(fd *desc.FieldDescriptor, val interface{}) error { + if !fd.IsRepeated() { + return FieldIsNotRepeatedError + } + val, err := validElementFieldValue(fd, val, false) + if err != nil { + return err + } + + if fd.IsMap() { + // We're lenient. Just as we allow setting a map field to a slice of entry messages, we also allow + // adding entries one at a time (as if the field were a normal repeated field). + msg := val.(proto.Message) + dm, err := asDynamicMessage(msg, fd.GetMessageType(), m.mf) + if err != nil { + return err + } + k, err := dm.TryGetFieldByNumber(1) + if err != nil { + return err + } + v, err := dm.TryGetFieldByNumber(2) + if err != nil { + return err + } + return m.putMapField(fd, k, v) + } + + sl := m.values[fd.GetNumber()] + if sl == nil { + if sl, err = m.parseUnknownField(fd); err != nil { + return err + } else if sl == nil { + sl = []interface{}{} + } + } + res := sl.([]interface{}) + res = append(res, val) + m.internalSetField(fd, res) + return nil +} + +// SetRepeatedField sets the value for the given repeated field descriptor and +// given index to the given value. It panics if an error is encountered. See +// SetRepeatedField. +func (m *Message) SetRepeatedField(fd *desc.FieldDescriptor, index int, val interface{}) { + if err := m.TrySetRepeatedField(fd, index, val); err != nil { + panic(err.Error()) + } +} + +// TrySetRepeatedField sets the value for the given repeated field descriptor +// and given index to the given value. An error is returned if the given field +// descriptor does not belong to the right message type, if the given field is +// not repeated, or if the given value is not a correct/compatible type for the +// given field. Also, even though map fields technically are repeated fields, if +// the given field is a map field an error will result: map representation does +// not lend itself to random access by index. +// +// The Go type expected for a field is the same as required by TrySetField for +// a non-repeated field of the same type. +// +// If the given field descriptor is not known (e.g. not present in the message +// descriptor) it will become known. Subsequent operations using tag numbers or +// names will be able to resolve the newly-known type. If the message has a +// value for the unknown value, it is parsed and the element at the given index +// is replaced with the given value. +func (m *Message) TrySetRepeatedField(fd *desc.FieldDescriptor, index int, val interface{}) error { + if index < 0 { + return IndexOutOfRangeError + } + if err := m.checkField(fd); err != nil { + return err + } + return m.setRepeatedField(fd, index, val) +} + +// SetRepeatedFieldByName sets the value for the repeated field with the given +// name and given index to the given value. It panics if an error is +// encountered. See TrySetRepeatedFieldByName. +func (m *Message) SetRepeatedFieldByName(name string, index int, val interface{}) { + if err := m.TrySetRepeatedFieldByName(name, index, val); err != nil { + panic(err.Error()) + } +} + +// TrySetRepeatedFieldByName sets the value for the repeated field with the +// given name and the given index to the given value. An error is returned if +// the given name is unknown, if it names a field that is not repeated (or is a +// map field), or if the given value has an incorrect type. +// +// (See TrySetField for more info on types.) +func (m *Message) TrySetRepeatedFieldByName(name string, index int, val interface{}) error { + if index < 0 { + return IndexOutOfRangeError + } + fd := m.FindFieldDescriptorByName(name) + if fd == nil { + return UnknownFieldNameError + } + return m.setRepeatedField(fd, index, val) +} + +// SetRepeatedFieldByNumber sets the value for the repeated field with the given +// tag number and given index to the given value. It panics if an error is +// encountered. See TrySetRepeatedFieldByNumber. +func (m *Message) SetRepeatedFieldByNumber(tagNumber int, index int, val interface{}) { + if err := m.TrySetRepeatedFieldByNumber(tagNumber, index, val); err != nil { + panic(err.Error()) + } +} + +// TrySetRepeatedFieldByNumber sets the value for the repeated field with the +// given tag number and the given index to the given value. An error is returned +// if the given tag is unknown, if it indicates a field that is not repeated (or +// is a map field), or if the given value has an incorrect type. +// +// (See TrySetField for more info on types.) +func (m *Message) TrySetRepeatedFieldByNumber(tagNumber int, index int, val interface{}) error { + if index < 0 { + return IndexOutOfRangeError + } + fd := m.FindFieldDescriptor(int32(tagNumber)) + if fd == nil { + return UnknownTagNumberError + } + return m.setRepeatedField(fd, index, val) +} + +func (m *Message) setRepeatedField(fd *desc.FieldDescriptor, index int, val interface{}) error { + if fd.IsMap() || !fd.IsRepeated() { + return FieldIsNotRepeatedError + } + val, err := validElementFieldValue(fd, val, false) + if err != nil { + return err + } + sl := m.values[fd.GetNumber()] + if sl == nil { + if sl, err = m.parseUnknownField(fd); err != nil { + return err + } else if sl == nil { + return IndexOutOfRangeError + } + } + res := sl.([]interface{}) + if index >= len(res) { + return IndexOutOfRangeError + } + res[index] = val + return nil +} + +// GetUnknownField gets the value(s) for the given unknown tag number. If this +// message has no unknown fields with the given tag, nil is returned. +func (m *Message) GetUnknownField(tagNumber int32) []UnknownField { + if u, ok := m.unknownFields[tagNumber]; ok { + return u + } else { + return nil + } +} + +func (m *Message) parseUnknownField(fd *desc.FieldDescriptor) (interface{}, error) { + unks, ok := m.unknownFields[fd.GetNumber()] + if !ok { + return nil, nil + } + var v interface{} + var sl []interface{} + var mp map[interface{}]interface{} + if fd.IsMap() { + mp = map[interface{}]interface{}{} + } + var err error + for _, unk := range unks { + var val interface{} + if unk.Encoding == proto.WireBytes || unk.Encoding == proto.WireStartGroup { + val, err = codec.DecodeLengthDelimitedField(fd, unk.Contents, m.mf) + } else { + val, err = codec.DecodeScalarField(fd, unk.Value) + } + if err != nil { + return nil, err + } + if fd.IsMap() { + newEntry := val.(*Message) + kk, err := newEntry.TryGetFieldByNumber(1) + if err != nil { + return nil, err + } + vv, err := newEntry.TryGetFieldByNumber(2) + if err != nil { + return nil, err + } + mp[kk] = vv + v = mp + } else if fd.IsRepeated() { + t := reflect.TypeOf(val) + if t.Kind() == reflect.Slice && t != typeOfBytes { + // append slices if we unmarshalled a packed repeated field + newVals := val.([]interface{}) + sl = append(sl, newVals...) + } else { + sl = append(sl, val) + } + v = sl + } else { + v = val + } + } + m.internalSetField(fd, v) + return v, nil +} + +func validFieldValue(fd *desc.FieldDescriptor, val interface{}) (interface{}, error) { + return validFieldValueForRv(fd, reflect.ValueOf(val)) +} + +func validFieldValueForRv(fd *desc.FieldDescriptor, val reflect.Value) (interface{}, error) { + if fd.IsMap() && val.Kind() == reflect.Map { + return validFieldValueForMapField(fd, val) + } + + if fd.IsRepeated() { // this will also catch map fields where given value was not a map + if val.Kind() != reflect.Array && val.Kind() != reflect.Slice { + if fd.IsMap() { + return nil, fmt.Errorf("value for map field must be a map; instead was %v", val.Type()) + } else { + return nil, fmt.Errorf("value for repeated field must be a slice; instead was %v", val.Type()) + } + } + + if fd.IsMap() { + // value should be a slice of entry messages that we need convert into a map[interface{}]interface{} + m := map[interface{}]interface{}{} + for i := 0; i < val.Len(); i++ { + e, err := validElementFieldValue(fd, val.Index(i).Interface(), false) + if err != nil { + return nil, err + } + msg := e.(proto.Message) + dm, err := asDynamicMessage(msg, fd.GetMessageType(), nil) + if err != nil { + return nil, err + } + k, err := dm.TryGetFieldByNumber(1) + if err != nil { + return nil, err + } + v, err := dm.TryGetFieldByNumber(2) + if err != nil { + return nil, err + } + m[k] = v + } + return m, nil + } + + // make a defensive copy while checking contents (also converts to []interface{}) + s := make([]interface{}, val.Len()) + for i := 0; i < val.Len(); i++ { + ev := val.Index(i) + if ev.Kind() == reflect.Interface { + // unwrap it + ev = reflect.ValueOf(ev.Interface()) + } + e, err := validElementFieldValueForRv(fd, ev, false) + if err != nil { + return nil, err + } + s[i] = e + } + + return s, nil + } + + return validElementFieldValueForRv(fd, val, false) +} + +func asDynamicMessage(m proto.Message, md *desc.MessageDescriptor, mf *MessageFactory) (*Message, error) { + if dm, ok := m.(*Message); ok { + return dm, nil + } + dm := NewMessageWithMessageFactory(md, mf) + if err := dm.mergeFrom(m); err != nil { + return nil, err + } + return dm, nil +} + +func validElementFieldValue(fd *desc.FieldDescriptor, val interface{}, allowNilMessage bool) (interface{}, error) { + return validElementFieldValueForRv(fd, reflect.ValueOf(val), allowNilMessage) +} + +func validElementFieldValueForRv(fd *desc.FieldDescriptor, val reflect.Value, allowNilMessage bool) (interface{}, error) { + t := fd.GetType() + if !val.IsValid() { + return nil, typeError(fd, nil) + } + + switch t { + case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32, + descriptorpb.FieldDescriptorProto_TYPE_INT32, + descriptorpb.FieldDescriptorProto_TYPE_SINT32, + descriptorpb.FieldDescriptorProto_TYPE_ENUM: + return toInt32(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64, + descriptorpb.FieldDescriptorProto_TYPE_INT64, + descriptorpb.FieldDescriptorProto_TYPE_SINT64: + return toInt64(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_FIXED32, + descriptorpb.FieldDescriptorProto_TYPE_UINT32: + return toUint32(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_FIXED64, + descriptorpb.FieldDescriptorProto_TYPE_UINT64: + return toUint64(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: + return toFloat32(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: + return toFloat64(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_BOOL: + return toBool(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_BYTES: + return toBytes(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_STRING: + return toString(reflect.Indirect(val), fd) + + case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, + descriptorpb.FieldDescriptorProto_TYPE_GROUP: + m, err := asMessage(val, fd.GetFullyQualifiedName()) + // check that message is correct type + if err != nil { + return nil, err + } + var msgType string + if dm, ok := m.(*Message); ok { + if allowNilMessage && dm == nil { + // if dm == nil, we'll panic below, so early out if that is allowed + // (only allowed for map values, to indicate an entry w/ no value) + return m, nil + } + msgType = dm.GetMessageDescriptor().GetFullyQualifiedName() + } else { + msgType = proto.MessageName(m) + } + if msgType != fd.GetMessageType().GetFullyQualifiedName() { + return nil, fmt.Errorf("message field %s requires value of type %s; received %s", fd.GetFullyQualifiedName(), fd.GetMessageType().GetFullyQualifiedName(), msgType) + } + return m, nil + + default: + return nil, fmt.Errorf("unable to handle unrecognized field type: %v", fd.GetType()) + } +} + +func toInt32(v reflect.Value, fd *desc.FieldDescriptor) (int32, error) { + if v.Kind() == reflect.Int32 { + return int32(v.Int()), nil + } + return 0, typeError(fd, v.Type()) +} + +func toUint32(v reflect.Value, fd *desc.FieldDescriptor) (uint32, error) { + if v.Kind() == reflect.Uint32 { + return uint32(v.Uint()), nil + } + return 0, typeError(fd, v.Type()) +} + +func toFloat32(v reflect.Value, fd *desc.FieldDescriptor) (float32, error) { + if v.Kind() == reflect.Float32 { + return float32(v.Float()), nil + } + return 0, typeError(fd, v.Type()) +} + +func toInt64(v reflect.Value, fd *desc.FieldDescriptor) (int64, error) { + if v.Kind() == reflect.Int64 || v.Kind() == reflect.Int || v.Kind() == reflect.Int32 { + return v.Int(), nil + } + return 0, typeError(fd, v.Type()) +} + +func toUint64(v reflect.Value, fd *desc.FieldDescriptor) (uint64, error) { + if v.Kind() == reflect.Uint64 || v.Kind() == reflect.Uint || v.Kind() == reflect.Uint32 { + return v.Uint(), nil + } + return 0, typeError(fd, v.Type()) +} + +func toFloat64(v reflect.Value, fd *desc.FieldDescriptor) (float64, error) { + if v.Kind() == reflect.Float64 || v.Kind() == reflect.Float32 { + return v.Float(), nil + } + return 0, typeError(fd, v.Type()) +} + +func toBool(v reflect.Value, fd *desc.FieldDescriptor) (bool, error) { + if v.Kind() == reflect.Bool { + return v.Bool(), nil + } + return false, typeError(fd, v.Type()) +} + +func toBytes(v reflect.Value, fd *desc.FieldDescriptor) ([]byte, error) { + if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 { + return v.Bytes(), nil + } + return nil, typeError(fd, v.Type()) +} + +func toString(v reflect.Value, fd *desc.FieldDescriptor) (string, error) { + if v.Kind() == reflect.String { + return v.String(), nil + } + return "", typeError(fd, v.Type()) +} + +func typeError(fd *desc.FieldDescriptor, t reflect.Type) error { + return fmt.Errorf( + "%s field %s is not compatible with value of type %v", + getTypeString(fd), fd.GetFullyQualifiedName(), t) +} + +func getTypeString(fd *desc.FieldDescriptor) string { + return strings.ToLower(fd.GetType().String()) +} + +func asMessage(v reflect.Value, fieldName string) (proto.Message, error) { + t := v.Type() + // we need a pointer to a struct that implements proto.Message + if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct || !t.Implements(typeOfProtoMessage) { + return nil, fmt.Errorf("message field %s requires is not compatible with value of type %v", fieldName, v.Type()) + } + return v.Interface().(proto.Message), nil +} + +// Reset resets this message to an empty message. It removes all values set in +// the message. +func (m *Message) Reset() { + for k := range m.values { + delete(m.values, k) + } + for k := range m.unknownFields { + delete(m.unknownFields, k) + } +} + +// String returns this message rendered in compact text format. +func (m *Message) String() string { + b, err := m.MarshalText() + if err != nil { + panic(fmt.Sprintf("Failed to create string representation of message: %s", err.Error())) + } + return string(b) +} + +// ProtoMessage is present to satisfy the proto.Message interface. +func (m *Message) ProtoMessage() { +} + +// ConvertTo converts this dynamic message into the given message. This is +// shorthand for resetting then merging: +// +// target.Reset() +// m.MergeInto(target) +func (m *Message) ConvertTo(target proto.Message) error { + if err := m.checkType(target); err != nil { + return err + } + + target.Reset() + return m.mergeInto(target, defaultDeterminism) +} + +// ConvertToDeterministic converts this dynamic message into the given message. +// It is just like ConvertTo, but it attempts to produce deterministic results. +// That means that if the target is a generated message (not another dynamic +// message) and the current runtime is unaware of any fields or extensions that +// are present in m, they will be serialized into the target's unrecognized +// fields deterministically. +func (m *Message) ConvertToDeterministic(target proto.Message) error { + if err := m.checkType(target); err != nil { + return err + } + + target.Reset() + return m.mergeInto(target, true) +} + +// ConvertFrom converts the given message into this dynamic message. This is +// shorthand for resetting then merging: +// +// m.Reset() +// m.MergeFrom(target) +func (m *Message) ConvertFrom(target proto.Message) error { + if err := m.checkType(target); err != nil { + return err + } + + m.Reset() + return m.mergeFrom(target) +} + +// MergeInto merges this dynamic message into the given message. All field +// values in this message will be set on the given message. For map fields, +// entries are added to the given message (if the given message has existing +// values for like keys, they are overwritten). For slice fields, elements are +// added. +// +// If the given message has a different set of known fields, it is possible for +// some known fields in this message to be represented as unknown fields in the +// given message after merging, and vice versa. +func (m *Message) MergeInto(target proto.Message) error { + if err := m.checkType(target); err != nil { + return err + } + return m.mergeInto(target, defaultDeterminism) +} + +// MergeIntoDeterministic merges this dynamic message into the given message. +// It is just like MergeInto, but it attempts to produce deterministic results. +// That means that if the target is a generated message (not another dynamic +// message) and the current runtime is unaware of any fields or extensions that +// are present in m, they will be serialized into the target's unrecognized +// fields deterministically. +func (m *Message) MergeIntoDeterministic(target proto.Message) error { + if err := m.checkType(target); err != nil { + return err + } + return m.mergeInto(target, true) +} + +// MergeFrom merges the given message into this dynamic message. All field +// values in the given message will be set on this message. For map fields, +// entries are added to this message (if this message has existing values for +// like keys, they are overwritten). For slice fields, elements are added. +// +// If the given message has a different set of known fields, it is possible for +// some known fields in that message to be represented as unknown fields in this +// message after merging, and vice versa. +func (m *Message) MergeFrom(source proto.Message) error { + if err := m.checkType(source); err != nil { + return err + } + return m.mergeFrom(source) +} + +// Merge implements the proto.Merger interface so that dynamic messages are +// compatible with the proto.Merge function. It delegates to MergeFrom but will +// panic on error as the proto.Merger interface doesn't allow for returning an +// error. +// +// Unlike nearly all other methods, this method can work if this message's type +// is not defined (such as instantiating the message without using NewMessage). +// This is strictly so that dynamic message's are compatible with the +// proto.Clone function, which instantiates a new message via reflection (thus +// its message descriptor will not be set) and than calls Merge. +func (m *Message) Merge(source proto.Message) { + if m.md == nil { + // To support proto.Clone, initialize the descriptor from the source. + if dm, ok := source.(*Message); ok { + m.md = dm.md + // also make sure the clone uses the same message factory and + // extensions and also knows about the same extra fields (if any) + m.mf = dm.mf + m.er = dm.er + m.extraFields = dm.extraFields + } else if md, err := desc.LoadMessageDescriptorForMessage(source); err != nil { + panic(err.Error()) + } else { + m.md = md + } + } + + if err := m.MergeFrom(source); err != nil { + panic(err.Error()) + } +} + +func (m *Message) checkType(target proto.Message) error { + if dm, ok := target.(*Message); ok { + if dm.md.GetFullyQualifiedName() != m.md.GetFullyQualifiedName() { + return fmt.Errorf("given message has wrong type: %q; expecting %q", dm.md.GetFullyQualifiedName(), m.md.GetFullyQualifiedName()) + } + return nil + } + + msgName := proto.MessageName(target) + if msgName != m.md.GetFullyQualifiedName() { + return fmt.Errorf("given message has wrong type: %q; expecting %q", msgName, m.md.GetFullyQualifiedName()) + } + return nil +} + +func (m *Message) mergeInto(pm proto.Message, deterministic bool) error { + if dm, ok := pm.(*Message); ok { + return dm.mergeFrom(m) + } + + target := reflect.ValueOf(pm) + if target.Kind() == reflect.Ptr { + target = target.Elem() + } + + // track tags for which the dynamic message has data but the given + // message doesn't know about it + unknownTags := map[int32]struct{}{} + for tag := range m.values { + unknownTags[tag] = struct{}{} + } + + // check that we can successfully do the merge + structProps := proto.GetProperties(reflect.TypeOf(pm).Elem()) + for _, prop := range structProps.Prop { + if prop.Tag == 0 { + continue // one-of or special field (such as XXX_unrecognized, etc.) + } + tag := int32(prop.Tag) + v, ok := m.values[tag] + if !ok { + continue + } + if unknownTags != nil { + delete(unknownTags, tag) + } + f := target.FieldByName(prop.Name) + ft := f.Type() + val := reflect.ValueOf(v) + if !canConvert(val, ft) { + return fmt.Errorf("cannot convert %v to %v", val.Type(), ft) + } + } + // check one-of fields + for _, oop := range structProps.OneofTypes { + prop := oop.Prop + tag := int32(prop.Tag) + v, ok := m.values[tag] + if !ok { + continue + } + if unknownTags != nil { + delete(unknownTags, tag) + } + stf, ok := oop.Type.Elem().FieldByName(prop.Name) + if !ok { + return fmt.Errorf("one-of field indicates struct field name %s, but type %v has no such field", prop.Name, oop.Type.Elem()) + } + ft := stf.Type + val := reflect.ValueOf(v) + if !canConvert(val, ft) { + return fmt.Errorf("cannot convert %v to %v", val.Type(), ft) + } + } + // and check extensions, too + for tag, ext := range proto.RegisteredExtensions(pm) { + v, ok := m.values[tag] + if !ok { + continue + } + if unknownTags != nil { + delete(unknownTags, tag) + } + ft := reflect.TypeOf(ext.ExtensionType) + val := reflect.ValueOf(v) + if !canConvert(val, ft) { + return fmt.Errorf("cannot convert %v to %v", val.Type(), ft) + } + } + + // now actually perform the merge + for _, prop := range structProps.Prop { + v, ok := m.values[int32(prop.Tag)] + if !ok { + continue + } + f := target.FieldByName(prop.Name) + if err := mergeVal(reflect.ValueOf(v), f, deterministic); err != nil { + return err + } + } + // merge one-ofs + for _, oop := range structProps.OneofTypes { + prop := oop.Prop + tag := int32(prop.Tag) + v, ok := m.values[tag] + if !ok { + continue + } + oov := reflect.New(oop.Type.Elem()) + f := oov.Elem().FieldByName(prop.Name) + if err := mergeVal(reflect.ValueOf(v), f, deterministic); err != nil { + return err + } + target.Field(oop.Field).Set(oov) + } + // merge extensions, too + for tag, ext := range proto.RegisteredExtensions(pm) { + v, ok := m.values[tag] + if !ok { + continue + } + e := reflect.New(reflect.TypeOf(ext.ExtensionType)).Elem() + if err := mergeVal(reflect.ValueOf(v), e, deterministic); err != nil { + return err + } + if err := proto.SetExtension(pm, ext, e.Interface()); err != nil { + // shouldn't happen since we already checked that the extension type was compatible above + return err + } + } + + // if we have fields that the given message doesn't know about, add to its unknown fields + if len(unknownTags) > 0 { + var b codec.Buffer + b.SetDeterministic(deterministic) + if deterministic { + // if we need to emit things deterministically, sort the + // extensions by their tag number + sortedUnknownTags := make([]int32, 0, len(unknownTags)) + for tag := range unknownTags { + sortedUnknownTags = append(sortedUnknownTags, tag) + } + sort.Slice(sortedUnknownTags, func(i, j int) bool { + return sortedUnknownTags[i] < sortedUnknownTags[j] + }) + for _, tag := range sortedUnknownTags { + fd := m.FindFieldDescriptor(tag) + if err := b.EncodeFieldValue(fd, m.values[tag]); err != nil { + return err + } + } + } else { + for tag := range unknownTags { + fd := m.FindFieldDescriptor(tag) + if err := b.EncodeFieldValue(fd, m.values[tag]); err != nil { + return err + } + } + } + + internal.SetUnrecognized(pm, b.Bytes()) + } + + // finally, convey unknown fields into the given message by letting it unmarshal them + // (this will append to its unknown fields if not known; if somehow the given message recognizes + // a field even though the dynamic message did not, it will get correctly unmarshalled) + if unknownTags != nil && len(m.unknownFields) > 0 { + var b codec.Buffer + _ = m.marshalUnknownFields(&b) + _ = proto.UnmarshalMerge(b.Bytes(), pm) + } + + return nil +} + +func canConvert(src reflect.Value, target reflect.Type) bool { + if src.Kind() == reflect.Interface { + src = reflect.ValueOf(src.Interface()) + } + srcType := src.Type() + // we allow convertible types instead of requiring exact types so that calling + // code can, for example, assign an enum constant to an enum field. In that case, + // one type is the enum type (a sub-type of int32) and the other may be the int32 + // type. So we automatically do the conversion in that case. + if srcType.ConvertibleTo(target) { + return true + } else if target.Kind() == reflect.Ptr && srcType.ConvertibleTo(target.Elem()) { + return true + } else if target.Kind() == reflect.Slice { + if srcType.Kind() != reflect.Slice { + return false + } + et := target.Elem() + for i := 0; i < src.Len(); i++ { + if !canConvert(src.Index(i), et) { + return false + } + } + return true + } else if target.Kind() == reflect.Map { + if srcType.Kind() != reflect.Map { + return false + } + return canConvertMap(src, target) + } else if srcType == typeOfDynamicMessage && target.Implements(typeOfProtoMessage) { + z := reflect.Zero(target).Interface() + msgType := proto.MessageName(z.(proto.Message)) + return msgType == src.Interface().(*Message).GetMessageDescriptor().GetFullyQualifiedName() + } else { + return false + } +} + +func mergeVal(src, target reflect.Value, deterministic bool) error { + if src.Kind() == reflect.Interface && !src.IsNil() { + src = src.Elem() + } + srcType := src.Type() + targetType := target.Type() + if srcType.ConvertibleTo(targetType) { + if targetType.Implements(typeOfProtoMessage) && !target.IsNil() { + Merge(target.Interface().(proto.Message), src.Convert(targetType).Interface().(proto.Message)) + } else { + target.Set(src.Convert(targetType)) + } + } else if targetType.Kind() == reflect.Ptr && srcType.ConvertibleTo(targetType.Elem()) { + if !src.CanAddr() { + target.Set(reflect.New(targetType.Elem())) + target.Elem().Set(src.Convert(targetType.Elem())) + } else { + target.Set(src.Addr().Convert(targetType)) + } + } else if targetType.Kind() == reflect.Slice { + l := target.Len() + newL := l + src.Len() + if target.Cap() < newL { + // expand capacity of the slice and copy + newSl := reflect.MakeSlice(targetType, newL, newL) + for i := 0; i < target.Len(); i++ { + newSl.Index(i).Set(target.Index(i)) + } + target.Set(newSl) + } else { + target.SetLen(newL) + } + for i := 0; i < src.Len(); i++ { + dest := target.Index(l + i) + if dest.Kind() == reflect.Ptr { + dest.Set(reflect.New(dest.Type().Elem())) + } + if err := mergeVal(src.Index(i), dest, deterministic); err != nil { + return err + } + } + } else if targetType.Kind() == reflect.Map { + return mergeMapVal(src, target, targetType, deterministic) + } else if srcType == typeOfDynamicMessage && targetType.Implements(typeOfProtoMessage) { + dm := src.Interface().(*Message) + if target.IsNil() { + target.Set(reflect.New(targetType.Elem())) + } + m := target.Interface().(proto.Message) + if err := dm.mergeInto(m, deterministic); err != nil { + return err + } + } else { + return fmt.Errorf("cannot convert %v to %v", srcType, targetType) + } + return nil +} + +func (m *Message) mergeFrom(pm proto.Message) error { + if dm, ok := pm.(*Message); ok { + // if given message is also a dynamic message, we merge differently + for tag, v := range dm.values { + fd := m.FindFieldDescriptor(tag) + if fd == nil { + fd = dm.FindFieldDescriptor(tag) + } + if err := mergeField(m, fd, v); err != nil { + return err + } + } + return nil + } + + pmrv := reflect.ValueOf(pm) + if pmrv.IsNil() { + // nil is an empty message, so nothing to do + return nil + } + + // check that we can successfully do the merge + src := pmrv.Elem() + values := map[*desc.FieldDescriptor]interface{}{} + props := proto.GetProperties(reflect.TypeOf(pm).Elem()) + if props == nil { + return fmt.Errorf("could not determine message properties to merge for %v", reflect.TypeOf(pm).Elem()) + } + + // regular fields + for _, prop := range props.Prop { + if prop.Tag == 0 { + continue // one-of or special field (such as XXX_unrecognized, etc.) + } + fd := m.FindFieldDescriptor(int32(prop.Tag)) + if fd == nil { + // Our descriptor has different fields than this message object. So + // try to reflect on the message object's fields. + md, err := desc.LoadMessageDescriptorForMessage(pm) + if err != nil { + return err + } + fd = md.FindFieldByNumber(int32(prop.Tag)) + if fd == nil { + return fmt.Errorf("message descriptor %q did not contain field for tag %d (%q)", md.GetFullyQualifiedName(), prop.Tag, prop.Name) + } + } + rv := src.FieldByName(prop.Name) + if (rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Slice) && rv.IsNil() { + continue + } + if v, err := validFieldValueForRv(fd, rv); err != nil { + return err + } else { + values[fd] = v + } + } + + // one-of fields + for _, oop := range props.OneofTypes { + oov := src.Field(oop.Field).Elem() + if !oov.IsValid() || oov.Type() != oop.Type { + // this field is unset (in other words, one-of message field is not currently set to this option) + continue + } + prop := oop.Prop + rv := oov.Elem().FieldByName(prop.Name) + fd := m.FindFieldDescriptor(int32(prop.Tag)) + if fd == nil { + // Our descriptor has different fields than this message object. So + // try to reflect on the message object's fields. + md, err := desc.LoadMessageDescriptorForMessage(pm) + if err != nil { + return err + } + fd = md.FindFieldByNumber(int32(prop.Tag)) + if fd == nil { + return fmt.Errorf("message descriptor %q did not contain field for tag %d (%q in one-of %q)", md.GetFullyQualifiedName(), prop.Tag, prop.Name, src.Type().Field(oop.Field).Name) + } + } + if v, err := validFieldValueForRv(fd, rv); err != nil { + return err + } else { + values[fd] = v + } + } + + // extension fields + rexts, _ := proto.ExtensionDescs(pm) + for _, ed := range rexts { + v, _ := proto.GetExtension(pm, ed) + if v == nil { + continue + } + if ed.ExtensionType == nil { + // unrecognized extension: we'll handle that below when we + // handle other unrecognized fields + continue + } + fd := m.er.FindExtension(m.md.GetFullyQualifiedName(), ed.Field) + if fd == nil { + var err error + if fd, err = desc.LoadFieldDescriptorForExtension(ed); err != nil { + return err + } + } + if v, err := validFieldValue(fd, v); err != nil { + return err + } else { + values[fd] = v + } + } + + // With API v2, it is possible that the new protoreflect interfaces + // were used to store an extension, which means it can't be returned + // by proto.ExtensionDescs and it's also not in the unrecognized data. + // So we have a separate loop to trawl through it... + var err error + proto.MessageReflect(pm).Range(func(fld protoreflect.FieldDescriptor, val protoreflect.Value) bool { + if !fld.IsExtension() { + // normal field... we already got it above + return true + } + xt := fld.(protoreflect.ExtensionTypeDescriptor) + if _, ok := xt.Type().(*proto.ExtensionDesc); ok { + // known extension... we already got it above + return true + } + var fd *desc.FieldDescriptor + fd, err = desc.WrapField(fld) + if err != nil { + return false + } + v := convertProtoReflectValue(val) + if v, err = validFieldValue(fd, v); err != nil { + return false + } + values[fd] = v + return true + }) + if err != nil { + return err + } + + // unrecognized extensions fields: + // In API v2 of proto, some extensions may NEITHER be included in ExtensionDescs + // above NOR included in unrecognized fields below. These are extensions that use + // a custom extension type (not a generated one -- i.e. not a linked in extension). + mr := proto.MessageReflect(pm) + var extBytes []byte + var retErr error + mr.Range(func(fld protoreflect.FieldDescriptor, val protoreflect.Value) bool { + if !fld.IsExtension() { + // normal field, already processed above + return true + } + if extd, ok := fld.(protoreflect.ExtensionTypeDescriptor); ok { + if _, ok := extd.Type().(*proto.ExtensionDesc); ok { + // normal known extension, already processed above + return true + } + } + + // marshal the extension to bytes and then handle as unknown field below + mr.New() + mr.Set(fld, val) + extBytes, retErr = protov2.MarshalOptions{}.MarshalAppend(extBytes, mr.Interface()) + return retErr == nil + }) + if retErr != nil { + return retErr + } + + // now actually perform the merge + for fd, v := range values { + if err := mergeField(m, fd, v); err != nil { + return err + } + } + + if len(extBytes) > 0 { + // treating unrecognized extensions like unknown fields: best-effort + // ignore any error returned: pulling in unknown fields is best-effort + _ = m.UnmarshalMerge(extBytes) + } + + data := internal.GetUnrecognized(pm) + if len(data) > 0 { + // ignore any error returned: pulling in unknown fields is best-effort + _ = m.UnmarshalMerge(data) + } + + return nil +} + +func convertProtoReflectValue(v protoreflect.Value) interface{} { + val := v.Interface() + switch val := val.(type) { + case protoreflect.Message: + return val.Interface() + case protoreflect.Map: + mp := make(map[interface{}]interface{}, val.Len()) + val.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + mp[convertProtoReflectValue(k.Value())] = convertProtoReflectValue(v) + return true + }) + return mp + case protoreflect.List: + sl := make([]interface{}, val.Len()) + for i := 0; i < val.Len(); i++ { + sl[i] = convertProtoReflectValue(val.Get(i)) + } + return sl + case protoreflect.EnumNumber: + return int32(val) + default: + return val + } +} + +// Validate checks that all required fields are present. It returns an error if any are absent. +func (m *Message) Validate() error { + missingFields := m.findMissingFields() + if len(missingFields) == 0 { + return nil + } + return fmt.Errorf("some required fields missing: %v", strings.Join(missingFields, ", ")) +} + +func (m *Message) findMissingFields() []string { + if m.md.IsProto3() { + // proto3 does not allow required fields + return nil + } + var missingFields []string + for _, fd := range m.md.GetFields() { + if fd.IsRequired() { + if _, ok := m.values[fd.GetNumber()]; !ok { + missingFields = append(missingFields, fd.GetName()) + } + } + } + return missingFields +} + +// ValidateRecursive checks that all required fields are present and also +// recursively validates all fields who are also messages. It returns an error +// if any required fields, in this message or nested within, are absent. +func (m *Message) ValidateRecursive() error { + return m.validateRecursive("") +} + +func (m *Message) validateRecursive(prefix string) error { + if missingFields := m.findMissingFields(); len(missingFields) > 0 { + for i := range missingFields { + missingFields[i] = fmt.Sprintf("%s%s", prefix, missingFields[i]) + } + return fmt.Errorf("some required fields missing: %v", strings.Join(missingFields, ", ")) + } + + for tag, fld := range m.values { + fd := m.FindFieldDescriptor(tag) + var chprefix string + var md *desc.MessageDescriptor + checkMsg := func(pm proto.Message) error { + var dm *Message + if d, ok := pm.(*Message); ok { + dm = d + } else if pm != nil { + dm = m.mf.NewDynamicMessage(md) + if err := dm.ConvertFrom(pm); err != nil { + return nil + } + } + if dm == nil { + return nil + } + if err := dm.validateRecursive(chprefix); err != nil { + return err + } + return nil + } + isMap := fd.IsMap() + if isMap && fd.GetMapValueType().GetMessageType() != nil { + md = fd.GetMapValueType().GetMessageType() + mp := fld.(map[interface{}]interface{}) + for k, v := range mp { + chprefix = fmt.Sprintf("%s%s[%v].", prefix, getName(fd), k) + if err := checkMsg(v.(proto.Message)); err != nil { + return err + } + } + } else if !isMap && fd.GetMessageType() != nil { + md = fd.GetMessageType() + if fd.IsRepeated() { + sl := fld.([]interface{}) + for i, v := range sl { + chprefix = fmt.Sprintf("%s%s[%d].", prefix, getName(fd), i) + if err := checkMsg(v.(proto.Message)); err != nil { + return err + } + } + } else { + chprefix = fmt.Sprintf("%s%s.", prefix, getName(fd)) + if err := checkMsg(fld.(proto.Message)); err != nil { + return err + } + } + } + } + + return nil +} + +func getName(fd *desc.FieldDescriptor) string { + if fd.IsExtension() { + return fmt.Sprintf("(%s)", fd.GetFullyQualifiedName()) + } else { + return fd.GetName() + } +} + +// knownFieldTags return tags of present and recognized fields, in sorted order. +func (m *Message) knownFieldTags() []int { + if len(m.values) == 0 { + return []int(nil) + } + + keys := make([]int, len(m.values)) + i := 0 + for k := range m.values { + keys[i] = int(k) + i++ + } + + sort.Ints(keys) + return keys +} + +// allKnownFieldTags return tags of present and recognized fields, including +// those that are unset, in sorted order. This only includes extensions that are +// present. Known but not-present extensions are not included in the returned +// set of tags. +func (m *Message) allKnownFieldTags() []int { + fds := m.md.GetFields() + keys := make([]int, 0, len(fds)+len(m.extraFields)) + + for k := range m.values { + keys = append(keys, int(k)) + } + + // also include known fields that are not present + for _, fd := range fds { + if _, ok := m.values[fd.GetNumber()]; !ok { + keys = append(keys, int(fd.GetNumber())) + } + } + for _, fd := range m.extraFields { + if !fd.IsExtension() { // skip extensions that are not present + if _, ok := m.values[fd.GetNumber()]; !ok { + keys = append(keys, int(fd.GetNumber())) + } + } + } + + sort.Ints(keys) + return keys +} + +// unknownFieldTags return tags of present but unrecognized fields, in sorted order. +func (m *Message) unknownFieldTags() []int { + if len(m.unknownFields) == 0 { + return []int(nil) + } + keys := make([]int, len(m.unknownFields)) + i := 0 + for k := range m.unknownFields { + keys[i] = int(k) + i++ + } + sort.Ints(keys) + return keys +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/equal.go b/vendor/github.com/jhump/protoreflect/dynamic/equal.go new file mode 100644 index 0000000..e44c6c5 --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/equal.go @@ -0,0 +1,157 @@ +package dynamic + +import ( + "bytes" + "reflect" + + "github.com/golang/protobuf/proto" + + "github.com/jhump/protoreflect/desc" +) + +// Equal returns true if the given two dynamic messages are equal. Two messages are equal when they +// have the same message type and same fields set to equal values. For proto3 messages, fields set +// to their zero value are considered unset. +func Equal(a, b *Message) bool { + if a == b { + return true + } + if (a == nil) != (b == nil) { + return false + } + if a.md.GetFullyQualifiedName() != b.md.GetFullyQualifiedName() { + return false + } + if len(a.values) != len(b.values) { + return false + } + if len(a.unknownFields) != len(b.unknownFields) { + return false + } + for tag, aval := range a.values { + bval, ok := b.values[tag] + if !ok { + return false + } + if !fieldsEqual(aval, bval) { + return false + } + } + for tag, au := range a.unknownFields { + bu, ok := b.unknownFields[tag] + if !ok { + return false + } + if len(au) != len(bu) { + return false + } + for i, aval := range au { + bval := bu[i] + if aval.Encoding != bval.Encoding { + return false + } + if aval.Encoding == proto.WireBytes || aval.Encoding == proto.WireStartGroup { + if !bytes.Equal(aval.Contents, bval.Contents) { + return false + } + } else if aval.Value != bval.Value { + return false + } + } + } + // all checks pass! + return true +} + +func fieldsEqual(aval, bval interface{}) bool { + arv := reflect.ValueOf(aval) + brv := reflect.ValueOf(bval) + if arv.Type() != brv.Type() { + // it is possible that one is a dynamic message and one is not + apm, ok := aval.(proto.Message) + if !ok { + return false + } + bpm, ok := bval.(proto.Message) + if !ok { + return false + } + return MessagesEqual(apm, bpm) + + } else { + switch arv.Kind() { + case reflect.Ptr: + apm, ok := aval.(proto.Message) + if !ok { + // Don't know how to compare pointer values that aren't messages! + // Maybe this should panic? + return false + } + bpm := bval.(proto.Message) // we know it will succeed because we know a and b have same type + return MessagesEqual(apm, bpm) + + case reflect.Map: + return mapsEqual(arv, brv) + + case reflect.Slice: + if arv.Type() == typeOfBytes { + return bytes.Equal(aval.([]byte), bval.([]byte)) + } else { + return slicesEqual(arv, brv) + } + + default: + return aval == bval + } + } +} + +func slicesEqual(a, b reflect.Value) bool { + if a.Len() != b.Len() { + return false + } + for i := 0; i < a.Len(); i++ { + ai := a.Index(i) + bi := b.Index(i) + if !fieldsEqual(ai.Interface(), bi.Interface()) { + return false + } + } + return true +} + +// MessagesEqual returns true if the given two messages are equal. Use this instead of proto.Equal +// when one or both of the messages might be a dynamic message. +func MessagesEqual(a, b proto.Message) bool { + da, aok := a.(*Message) + db, bok := b.(*Message) + // Both dynamic messages + if aok && bok { + return Equal(da, db) + } + // Neither dynamic messages + if !aok && !bok { + return proto.Equal(a, b) + } + // Mixed + if bok { + // we want a to be the dynamic one + b, da = a, db + } + + // Instead of panic'ing below if we have a nil dynamic message, check + // now and return false if the input message is not also nil. + if da == nil { + return isNil(b) + } + + md, err := desc.LoadMessageDescriptorForMessage(b) + if err != nil { + return false + } + db = NewMessageWithMessageFactory(md, da.mf) + if db.ConvertFrom(b) != nil { + return false + } + return Equal(da, db) +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/extension.go b/vendor/github.com/jhump/protoreflect/dynamic/extension.go new file mode 100644 index 0000000..1d38161 --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/extension.go @@ -0,0 +1,46 @@ +package dynamic + +import ( + "fmt" + + "github.com/golang/protobuf/proto" + + "github.com/jhump/protoreflect/codec" + "github.com/jhump/protoreflect/desc" +) + +// SetExtension sets the given extension value. If the given message is not a +// dynamic message, the given extension may not be recognized (or may differ +// from the compiled and linked in version of the extension. So in that case, +// this function will serialize the given value to bytes and then use +// proto.SetRawExtension to set the value. +func SetExtension(msg proto.Message, extd *desc.FieldDescriptor, val interface{}) error { + if !extd.IsExtension() { + return fmt.Errorf("given field %s is not an extension", extd.GetFullyQualifiedName()) + } + + if dm, ok := msg.(*Message); ok { + return dm.TrySetField(extd, val) + } + + md, err := desc.LoadMessageDescriptorForMessage(msg) + if err != nil { + return err + } + if err := checkField(extd, md); err != nil { + return err + } + + val, err = validFieldValue(extd, val) + if err != nil { + return err + } + + var b codec.Buffer + b.SetDeterministic(defaultDeterminism) + if err := b.EncodeFieldValue(extd, val); err != nil { + return err + } + proto.SetRawExtension(msg, extd.GetNumber(), b.Bytes()) + return nil +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/extension_registry.go b/vendor/github.com/jhump/protoreflect/dynamic/extension_registry.go new file mode 100644 index 0000000..6876827 --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/extension_registry.go @@ -0,0 +1,241 @@ +package dynamic + +import ( + "fmt" + "reflect" + "sync" + + "github.com/golang/protobuf/proto" + + "github.com/jhump/protoreflect/desc" +) + +// ExtensionRegistry is a registry of known extension fields. This is used to parse +// extension fields encountered when de-serializing a dynamic message. +type ExtensionRegistry struct { + includeDefault bool + mu sync.RWMutex + exts map[string]map[int32]*desc.FieldDescriptor +} + +// NewExtensionRegistryWithDefaults is a registry that includes all "default" extensions, +// which are those that are statically linked into the current program (e.g. registered by +// protoc-generated code via proto.RegisterExtension). Extensions explicitly added to the +// registry will override any default extensions that are for the same extendee and have the +// same tag number and/or name. +func NewExtensionRegistryWithDefaults() *ExtensionRegistry { + return &ExtensionRegistry{includeDefault: true} +} + +// AddExtensionDesc adds the given extensions to the registry. +func (r *ExtensionRegistry) AddExtensionDesc(exts ...*proto.ExtensionDesc) error { + flds := make([]*desc.FieldDescriptor, len(exts)) + for i, ext := range exts { + fd, err := desc.LoadFieldDescriptorForExtension(ext) + if err != nil { + return err + } + flds[i] = fd + } + r.mu.Lock() + defer r.mu.Unlock() + if r.exts == nil { + r.exts = map[string]map[int32]*desc.FieldDescriptor{} + } + for _, fd := range flds { + r.putExtensionLocked(fd) + } + return nil +} + +// AddExtension adds the given extensions to the registry. The given extensions +// will overwrite any previously added extensions that are for the same extendee +// message and same extension tag number. +func (r *ExtensionRegistry) AddExtension(exts ...*desc.FieldDescriptor) error { + for _, ext := range exts { + if !ext.IsExtension() { + return fmt.Errorf("given field is not an extension: %s", ext.GetFullyQualifiedName()) + } + } + r.mu.Lock() + defer r.mu.Unlock() + if r.exts == nil { + r.exts = map[string]map[int32]*desc.FieldDescriptor{} + } + for _, ext := range exts { + r.putExtensionLocked(ext) + } + return nil +} + +// AddExtensionsFromFile adds to the registry all extension fields defined in the given file descriptor. +func (r *ExtensionRegistry) AddExtensionsFromFile(fd *desc.FileDescriptor) { + r.mu.Lock() + defer r.mu.Unlock() + r.addExtensionsFromFileLocked(fd, false, nil) +} + +// AddExtensionsFromFileRecursively adds to the registry all extension fields defined in the give file +// descriptor and also recursively adds all extensions defined in that file's dependencies. This adds +// extensions from the entire transitive closure for the given file. +func (r *ExtensionRegistry) AddExtensionsFromFileRecursively(fd *desc.FileDescriptor) { + r.mu.Lock() + defer r.mu.Unlock() + already := map[*desc.FileDescriptor]struct{}{} + r.addExtensionsFromFileLocked(fd, true, already) +} + +func (r *ExtensionRegistry) addExtensionsFromFileLocked(fd *desc.FileDescriptor, recursive bool, alreadySeen map[*desc.FileDescriptor]struct{}) { + if _, ok := alreadySeen[fd]; ok { + return + } + + if r.exts == nil { + r.exts = map[string]map[int32]*desc.FieldDescriptor{} + } + for _, ext := range fd.GetExtensions() { + r.putExtensionLocked(ext) + } + for _, msg := range fd.GetMessageTypes() { + r.addExtensionsFromMessageLocked(msg) + } + + if recursive { + alreadySeen[fd] = struct{}{} + for _, dep := range fd.GetDependencies() { + r.addExtensionsFromFileLocked(dep, recursive, alreadySeen) + } + } +} + +func (r *ExtensionRegistry) addExtensionsFromMessageLocked(md *desc.MessageDescriptor) { + for _, ext := range md.GetNestedExtensions() { + r.putExtensionLocked(ext) + } + for _, msg := range md.GetNestedMessageTypes() { + r.addExtensionsFromMessageLocked(msg) + } +} + +func (r *ExtensionRegistry) putExtensionLocked(fd *desc.FieldDescriptor) { + msgName := fd.GetOwner().GetFullyQualifiedName() + m := r.exts[msgName] + if m == nil { + m = map[int32]*desc.FieldDescriptor{} + r.exts[msgName] = m + } + m[fd.GetNumber()] = fd +} + +// FindExtension queries for the extension field with the given extendee name (must be a fully-qualified +// message name) and tag number. If no extension is known, nil is returned. +func (r *ExtensionRegistry) FindExtension(messageName string, tagNumber int32) *desc.FieldDescriptor { + if r == nil { + return nil + } + r.mu.RLock() + defer r.mu.RUnlock() + fd := r.exts[messageName][tagNumber] + if fd == nil && r.includeDefault { + ext := getDefaultExtensions(messageName)[tagNumber] + if ext != nil { + fd, _ = desc.LoadFieldDescriptorForExtension(ext) + } + } + return fd +} + +// FindExtensionByName queries for the extension field with the given extendee name (must be a fully-qualified +// message name) and field name (must also be a fully-qualified extension name). If no extension is known, nil +// is returned. +func (r *ExtensionRegistry) FindExtensionByName(messageName string, fieldName string) *desc.FieldDescriptor { + if r == nil { + return nil + } + r.mu.RLock() + defer r.mu.RUnlock() + for _, fd := range r.exts[messageName] { + if fd.GetFullyQualifiedName() == fieldName { + return fd + } + } + if r.includeDefault { + for _, ext := range getDefaultExtensions(messageName) { + fd, _ := desc.LoadFieldDescriptorForExtension(ext) + if fd.GetFullyQualifiedName() == fieldName { + return fd + } + } + } + return nil +} + +// FindExtensionByJSONName queries for the extension field with the given extendee name (must be a fully-qualified +// message name) and JSON field name (must also be a fully-qualified name). If no extension is known, nil is returned. +// The fully-qualified JSON name is the same as the extension's normal fully-qualified name except that the last +// component uses the field's JSON name (if present). +func (r *ExtensionRegistry) FindExtensionByJSONName(messageName string, fieldName string) *desc.FieldDescriptor { + if r == nil { + return nil + } + r.mu.RLock() + defer r.mu.RUnlock() + for _, fd := range r.exts[messageName] { + if fd.GetFullyQualifiedJSONName() == fieldName { + return fd + } + } + if r.includeDefault { + for _, ext := range getDefaultExtensions(messageName) { + fd, _ := desc.LoadFieldDescriptorForExtension(ext) + if fd.GetFullyQualifiedJSONName() == fieldName { + return fd + } + } + } + return nil +} + +func getDefaultExtensions(messageName string) map[int32]*proto.ExtensionDesc { + t := proto.MessageType(messageName) + if t != nil { + msg := reflect.Zero(t).Interface().(proto.Message) + return proto.RegisteredExtensions(msg) + } + return nil +} + +// AllExtensionsForType returns all known extension fields for the given extendee name (must be a +// fully-qualified message name). +func (r *ExtensionRegistry) AllExtensionsForType(messageName string) []*desc.FieldDescriptor { + if r == nil { + return []*desc.FieldDescriptor(nil) + } + r.mu.RLock() + defer r.mu.RUnlock() + flds := r.exts[messageName] + var ret []*desc.FieldDescriptor + if r.includeDefault { + exts := getDefaultExtensions(messageName) + if len(exts) > 0 || len(flds) > 0 { + ret = make([]*desc.FieldDescriptor, 0, len(exts)+len(flds)) + } + for tag, ext := range exts { + if _, ok := flds[tag]; ok { + // skip default extension and use the one explicitly registered instead + continue + } + fd, _ := desc.LoadFieldDescriptorForExtension(ext) + if fd != nil { + ret = append(ret, fd) + } + } + } else if len(flds) > 0 { + ret = make([]*desc.FieldDescriptor, 0, len(flds)) + } + + for _, ext := range flds { + ret = append(ret, ext) + } + return ret +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/grpcdynamic/stub.go b/vendor/github.com/jhump/protoreflect/dynamic/grpcdynamic/stub.go new file mode 100644 index 0000000..6fca393 --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/grpcdynamic/stub.go @@ -0,0 +1,310 @@ +// Package grpcdynamic provides a dynamic RPC stub. It can be used to invoke RPC +// method where only method descriptors are known. The actual request and response +// messages may be dynamic messages. +package grpcdynamic + +import ( + "context" + "fmt" + "io" + + "github.com/golang/protobuf/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/dynamic" +) + +// Stub is an RPC client stub, used for dynamically dispatching RPCs to a server. +type Stub struct { + channel Channel + mf *dynamic.MessageFactory +} + +// Channel represents the operations necessary to issue RPCs via gRPC. The +// *grpc.ClientConn type provides this interface and will typically the concrete +// type used to construct Stubs. But the use of this interface allows +// construction of stubs that use alternate concrete types as the transport for +// RPC operations. +type Channel = grpc.ClientConnInterface + +// NewStub creates a new RPC stub that uses the given channel for dispatching RPCs. +func NewStub(channel Channel) Stub { + return NewStubWithMessageFactory(channel, nil) +} + +// NewStubWithMessageFactory creates a new RPC stub that uses the given channel for +// dispatching RPCs and the given MessageFactory for creating response messages. +func NewStubWithMessageFactory(channel Channel, mf *dynamic.MessageFactory) Stub { + return Stub{channel: channel, mf: mf} +} + +func requestMethod(md *desc.MethodDescriptor) string { + return fmt.Sprintf("/%s/%s", md.GetService().GetFullyQualifiedName(), md.GetName()) +} + +// InvokeRpc sends a unary RPC and returns the response. Use this for unary methods. +func (s Stub) InvokeRpc(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (proto.Message, error) { + if method.IsClientStreaming() || method.IsServerStreaming() { + return nil, fmt.Errorf("InvokeRpc is for unary methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) + } + if err := checkMessageType(method.GetInputType(), request); err != nil { + return nil, err + } + resp := s.mf.NewMessage(method.GetOutputType()) + if err := s.channel.Invoke(ctx, requestMethod(method), request, resp, opts...); err != nil { + return nil, err + } + return resp, nil +} + +// InvokeRpcServerStream sends a unary RPC and returns the response stream. Use this for server-streaming methods. +func (s Stub) InvokeRpcServerStream(ctx context.Context, method *desc.MethodDescriptor, request proto.Message, opts ...grpc.CallOption) (*ServerStream, error) { + if method.IsClientStreaming() || !method.IsServerStreaming() { + return nil, fmt.Errorf("InvokeRpcServerStream is for server-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) + } + if err := checkMessageType(method.GetInputType(), request); err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(ctx) + sd := grpc.StreamDesc{ + StreamName: method.GetName(), + ServerStreams: method.IsServerStreaming(), + ClientStreams: method.IsClientStreaming(), + } + if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { + cancel() + return nil, err + } else { + err = cs.SendMsg(request) + if err != nil { + cancel() + return nil, err + } + err = cs.CloseSend() + if err != nil { + cancel() + return nil, err + } + go func() { + // when the new stream is finished, also cleanup the parent context + <-cs.Context().Done() + cancel() + }() + return &ServerStream{cs, method.GetOutputType(), s.mf}, nil + } +} + +// InvokeRpcClientStream creates a new stream that is used to send request messages and, at the end, +// receive the response message. Use this for client-streaming methods. +func (s Stub) InvokeRpcClientStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*ClientStream, error) { + if !method.IsClientStreaming() || method.IsServerStreaming() { + return nil, fmt.Errorf("InvokeRpcClientStream is for client-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) + } + ctx, cancel := context.WithCancel(ctx) + sd := grpc.StreamDesc{ + StreamName: method.GetName(), + ServerStreams: method.IsServerStreaming(), + ClientStreams: method.IsClientStreaming(), + } + if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { + cancel() + return nil, err + } else { + go func() { + // when the new stream is finished, also cleanup the parent context + <-cs.Context().Done() + cancel() + }() + return &ClientStream{cs, method, s.mf, cancel}, nil + } +} + +// InvokeRpcBidiStream creates a new stream that is used to both send request messages and receive response +// messages. Use this for bidi-streaming methods. +func (s Stub) InvokeRpcBidiStream(ctx context.Context, method *desc.MethodDescriptor, opts ...grpc.CallOption) (*BidiStream, error) { + if !method.IsClientStreaming() || !method.IsServerStreaming() { + return nil, fmt.Errorf("InvokeRpcBidiStream is for bidi-streaming methods; %q is %s", method.GetFullyQualifiedName(), methodType(method)) + } + sd := grpc.StreamDesc{ + StreamName: method.GetName(), + ServerStreams: method.IsServerStreaming(), + ClientStreams: method.IsClientStreaming(), + } + if cs, err := s.channel.NewStream(ctx, &sd, requestMethod(method), opts...); err != nil { + return nil, err + } else { + return &BidiStream{cs, method.GetInputType(), method.GetOutputType(), s.mf}, nil + } +} + +func methodType(md *desc.MethodDescriptor) string { + if md.IsClientStreaming() && md.IsServerStreaming() { + return "bidi-streaming" + } else if md.IsClientStreaming() { + return "client-streaming" + } else if md.IsServerStreaming() { + return "server-streaming" + } else { + return "unary" + } +} + +func checkMessageType(md *desc.MessageDescriptor, msg proto.Message) error { + var typeName string + if dm, ok := msg.(*dynamic.Message); ok { + typeName = dm.GetMessageDescriptor().GetFullyQualifiedName() + } else { + typeName = proto.MessageName(msg) + } + if typeName != md.GetFullyQualifiedName() { + return fmt.Errorf("expecting message of type %s; got %s", md.GetFullyQualifiedName(), typeName) + } + return nil +} + +// ServerStream represents a response stream from a server. Messages in the stream can be queried +// as can header and trailer metadata sent by the server. +type ServerStream struct { + stream grpc.ClientStream + respType *desc.MessageDescriptor + mf *dynamic.MessageFactory +} + +// Header returns any header metadata sent by the server (blocks if necessary until headers are +// received). +func (s *ServerStream) Header() (metadata.MD, error) { + return s.stream.Header() +} + +// Trailer returns the trailer metadata sent by the server. It must only be called after +// RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). +func (s *ServerStream) Trailer() metadata.MD { + return s.stream.Trailer() +} + +// Context returns the context associated with this streaming operation. +func (s *ServerStream) Context() context.Context { + return s.stream.Context() +} + +// RecvMsg returns the next message in the response stream or an error. If the stream +// has completed normally, the error is io.EOF. Otherwise, the error indicates the +// nature of the abnormal termination of the stream. +func (s *ServerStream) RecvMsg() (proto.Message, error) { + resp := s.mf.NewMessage(s.respType) + if err := s.stream.RecvMsg(resp); err != nil { + return nil, err + } else { + return resp, nil + } +} + +// ClientStream represents a response stream from a client. Messages in the stream can be sent +// and, when done, the unary server message and header and trailer metadata can be queried. +type ClientStream struct { + stream grpc.ClientStream + method *desc.MethodDescriptor + mf *dynamic.MessageFactory + cancel context.CancelFunc +} + +// Header returns any header metadata sent by the server (blocks if necessary until headers are +// received). +func (s *ClientStream) Header() (metadata.MD, error) { + return s.stream.Header() +} + +// Trailer returns the trailer metadata sent by the server. It must only be called after +// RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). +func (s *ClientStream) Trailer() metadata.MD { + return s.stream.Trailer() +} + +// Context returns the context associated with this streaming operation. +func (s *ClientStream) Context() context.Context { + return s.stream.Context() +} + +// SendMsg sends a request message to the server. +func (s *ClientStream) SendMsg(m proto.Message) error { + if err := checkMessageType(s.method.GetInputType(), m); err != nil { + return err + } + return s.stream.SendMsg(m) +} + +// CloseAndReceive closes the outgoing request stream and then blocks for the server's response. +func (s *ClientStream) CloseAndReceive() (proto.Message, error) { + if err := s.stream.CloseSend(); err != nil { + return nil, err + } + resp := s.mf.NewMessage(s.method.GetOutputType()) + if err := s.stream.RecvMsg(resp); err != nil { + return nil, err + } + // make sure we get EOF for a second message + if err := s.stream.RecvMsg(resp); err != io.EOF { + if err == nil { + s.cancel() + return nil, fmt.Errorf("client-streaming method %q returned more than one response message", s.method.GetFullyQualifiedName()) + } else { + return nil, err + } + } + return resp, nil +} + +// BidiStream represents a bi-directional stream for sending messages to and receiving +// messages from a server. The header and trailer metadata sent by the server can also be +// queried. +type BidiStream struct { + stream grpc.ClientStream + reqType *desc.MessageDescriptor + respType *desc.MessageDescriptor + mf *dynamic.MessageFactory +} + +// Header returns any header metadata sent by the server (blocks if necessary until headers are +// received). +func (s *BidiStream) Header() (metadata.MD, error) { + return s.stream.Header() +} + +// Trailer returns the trailer metadata sent by the server. It must only be called after +// RecvMsg returns a non-nil error (which may be EOF for normal completion of stream). +func (s *BidiStream) Trailer() metadata.MD { + return s.stream.Trailer() +} + +// Context returns the context associated with this streaming operation. +func (s *BidiStream) Context() context.Context { + return s.stream.Context() +} + +// SendMsg sends a request message to the server. +func (s *BidiStream) SendMsg(m proto.Message) error { + if err := checkMessageType(s.reqType, m); err != nil { + return err + } + return s.stream.SendMsg(m) +} + +// CloseSend indicates the request stream has ended. Invoke this after all request messages +// are sent (even if there are zero such messages). +func (s *BidiStream) CloseSend() error { + return s.stream.CloseSend() +} + +// RecvMsg returns the next message in the response stream or an error. If the stream +// has completed normally, the error is io.EOF. Otherwise, the error indicates the +// nature of the abnormal termination of the stream. +func (s *BidiStream) RecvMsg() (proto.Message, error) { + resp := s.mf.NewMessage(s.respType) + if err := s.stream.RecvMsg(resp); err != nil { + return nil, err + } else { + return resp, nil + } +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/indent.go b/vendor/github.com/jhump/protoreflect/dynamic/indent.go new file mode 100644 index 0000000..bd7fcaa --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/indent.go @@ -0,0 +1,76 @@ +package dynamic + +import "bytes" + +type indentBuffer struct { + bytes.Buffer + indent string + indentCount int + comma bool +} + +func (b *indentBuffer) start() error { + if b.indentCount >= 0 { + b.indentCount++ + return b.newLine(false) + } + return nil +} + +func (b *indentBuffer) sep() error { + if b.indentCount >= 0 { + _, err := b.WriteString(": ") + return err + } else { + return b.WriteByte(':') + } +} + +func (b *indentBuffer) end() error { + if b.indentCount >= 0 { + b.indentCount-- + return b.newLine(false) + } + return nil +} + +func (b *indentBuffer) maybeNext(first *bool) error { + if *first { + *first = false + return nil + } else { + return b.next() + } +} + +func (b *indentBuffer) next() error { + if b.indentCount >= 0 { + return b.newLine(b.comma) + } else if b.comma { + return b.WriteByte(',') + } else { + return b.WriteByte(' ') + } +} + +func (b *indentBuffer) newLine(comma bool) error { + if comma { + err := b.WriteByte(',') + if err != nil { + return err + } + } + + err := b.WriteByte('\n') + if err != nil { + return err + } + + for i := 0; i < b.indentCount; i++ { + _, err := b.WriteString(b.indent) + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/json.go b/vendor/github.com/jhump/protoreflect/dynamic/json.go new file mode 100644 index 0000000..9081965 --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/json.go @@ -0,0 +1,1256 @@ +package dynamic + +// JSON marshalling and unmarshalling for dynamic messages + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "math" + "reflect" + "sort" + "strconv" + "strings" + + "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + // link in the well-known-types that have a special JSON format + _ "google.golang.org/protobuf/types/known/anypb" + _ "google.golang.org/protobuf/types/known/durationpb" + _ "google.golang.org/protobuf/types/known/emptypb" + _ "google.golang.org/protobuf/types/known/structpb" + _ "google.golang.org/protobuf/types/known/timestamppb" + _ "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/jhump/protoreflect/desc" +) + +var wellKnownTypeNames = map[string]struct{}{ + "google.protobuf.Any": {}, + "google.protobuf.Empty": {}, + "google.protobuf.Duration": {}, + "google.protobuf.Timestamp": {}, + // struct.proto + "google.protobuf.Struct": {}, + "google.protobuf.Value": {}, + "google.protobuf.ListValue": {}, + // wrappers.proto + "google.protobuf.DoubleValue": {}, + "google.protobuf.FloatValue": {}, + "google.protobuf.Int64Value": {}, + "google.protobuf.UInt64Value": {}, + "google.protobuf.Int32Value": {}, + "google.protobuf.UInt32Value": {}, + "google.protobuf.BoolValue": {}, + "google.protobuf.StringValue": {}, + "google.protobuf.BytesValue": {}, +} + +// MarshalJSON serializes this message to bytes in JSON format, returning an +// error if the operation fails. The resulting bytes will be a valid UTF8 +// string. +// +// This method uses a compact form: no newlines, and spaces between fields and +// between field identifiers and values are elided. +// +// This method is convenient shorthand for invoking MarshalJSONPB with a default +// (zero value) marshaler: +// +// m.MarshalJSONPB(&jsonpb.Marshaler{}) +// +// So enums are serialized using enum value name strings, and values that are +// not present (including those with default/zero value for messages defined in +// "proto3" syntax) are omitted. +func (m *Message) MarshalJSON() ([]byte, error) { + return m.MarshalJSONPB(&jsonpb.Marshaler{}) +} + +// MarshalJSONIndent serializes this message to bytes in JSON format, returning +// an error if the operation fails. The resulting bytes will be a valid UTF8 +// string. +// +// This method uses a "pretty-printed" form, with each field on its own line and +// spaces between field identifiers and values. Indentation of two spaces is +// used. +// +// This method is convenient shorthand for invoking MarshalJSONPB with a default +// (zero value) marshaler: +// +// m.MarshalJSONPB(&jsonpb.Marshaler{Indent: " "}) +// +// So enums are serialized using enum value name strings, and values that are +// not present (including those with default/zero value for messages defined in +// "proto3" syntax) are omitted. +func (m *Message) MarshalJSONIndent() ([]byte, error) { + return m.MarshalJSONPB(&jsonpb.Marshaler{Indent: " "}) +} + +// MarshalJSONPB serializes this message to bytes in JSON format, returning an +// error if the operation fails. The resulting bytes will be a valid UTF8 +// string. The given marshaler is used to convey options used during marshaling. +// +// If this message contains nested messages that are generated message types (as +// opposed to dynamic messages), the given marshaler is used to marshal it. +// +// When marshaling any nested messages, any jsonpb.AnyResolver configured in the +// given marshaler is augmented with knowledge of message types known to this +// message's descriptor (and its enclosing file and set of transitive +// dependencies). +func (m *Message) MarshalJSONPB(opts *jsonpb.Marshaler) ([]byte, error) { + var b indentBuffer + b.indent = opts.Indent + if len(opts.Indent) == 0 { + b.indentCount = -1 + } + b.comma = true + if err := m.marshalJSON(&b, opts); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func (m *Message) marshalJSON(b *indentBuffer, opts *jsonpb.Marshaler) error { + if m == nil { + _, err := b.WriteString("null") + return err + } + if r, changed := wrapResolver(opts.AnyResolver, m.mf, m.md.GetFile()); changed { + newOpts := *opts + newOpts.AnyResolver = r + opts = &newOpts + } + + if ok, err := marshalWellKnownType(m, b, opts); ok { + return err + } + + err := b.WriteByte('{') + if err != nil { + return err + } + err = b.start() + if err != nil { + return err + } + + var tags []int + if opts.EmitDefaults { + tags = m.allKnownFieldTags() + } else { + tags = m.knownFieldTags() + } + + first := true + + for _, tag := range tags { + itag := int32(tag) + fd := m.FindFieldDescriptor(itag) + + v, ok := m.values[itag] + if !ok { + if fd.GetOneOf() != nil { + // don't print defaults for fields in a oneof + continue + } + v = fd.GetDefaultValue() + } + + err := b.maybeNext(&first) + if err != nil { + return err + } + err = marshalKnownFieldJSON(b, fd, v, opts) + if err != nil { + return err + } + } + + err = b.end() + if err != nil { + return err + } + err = b.WriteByte('}') + if err != nil { + return err + } + + return nil +} + +func marshalWellKnownType(m *Message, b *indentBuffer, opts *jsonpb.Marshaler) (bool, error) { + fqn := m.md.GetFullyQualifiedName() + if _, ok := wellKnownTypeNames[fqn]; !ok { + return false, nil + } + + msgType := proto.MessageType(fqn) + if msgType == nil { + // wtf? + panic(fmt.Sprintf("could not find registered message type for %q", fqn)) + } + + // convert dynamic message to well-known type and let jsonpb marshal it + msg := reflect.New(msgType.Elem()).Interface().(proto.Message) + if err := m.MergeInto(msg); err != nil { + return true, err + } + return true, opts.Marshal(b, msg) +} + +func marshalKnownFieldJSON(b *indentBuffer, fd *desc.FieldDescriptor, v interface{}, opts *jsonpb.Marshaler) error { + var jsonName string + if opts.OrigName { + jsonName = fd.GetName() + } else { + jsonName = fd.AsFieldDescriptorProto().GetJsonName() + if jsonName == "" { + jsonName = fd.GetName() + } + } + if fd.IsExtension() { + var scope string + switch parent := fd.GetParent().(type) { + case *desc.FileDescriptor: + scope = parent.GetPackage() + default: + scope = parent.GetFullyQualifiedName() + } + if scope == "" { + jsonName = fmt.Sprintf("[%s]", jsonName) + } else { + jsonName = fmt.Sprintf("[%s.%s]", scope, jsonName) + } + } + err := writeJsonString(b, jsonName) + if err != nil { + return err + } + err = b.sep() + if err != nil { + return err + } + + if isNil(v) { + _, err := b.WriteString("null") + return err + } + + if fd.IsMap() { + err = b.WriteByte('{') + if err != nil { + return err + } + err = b.start() + if err != nil { + return err + } + + md := fd.GetMessageType() + vfd := md.FindFieldByNumber(2) + + mp := v.(map[interface{}]interface{}) + keys := make([]interface{}, 0, len(mp)) + for k := range mp { + keys = append(keys, k) + } + sort.Sort(sortable(keys)) + first := true + for _, mk := range keys { + mv := mp[mk] + err := b.maybeNext(&first) + if err != nil { + return err + } + + err = marshalKnownFieldMapEntryJSON(b, mk, vfd, mv, opts) + if err != nil { + return err + } + } + + err = b.end() + if err != nil { + return err + } + return b.WriteByte('}') + + } else if fd.IsRepeated() { + err = b.WriteByte('[') + if err != nil { + return err + } + err = b.start() + if err != nil { + return err + } + + sl := v.([]interface{}) + first := true + for _, slv := range sl { + err := b.maybeNext(&first) + if err != nil { + return err + } + err = marshalKnownFieldValueJSON(b, fd, slv, opts) + if err != nil { + return err + } + } + + err = b.end() + if err != nil { + return err + } + return b.WriteByte(']') + + } else { + return marshalKnownFieldValueJSON(b, fd, v, opts) + } +} + +// sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64), +// bools, or strings. +type sortable []interface{} + +func (s sortable) Len() int { + return len(s) +} + +func (s sortable) Less(i, j int) bool { + vi := s[i] + vj := s[j] + switch reflect.TypeOf(vi).Kind() { + case reflect.Int32: + return vi.(int32) < vj.(int32) + case reflect.Int64: + return vi.(int64) < vj.(int64) + case reflect.Uint32: + return vi.(uint32) < vj.(uint32) + case reflect.Uint64: + return vi.(uint64) < vj.(uint64) + case reflect.String: + return vi.(string) < vj.(string) + case reflect.Bool: + return !vi.(bool) && vj.(bool) + default: + panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi))) + } +} + +func (s sortable) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func isNil(v interface{}) bool { + if v == nil { + return true + } + rv := reflect.ValueOf(v) + return rv.Kind() == reflect.Ptr && rv.IsNil() +} + +func marshalKnownFieldMapEntryJSON(b *indentBuffer, mk interface{}, vfd *desc.FieldDescriptor, mv interface{}, opts *jsonpb.Marshaler) error { + rk := reflect.ValueOf(mk) + var strkey string + switch rk.Kind() { + case reflect.Bool: + strkey = strconv.FormatBool(rk.Bool()) + case reflect.Int32, reflect.Int64: + strkey = strconv.FormatInt(rk.Int(), 10) + case reflect.Uint32, reflect.Uint64: + strkey = strconv.FormatUint(rk.Uint(), 10) + case reflect.String: + strkey = rk.String() + default: + return fmt.Errorf("invalid map key value: %v (%v)", mk, rk.Type()) + } + err := writeJsonString(b, strkey) + if err != nil { + return err + } + err = b.sep() + if err != nil { + return err + } + return marshalKnownFieldValueJSON(b, vfd, mv, opts) +} + +func marshalKnownFieldValueJSON(b *indentBuffer, fd *desc.FieldDescriptor, v interface{}, opts *jsonpb.Marshaler) error { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Int64: + return writeJsonString(b, strconv.FormatInt(rv.Int(), 10)) + case reflect.Int32: + ed := fd.GetEnumType() + if !opts.EnumsAsInts && ed != nil { + n := int32(rv.Int()) + vd := ed.FindValueByNumber(n) + if vd == nil { + _, err := b.WriteString(strconv.FormatInt(rv.Int(), 10)) + return err + } else { + return writeJsonString(b, vd.GetName()) + } + } else { + _, err := b.WriteString(strconv.FormatInt(rv.Int(), 10)) + return err + } + case reflect.Uint64: + return writeJsonString(b, strconv.FormatUint(rv.Uint(), 10)) + case reflect.Uint32: + _, err := b.WriteString(strconv.FormatUint(rv.Uint(), 10)) + return err + case reflect.Float32, reflect.Float64: + f := rv.Float() + var str string + if math.IsNaN(f) { + str = `"NaN"` + } else if math.IsInf(f, 1) { + str = `"Infinity"` + } else if math.IsInf(f, -1) { + str = `"-Infinity"` + } else { + var bits int + if rv.Kind() == reflect.Float32 { + bits = 32 + } else { + bits = 64 + } + str = strconv.FormatFloat(rv.Float(), 'g', -1, bits) + } + _, err := b.WriteString(str) + return err + case reflect.Bool: + _, err := b.WriteString(strconv.FormatBool(rv.Bool())) + return err + case reflect.Slice: + bstr := base64.StdEncoding.EncodeToString(rv.Bytes()) + return writeJsonString(b, bstr) + case reflect.String: + return writeJsonString(b, rv.String()) + default: + // must be a message + if isNil(v) { + _, err := b.WriteString("null") + return err + } + + if dm, ok := v.(*Message); ok { + return dm.marshalJSON(b, opts) + } + + var err error + if b.indentCount <= 0 || len(b.indent) == 0 { + err = opts.Marshal(b, v.(proto.Message)) + } else { + str, err := opts.MarshalToString(v.(proto.Message)) + if err != nil { + return err + } + indent := strings.Repeat(b.indent, b.indentCount) + pos := 0 + // add indention prefix to each line + for pos < len(str) { + start := pos + nextPos := strings.Index(str[pos:], "\n") + if nextPos == -1 { + nextPos = len(str) + } else { + nextPos = pos + nextPos + 1 // include newline + } + line := str[start:nextPos] + if pos > 0 { + _, err = b.WriteString(indent) + if err != nil { + return err + } + } + _, err = b.WriteString(line) + if err != nil { + return err + } + pos = nextPos + } + } + return err + } +} + +func writeJsonString(b *indentBuffer, s string) error { + if sbytes, err := json.Marshal(s); err != nil { + return err + } else { + _, err := b.Write(sbytes) + return err + } +} + +// UnmarshalJSON de-serializes the message that is present, in JSON format, in +// the given bytes into this message. It first resets the current message. It +// returns an error if the given bytes do not contain a valid encoding of this +// message type in JSON format. +// +// This method is shorthand for invoking UnmarshalJSONPB with a default (zero +// value) unmarshaler: +// +// m.UnmarshalMergeJSONPB(&jsonpb.Unmarshaler{}, js) +// +// So unknown fields will result in an error, and no provided jsonpb.AnyResolver +// will be used when parsing google.protobuf.Any messages. +func (m *Message) UnmarshalJSON(js []byte) error { + return m.UnmarshalJSONPB(&jsonpb.Unmarshaler{}, js) +} + +// UnmarshalMergeJSON de-serializes the message that is present, in JSON format, +// in the given bytes into this message. Unlike UnmarshalJSON, it does not first +// reset the message, instead merging the data in the given bytes into the +// existing data in this message. +func (m *Message) UnmarshalMergeJSON(js []byte) error { + return m.UnmarshalMergeJSONPB(&jsonpb.Unmarshaler{}, js) +} + +// UnmarshalJSONPB de-serializes the message that is present, in JSON format, in +// the given bytes into this message. The given unmarshaler conveys options used +// when parsing the JSON. This function first resets the current message. It +// returns an error if the given bytes do not contain a valid encoding of this +// message type in JSON format. +// +// The decoding is lenient: +// 1. The JSON can refer to fields either by their JSON name or by their +// declared name. +// 2. The JSON can use either numeric values or string names for enum values. +// +// When instantiating nested messages, if this message's associated factory +// returns a generated message type (as opposed to a dynamic message), the given +// unmarshaler is used to unmarshal it. +// +// When unmarshaling any nested messages, any jsonpb.AnyResolver configured in +// the given unmarshaler is augmented with knowledge of message types known to +// this message's descriptor (and its enclosing file and set of transitive +// dependencies). +func (m *Message) UnmarshalJSONPB(opts *jsonpb.Unmarshaler, js []byte) error { + m.Reset() + if err := m.UnmarshalMergeJSONPB(opts, js); err != nil { + return err + } + return m.Validate() +} + +// UnmarshalMergeJSONPB de-serializes the message that is present, in JSON +// format, in the given bytes into this message. The given unmarshaler conveys +// options used when parsing the JSON. Unlike UnmarshalJSONPB, it does not first +// reset the message, instead merging the data in the given bytes into the +// existing data in this message. +func (m *Message) UnmarshalMergeJSONPB(opts *jsonpb.Unmarshaler, js []byte) error { + r := newJsReader(js) + err := m.unmarshalJson(r, opts) + if err != nil { + return err + } + if t, err := r.poll(); err != io.EOF { + b, _ := ioutil.ReadAll(r.unread()) + s := fmt.Sprintf("%v%s", t, string(b)) + return fmt.Errorf("superfluous data found after JSON object: %q", s) + } + return nil +} + +func unmarshalWellKnownType(m *Message, r *jsReader, opts *jsonpb.Unmarshaler) (bool, error) { + fqn := m.md.GetFullyQualifiedName() + if _, ok := wellKnownTypeNames[fqn]; !ok { + return false, nil + } + + msgType := proto.MessageType(fqn) + if msgType == nil { + // wtf? + panic(fmt.Sprintf("could not find registered message type for %q", fqn)) + } + + // extract json value from r + var js json.RawMessage + if err := json.NewDecoder(r.unread()).Decode(&js); err != nil { + return true, err + } + if err := r.skip(); err != nil { + return true, err + } + + // unmarshal into well-known type and then convert to dynamic message + msg := reflect.New(msgType.Elem()).Interface().(proto.Message) + if err := opts.Unmarshal(bytes.NewReader(js), msg); err != nil { + return true, err + } + return true, m.MergeFrom(msg) +} + +func (m *Message) unmarshalJson(r *jsReader, opts *jsonpb.Unmarshaler) error { + if r, changed := wrapResolver(opts.AnyResolver, m.mf, m.md.GetFile()); changed { + newOpts := *opts + newOpts.AnyResolver = r + opts = &newOpts + } + + if ok, err := unmarshalWellKnownType(m, r, opts); ok { + return err + } + + t, err := r.peek() + if err != nil { + return err + } + if t == nil { + // if json is simply "null" we do nothing + r.poll() + return nil + } + + if err := r.beginObject(); err != nil { + return err + } + + for r.hasNext() { + f, err := r.nextObjectKey() + if err != nil { + return err + } + fd := m.FindFieldDescriptorByJSONName(f) + if fd == nil { + if opts.AllowUnknownFields { + r.skip() + continue + } + return fmt.Errorf("message type %s has no known field named %s", m.md.GetFullyQualifiedName(), f) + } + v, err := unmarshalJsField(fd, r, m.mf, opts) + if err != nil { + return err + } + if v != nil { + if err := mergeField(m, fd, v); err != nil { + return err + } + } else if fd.GetOneOf() != nil { + // preserve explicit null for oneof fields (this is a little odd but + // mimics the behavior of jsonpb with oneofs in generated message types) + if fd.GetMessageType() != nil { + typ := m.mf.GetKnownTypeRegistry().GetKnownType(fd.GetMessageType().GetFullyQualifiedName()) + if typ != nil { + // typed nil + if typ.Kind() != reflect.Ptr { + typ = reflect.PtrTo(typ) + } + v = reflect.Zero(typ).Interface() + } else { + // can't use nil dynamic message, so we just use empty one instead + v = m.mf.NewDynamicMessage(fd.GetMessageType()) + } + if err := m.setField(fd, v); err != nil { + return err + } + } else { + // not a message... explicit null makes no sense + return fmt.Errorf("message type %s cannot set field %s to null: it is not a message type", m.md.GetFullyQualifiedName(), f) + } + } else { + m.clearField(fd) + } + } + + if err := r.endObject(); err != nil { + return err + } + + return nil +} + +func isWellKnownValue(fd *desc.FieldDescriptor) bool { + return !fd.IsRepeated() && fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE && + fd.GetMessageType().GetFullyQualifiedName() == "google.protobuf.Value" +} + +func isWellKnownListValue(fd *desc.FieldDescriptor) bool { + // we look for ListValue; but we also look for Value, which can be assigned a ListValue + return !fd.IsRepeated() && fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE && + (fd.GetMessageType().GetFullyQualifiedName() == "google.protobuf.ListValue" || + fd.GetMessageType().GetFullyQualifiedName() == "google.protobuf.Value") +} + +func unmarshalJsField(fd *desc.FieldDescriptor, r *jsReader, mf *MessageFactory, opts *jsonpb.Unmarshaler) (interface{}, error) { + t, err := r.peek() + if err != nil { + return nil, err + } + if t == nil && !isWellKnownValue(fd) { + // if value is null, just return nil + // (unless field is google.protobuf.Value, in which case + // we fall through to parse it as an instance where its + // underlying value is set to a NullValue) + r.poll() + return nil, nil + } + + if t == json.Delim('{') && fd.IsMap() { + entryType := fd.GetMessageType() + keyType := entryType.FindFieldByNumber(1) + valueType := entryType.FindFieldByNumber(2) + mp := map[interface{}]interface{}{} + + // TODO: if there are just two map keys "key" and "value" and they have the right type of values, + // treat this JSON object as a single map entry message. (In keeping with support of map fields as + // if they were normal repeated field of entry messages as well as supporting a transition from + // optional to repeated...) + + if err := r.beginObject(); err != nil { + return nil, err + } + for r.hasNext() { + kk, err := unmarshalJsFieldElement(keyType, r, mf, opts, false) + if err != nil { + return nil, err + } + vv, err := unmarshalJsFieldElement(valueType, r, mf, opts, true) + if err != nil { + return nil, err + } + mp[kk] = vv + } + if err := r.endObject(); err != nil { + return nil, err + } + + return mp, nil + } else if t == json.Delim('[') && !isWellKnownListValue(fd) { + // We support parsing an array, even if field is not repeated, to mimic support in proto + // binary wire format that supports changing an optional field to repeated and vice versa. + // If the field is not repeated, we only keep the last value in the array. + + if err := r.beginArray(); err != nil { + return nil, err + } + var sl []interface{} + var v interface{} + for r.hasNext() { + var err error + v, err = unmarshalJsFieldElement(fd, r, mf, opts, false) + if err != nil { + return nil, err + } + if fd.IsRepeated() && v != nil { + sl = append(sl, v) + } + } + if err := r.endArray(); err != nil { + return nil, err + } + if fd.IsMap() { + mp := map[interface{}]interface{}{} + for _, m := range sl { + msg := m.(*Message) + kk, err := msg.TryGetFieldByNumber(1) + if err != nil { + return nil, err + } + vv, err := msg.TryGetFieldByNumber(2) + if err != nil { + return nil, err + } + mp[kk] = vv + } + return mp, nil + } else if fd.IsRepeated() { + return sl, nil + } else { + return v, nil + } + } else { + // We support parsing a singular value, even if field is repeated, to mimic support in proto + // binary wire format that supports changing an optional field to repeated and vice versa. + // If the field is repeated, we store value as singleton slice of that one value. + + v, err := unmarshalJsFieldElement(fd, r, mf, opts, false) + if err != nil { + return nil, err + } + if v == nil { + return nil, nil + } + if fd.IsRepeated() { + return []interface{}{v}, nil + } else { + return v, nil + } + } +} + +func unmarshalJsFieldElement(fd *desc.FieldDescriptor, r *jsReader, mf *MessageFactory, opts *jsonpb.Unmarshaler, allowNilMessage bool) (interface{}, error) { + t, err := r.peek() + if err != nil { + return nil, err + } + + switch fd.GetType() { + case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, + descriptorpb.FieldDescriptorProto_TYPE_GROUP: + + if t == nil && allowNilMessage { + // if json is simply "null" return a nil pointer + r.poll() + return nilMessage(fd.GetMessageType()), nil + } + + m := mf.NewMessage(fd.GetMessageType()) + if dm, ok := m.(*Message); ok { + if err := dm.unmarshalJson(r, opts); err != nil { + return nil, err + } + } else { + var msg json.RawMessage + if err := json.NewDecoder(r.unread()).Decode(&msg); err != nil { + return nil, err + } + if err := r.skip(); err != nil { + return nil, err + } + if err := opts.Unmarshal(bytes.NewReader([]byte(msg)), m); err != nil { + return nil, err + } + } + return m, nil + + case descriptorpb.FieldDescriptorProto_TYPE_ENUM: + if e, err := r.nextNumber(); err != nil { + return nil, err + } else { + // value could be string or number + if i, err := e.Int64(); err != nil { + // number cannot be parsed, so see if it's an enum value name + vd := fd.GetEnumType().FindValueByName(string(e)) + if vd != nil { + return vd.GetNumber(), nil + } else { + return nil, fmt.Errorf("enum %q does not have value named %q", fd.GetEnumType().GetFullyQualifiedName(), e) + } + } else if i > math.MaxInt32 || i < math.MinInt32 { + return nil, NumericOverflowError + } else { + return int32(i), err + } + } + + case descriptorpb.FieldDescriptorProto_TYPE_INT32, + descriptorpb.FieldDescriptorProto_TYPE_SINT32, + descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: + if i, err := r.nextInt(); err != nil { + return nil, err + } else if i > math.MaxInt32 || i < math.MinInt32 { + return nil, NumericOverflowError + } else { + return int32(i), err + } + + case descriptorpb.FieldDescriptorProto_TYPE_INT64, + descriptorpb.FieldDescriptorProto_TYPE_SINT64, + descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: + return r.nextInt() + + case descriptorpb.FieldDescriptorProto_TYPE_UINT32, + descriptorpb.FieldDescriptorProto_TYPE_FIXED32: + if i, err := r.nextUint(); err != nil { + return nil, err + } else if i > math.MaxUint32 { + return nil, NumericOverflowError + } else { + return uint32(i), err + } + + case descriptorpb.FieldDescriptorProto_TYPE_UINT64, + descriptorpb.FieldDescriptorProto_TYPE_FIXED64: + return r.nextUint() + + case descriptorpb.FieldDescriptorProto_TYPE_BOOL: + if str, ok := t.(string); ok { + if str == "true" { + r.poll() // consume token + return true, err + } else if str == "false" { + r.poll() // consume token + return false, err + } + } + return r.nextBool() + + case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: + if f, err := r.nextFloat(); err != nil { + return nil, err + } else { + return float32(f), nil + } + + case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: + return r.nextFloat() + + case descriptorpb.FieldDescriptorProto_TYPE_BYTES: + return r.nextBytes() + + case descriptorpb.FieldDescriptorProto_TYPE_STRING: + return r.nextString() + + default: + return nil, fmt.Errorf("unknown field type: %v", fd.GetType()) + } +} + +type jsReader struct { + reader *bytes.Reader + dec *json.Decoder + current json.Token + peeked bool +} + +func newJsReader(b []byte) *jsReader { + reader := bytes.NewReader(b) + dec := json.NewDecoder(reader) + dec.UseNumber() + return &jsReader{reader: reader, dec: dec} +} + +func (r *jsReader) unread() io.Reader { + bufs := make([]io.Reader, 3) + var peeked []byte + if r.peeked { + if _, ok := r.current.(json.Delim); ok { + peeked = []byte(fmt.Sprintf("%v", r.current)) + } else { + peeked, _ = json.Marshal(r.current) + } + } + readerCopy := *r.reader + decCopy := *r.dec + + bufs[0] = bytes.NewReader(peeked) + bufs[1] = decCopy.Buffered() + bufs[2] = &readerCopy + return &concatReader{bufs: bufs} +} + +func (r *jsReader) hasNext() bool { + return r.dec.More() +} + +func (r *jsReader) peek() (json.Token, error) { + if r.peeked { + return r.current, nil + } + t, err := r.dec.Token() + if err != nil { + return nil, err + } + r.peeked = true + r.current = t + return t, nil +} + +func (r *jsReader) poll() (json.Token, error) { + if r.peeked { + ret := r.current + r.current = nil + r.peeked = false + return ret, nil + } + return r.dec.Token() +} + +func (r *jsReader) beginObject() error { + _, err := r.expect(func(t json.Token) bool { return t == json.Delim('{') }, nil, "start of JSON object: '{'") + return err +} + +func (r *jsReader) endObject() error { + _, err := r.expect(func(t json.Token) bool { return t == json.Delim('}') }, nil, "end of JSON object: '}'") + return err +} + +func (r *jsReader) beginArray() error { + _, err := r.expect(func(t json.Token) bool { return t == json.Delim('[') }, nil, "start of array: '['") + return err +} + +func (r *jsReader) endArray() error { + _, err := r.expect(func(t json.Token) bool { return t == json.Delim(']') }, nil, "end of array: ']'") + return err +} + +func (r *jsReader) nextObjectKey() (string, error) { + return r.nextString() +} + +func (r *jsReader) nextString() (string, error) { + t, err := r.expect(func(t json.Token) bool { _, ok := t.(string); return ok }, "", "string") + if err != nil { + return "", err + } + return t.(string), nil +} + +func (r *jsReader) nextBytes() ([]byte, error) { + str, err := r.nextString() + if err != nil { + return nil, err + } + return base64.StdEncoding.DecodeString(str) +} + +func (r *jsReader) nextBool() (bool, error) { + t, err := r.expect(func(t json.Token) bool { _, ok := t.(bool); return ok }, false, "boolean") + if err != nil { + return false, err + } + return t.(bool), nil +} + +func (r *jsReader) nextInt() (int64, error) { + n, err := r.nextNumber() + if err != nil { + return 0, err + } + return n.Int64() +} + +func (r *jsReader) nextUint() (uint64, error) { + n, err := r.nextNumber() + if err != nil { + return 0, err + } + return strconv.ParseUint(string(n), 10, 64) +} + +func (r *jsReader) nextFloat() (float64, error) { + n, err := r.nextNumber() + if err != nil { + return 0, err + } + return n.Float64() +} + +func (r *jsReader) nextNumber() (json.Number, error) { + t, err := r.expect(func(t json.Token) bool { return reflect.TypeOf(t).Kind() == reflect.String }, "0", "number") + if err != nil { + return "", err + } + switch t := t.(type) { + case json.Number: + return t, nil + case string: + return json.Number(t), nil + } + return "", fmt.Errorf("expecting a number but got %v", t) +} + +func (r *jsReader) skip() error { + t, err := r.poll() + if err != nil { + return err + } + if t == json.Delim('[') { + if err := r.skipArray(); err != nil { + return err + } + } else if t == json.Delim('{') { + if err := r.skipObject(); err != nil { + return err + } + } + return nil +} + +func (r *jsReader) skipArray() error { + for r.hasNext() { + if err := r.skip(); err != nil { + return err + } + } + if err := r.endArray(); err != nil { + return err + } + return nil +} + +func (r *jsReader) skipObject() error { + for r.hasNext() { + // skip object key + if err := r.skip(); err != nil { + return err + } + // and value + if err := r.skip(); err != nil { + return err + } + } + if err := r.endObject(); err != nil { + return err + } + return nil +} + +func (r *jsReader) expect(predicate func(json.Token) bool, ifNil interface{}, expected string) (interface{}, error) { + t, err := r.poll() + if err != nil { + return nil, err + } + if t == nil && ifNil != nil { + return ifNil, nil + } + if !predicate(t) { + return t, fmt.Errorf("bad input: expecting %s ; instead got %v", expected, t) + } + return t, nil +} + +type concatReader struct { + bufs []io.Reader + curr int +} + +func (r *concatReader) Read(p []byte) (n int, err error) { + for { + if r.curr >= len(r.bufs) { + err = io.EOF + return + } + var c int + c, err = r.bufs[r.curr].Read(p) + n += c + if err != io.EOF { + return + } + r.curr++ + p = p[c:] + } +} + +// AnyResolver returns a jsonpb.AnyResolver that uses the given file descriptors +// to resolve message names. It uses the given factory, which may be nil, to +// instantiate messages. The messages that it returns when resolving a type name +// may often be dynamic messages. +func AnyResolver(mf *MessageFactory, files ...*desc.FileDescriptor) jsonpb.AnyResolver { + return &anyResolver{mf: mf, files: files} +} + +type anyResolver struct { + mf *MessageFactory + files []*desc.FileDescriptor + ignored map[*desc.FileDescriptor]struct{} + other jsonpb.AnyResolver +} + +func wrapResolver(r jsonpb.AnyResolver, mf *MessageFactory, f *desc.FileDescriptor) (jsonpb.AnyResolver, bool) { + if r, ok := r.(*anyResolver); ok { + if _, ok := r.ignored[f]; ok { + // if the current resolver is ignoring this file, it's because another + // (upstream) resolver is already handling it, so nothing to do + return r, false + } + for _, file := range r.files { + if file == f { + // no need to wrap! + return r, false + } + } + // ignore files that will be checked by the resolver we're wrapping + // (we'll just delegate and let it search those files) + ignored := map[*desc.FileDescriptor]struct{}{} + for i := range r.ignored { + ignored[i] = struct{}{} + } + ignore(r.files, ignored) + return &anyResolver{mf: mf, files: []*desc.FileDescriptor{f}, ignored: ignored, other: r}, true + } + return &anyResolver{mf: mf, files: []*desc.FileDescriptor{f}, other: r}, true +} + +func ignore(files []*desc.FileDescriptor, ignored map[*desc.FileDescriptor]struct{}) { + for _, f := range files { + if _, ok := ignored[f]; ok { + continue + } + ignored[f] = struct{}{} + ignore(f.GetDependencies(), ignored) + } +} + +func (r *anyResolver) Resolve(typeUrl string) (proto.Message, error) { + mname := typeUrl + if slash := strings.LastIndex(mname, "/"); slash >= 0 { + mname = mname[slash+1:] + } + + // see if the user-specified resolver is able to do the job + if r.other != nil { + msg, err := r.other.Resolve(typeUrl) + if err == nil { + return msg, nil + } + } + + // try to find the message in our known set of files + checked := map[*desc.FileDescriptor]struct{}{} + for _, f := range r.files { + md := r.findMessage(f, mname, checked) + if md != nil { + return r.mf.NewMessage(md), nil + } + } + // failing that, see if the message factory knows about this type + var ktr *KnownTypeRegistry + if r.mf != nil { + ktr = r.mf.ktr + } else { + ktr = (*KnownTypeRegistry)(nil) + } + m := ktr.CreateIfKnown(mname) + if m != nil { + return m, nil + } + + // no other resolver to fallback to? mimic default behavior + mt := proto.MessageType(mname) + if mt == nil { + return nil, fmt.Errorf("unknown message type %q", mname) + } + return reflect.New(mt.Elem()).Interface().(proto.Message), nil +} + +func (r *anyResolver) findMessage(fd *desc.FileDescriptor, msgName string, checked map[*desc.FileDescriptor]struct{}) *desc.MessageDescriptor { + // if this is an ignored descriptor, skip + if _, ok := r.ignored[fd]; ok { + return nil + } + + // bail if we've already checked this file + if _, ok := checked[fd]; ok { + return nil + } + checked[fd] = struct{}{} + + // see if this file has the message + md := fd.FindMessage(msgName) + if md != nil { + return md + } + + // if not, recursively search the file's imports + for _, dep := range fd.GetDependencies() { + md = r.findMessage(dep, msgName, checked) + if md != nil { + return md + } + } + return nil +} + +var _ jsonpb.AnyResolver = (*anyResolver)(nil) diff --git a/vendor/github.com/jhump/protoreflect/dynamic/maps_1.11.go b/vendor/github.com/jhump/protoreflect/dynamic/maps_1.11.go new file mode 100644 index 0000000..69969fc --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/maps_1.11.go @@ -0,0 +1,131 @@ +//go:build !go1.12 +// +build !go1.12 + +package dynamic + +import ( + "reflect" + + "github.com/jhump/protoreflect/desc" +) + +// Pre-Go-1.12, we must use reflect.Value.MapKeys to reflectively +// iterate a map. (We can be more efficient in Go 1.12 and up...) + +func mapsEqual(a, b reflect.Value) bool { + if a.Len() != b.Len() { + return false + } + if a.Len() == 0 && b.Len() == 0 { + // Optimize the case where maps are frequently empty because MapKeys() + // function allocates heavily. + return true + } + + for _, k := range a.MapKeys() { + av := a.MapIndex(k) + bv := b.MapIndex(k) + if !bv.IsValid() { + return false + } + if !fieldsEqual(av.Interface(), bv.Interface()) { + return false + } + } + return true +} + +func validFieldValueForMapField(fd *desc.FieldDescriptor, val reflect.Value) (interface{}, error) { + // make a defensive copy while we check the contents + // (also converts to map[interface{}]interface{} if it's some other type) + keyField := fd.GetMessageType().GetFields()[0] + valField := fd.GetMessageType().GetFields()[1] + m := map[interface{}]interface{}{} + for _, k := range val.MapKeys() { + if k.Kind() == reflect.Interface { + // unwrap it + k = reflect.ValueOf(k.Interface()) + } + kk, err := validElementFieldValueForRv(keyField, k, false) + if err != nil { + return nil, err + } + v := val.MapIndex(k) + if v.Kind() == reflect.Interface { + // unwrap it + v = reflect.ValueOf(v.Interface()) + } + vv, err := validElementFieldValueForRv(valField, v, true) + if err != nil { + return nil, err + } + m[kk] = vv + } + return m, nil +} + +func canConvertMap(src reflect.Value, target reflect.Type) bool { + kt := target.Key() + vt := target.Elem() + for _, k := range src.MapKeys() { + if !canConvert(k, kt) { + return false + } + if !canConvert(src.MapIndex(k), vt) { + return false + } + } + return true +} + +func mergeMapVal(src, target reflect.Value, targetType reflect.Type, deterministic bool) error { + tkt := targetType.Key() + tvt := targetType.Elem() + for _, k := range src.MapKeys() { + v := src.MapIndex(k) + skt := k.Type() + svt := v.Type() + var nk, nv reflect.Value + if tkt == skt { + nk = k + } else if tkt.Kind() == reflect.Ptr && tkt.Elem() == skt { + nk = k.Addr() + } else { + nk = reflect.New(tkt).Elem() + if err := mergeVal(k, nk, deterministic); err != nil { + return err + } + } + if tvt == svt { + nv = v + } else if tvt.Kind() == reflect.Ptr && tvt.Elem() == svt { + nv = v.Addr() + } else { + nv = reflect.New(tvt).Elem() + if err := mergeVal(v, nv, deterministic); err != nil { + return err + } + } + if target.IsNil() { + target.Set(reflect.MakeMap(targetType)) + } + target.SetMapIndex(nk, nv) + } + return nil +} + +func mergeMapField(m *Message, fd *desc.FieldDescriptor, rv reflect.Value) error { + for _, k := range rv.MapKeys() { + if k.Kind() == reflect.Interface && !k.IsNil() { + k = k.Elem() + } + v := rv.MapIndex(k) + if v.Kind() == reflect.Interface && !v.IsNil() { + v = v.Elem() + } + if err := m.putMapField(fd, k.Interface(), v.Interface()); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/maps_1.12.go b/vendor/github.com/jhump/protoreflect/dynamic/maps_1.12.go new file mode 100644 index 0000000..fb353cf --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/maps_1.12.go @@ -0,0 +1,139 @@ +//go:build go1.12 +// +build go1.12 + +package dynamic + +import ( + "reflect" + + "github.com/jhump/protoreflect/desc" +) + +// With Go 1.12 and above, we can use reflect.Value.MapRange to iterate +// over maps more efficiently than using reflect.Value.MapKeys. + +func mapsEqual(a, b reflect.Value) bool { + if a.Len() != b.Len() { + return false + } + if a.Len() == 0 && b.Len() == 0 { + // Optimize the case where maps are frequently empty + return true + } + + iter := a.MapRange() + for iter.Next() { + k := iter.Key() + av := iter.Value() + bv := b.MapIndex(k) + if !bv.IsValid() { + return false + } + if !fieldsEqual(av.Interface(), bv.Interface()) { + return false + } + } + return true +} + +func validFieldValueForMapField(fd *desc.FieldDescriptor, val reflect.Value) (interface{}, error) { + // make a defensive copy while we check the contents + // (also converts to map[interface{}]interface{} if it's some other type) + keyField := fd.GetMessageType().GetFields()[0] + valField := fd.GetMessageType().GetFields()[1] + m := map[interface{}]interface{}{} + iter := val.MapRange() + for iter.Next() { + k := iter.Key() + if k.Kind() == reflect.Interface { + // unwrap it + k = reflect.ValueOf(k.Interface()) + } + kk, err := validElementFieldValueForRv(keyField, k, false) + if err != nil { + return nil, err + } + v := iter.Value() + if v.Kind() == reflect.Interface { + // unwrap it + v = reflect.ValueOf(v.Interface()) + } + vv, err := validElementFieldValueForRv(valField, v, true) + if err != nil { + return nil, err + } + m[kk] = vv + } + return m, nil +} + +func canConvertMap(src reflect.Value, target reflect.Type) bool { + kt := target.Key() + vt := target.Elem() + iter := src.MapRange() + for iter.Next() { + if !canConvert(iter.Key(), kt) { + return false + } + if !canConvert(iter.Value(), vt) { + return false + } + } + return true +} + +func mergeMapVal(src, target reflect.Value, targetType reflect.Type, deterministic bool) error { + tkt := targetType.Key() + tvt := targetType.Elem() + iter := src.MapRange() + for iter.Next() { + k := iter.Key() + v := iter.Value() + skt := k.Type() + svt := v.Type() + var nk, nv reflect.Value + if tkt == skt { + nk = k + } else if tkt.Kind() == reflect.Ptr && tkt.Elem() == skt { + nk = k.Addr() + } else { + nk = reflect.New(tkt).Elem() + if err := mergeVal(k, nk, deterministic); err != nil { + return err + } + } + if tvt == svt { + nv = v + } else if tvt.Kind() == reflect.Ptr && tvt.Elem() == svt { + nv = v.Addr() + } else { + nv = reflect.New(tvt).Elem() + if err := mergeVal(v, nv, deterministic); err != nil { + return err + } + } + if target.IsNil() { + target.Set(reflect.MakeMap(targetType)) + } + target.SetMapIndex(nk, nv) + } + return nil +} + +func mergeMapField(m *Message, fd *desc.FieldDescriptor, rv reflect.Value) error { + iter := rv.MapRange() + for iter.Next() { + k := iter.Key() + v := iter.Value() + if k.Kind() == reflect.Interface && !k.IsNil() { + k = k.Elem() + } + if v.Kind() == reflect.Interface && !v.IsNil() { + v = v.Elem() + } + if err := m.putMapField(fd, k.Interface(), v.Interface()); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/merge.go b/vendor/github.com/jhump/protoreflect/dynamic/merge.go new file mode 100644 index 0000000..ce727fd --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/merge.go @@ -0,0 +1,100 @@ +package dynamic + +import ( + "errors" + "reflect" + + "github.com/golang/protobuf/proto" + + "github.com/jhump/protoreflect/desc" +) + +// Merge merges the given source message into the given destination message. Use +// use this instead of proto.Merge when one or both of the messages might be a +// a dynamic message. If there is a problem merging the messages, such as the +// two messages having different types, then this method will panic (just as +// proto.Merges does). +func Merge(dst, src proto.Message) { + if dm, ok := dst.(*Message); ok { + if err := dm.MergeFrom(src); err != nil { + panic(err.Error()) + } + } else if dm, ok := src.(*Message); ok { + if err := dm.MergeInto(dst); err != nil { + panic(err.Error()) + } + } else { + proto.Merge(dst, src) + } +} + +// TryMerge merges the given source message into the given destination message. +// You can use this instead of proto.Merge when one or both of the messages +// might be a dynamic message. Unlike proto.Merge, this method will return an +// error on failure instead of panic'ing. +func TryMerge(dst, src proto.Message) error { + if dm, ok := dst.(*Message); ok { + if err := dm.MergeFrom(src); err != nil { + return err + } + } else if dm, ok := src.(*Message); ok { + if err := dm.MergeInto(dst); err != nil { + return err + } + } else { + // proto.Merge panics on bad input, so we first verify + // inputs and return error instead of panic + out := reflect.ValueOf(dst) + if out.IsNil() { + return errors.New("proto: nil destination") + } + in := reflect.ValueOf(src) + if in.Type() != out.Type() { + return errors.New("proto: type mismatch") + } + proto.Merge(dst, src) + } + return nil +} + +func mergeField(m *Message, fd *desc.FieldDescriptor, val interface{}) error { + rv := reflect.ValueOf(val) + + if fd.IsMap() && rv.Kind() == reflect.Map { + return mergeMapField(m, fd, rv) + } + + if fd.IsRepeated() && rv.Kind() == reflect.Slice && rv.Type() != typeOfBytes { + for i := 0; i < rv.Len(); i++ { + e := rv.Index(i) + if e.Kind() == reflect.Interface && !e.IsNil() { + e = e.Elem() + } + if err := m.addRepeatedField(fd, e.Interface()); err != nil { + return err + } + } + return nil + } + + if fd.IsRepeated() { + return m.addRepeatedField(fd, val) + } else if fd.GetMessageType() == nil { + return m.setField(fd, val) + } + + // it's a message type, so we want to merge contents + var err error + if val, err = validFieldValue(fd, val); err != nil { + return err + } + + existing, _ := m.doGetField(fd, true) + if existing != nil && !reflect.ValueOf(existing).IsNil() { + return TryMerge(existing.(proto.Message), val.(proto.Message)) + } + + // no existing message, so just set field + m.internalSetField(fd, val) + return nil +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/message_factory.go b/vendor/github.com/jhump/protoreflect/dynamic/message_factory.go new file mode 100644 index 0000000..683e7b3 --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/message_factory.go @@ -0,0 +1,207 @@ +package dynamic + +import ( + "reflect" + "sync" + + "github.com/golang/protobuf/proto" + + "github.com/jhump/protoreflect/desc" +) + +// MessageFactory can be used to create new empty message objects. A default instance +// (without extension registry or known-type registry specified) will always return +// dynamic messages (e.g. type will be *dynamic.Message) except for "well-known" types. +// The well-known types include primitive wrapper types and a handful of other special +// types defined in standard protobuf definitions, like Any, Duration, and Timestamp. +type MessageFactory struct { + er *ExtensionRegistry + ktr *KnownTypeRegistry +} + +// NewMessageFactoryWithExtensionRegistry creates a new message factory where any +// dynamic messages produced will use the given extension registry to recognize and +// parse extension fields. +func NewMessageFactoryWithExtensionRegistry(er *ExtensionRegistry) *MessageFactory { + return NewMessageFactoryWithRegistries(er, nil) +} + +// NewMessageFactoryWithKnownTypeRegistry creates a new message factory where the +// known types, per the given registry, will be returned as normal protobuf messages +// (e.g. generated structs, instead of dynamic messages). +func NewMessageFactoryWithKnownTypeRegistry(ktr *KnownTypeRegistry) *MessageFactory { + return NewMessageFactoryWithRegistries(nil, ktr) +} + +// NewMessageFactoryWithDefaults creates a new message factory where all "default" types +// (those for which protoc-generated code is statically linked into the Go program) are +// known types. If any dynamic messages are produced, they will recognize and parse all +// "default" extension fields. This is the equivalent of: +// +// NewMessageFactoryWithRegistries( +// NewExtensionRegistryWithDefaults(), +// NewKnownTypeRegistryWithDefaults()) +func NewMessageFactoryWithDefaults() *MessageFactory { + return NewMessageFactoryWithRegistries(NewExtensionRegistryWithDefaults(), NewKnownTypeRegistryWithDefaults()) +} + +// NewMessageFactoryWithRegistries creates a new message factory with the given extension +// and known type registries. +func NewMessageFactoryWithRegistries(er *ExtensionRegistry, ktr *KnownTypeRegistry) *MessageFactory { + return &MessageFactory{ + er: er, + ktr: ktr, + } +} + +// NewMessage creates a new empty message that corresponds to the given descriptor. +// If the given descriptor describes a "known type" then that type is instantiated. +// Otherwise, an empty dynamic message is returned. +func (f *MessageFactory) NewMessage(md *desc.MessageDescriptor) proto.Message { + var ktr *KnownTypeRegistry + if f != nil { + ktr = f.ktr + } + if m := ktr.CreateIfKnown(md.GetFullyQualifiedName()); m != nil { + return m + } + return NewMessageWithMessageFactory(md, f) +} + +// NewDynamicMessage creates a new empty dynamic message that corresponds to the given +// descriptor. This is like f.NewMessage(md) except the known type registry is not +// consulted so the return value is always a dynamic message. +// +// This is also like dynamic.NewMessage(md) except that the returned message will use +// this factory when creating other messages, like during de-serialization of fields +// that are themselves message types. +func (f *MessageFactory) NewDynamicMessage(md *desc.MessageDescriptor) *Message { + return NewMessageWithMessageFactory(md, f) +} + +// GetKnownTypeRegistry returns the known type registry that this factory uses to +// instantiate known (e.g. generated) message types. +func (f *MessageFactory) GetKnownTypeRegistry() *KnownTypeRegistry { + if f == nil { + return nil + } + return f.ktr +} + +// GetExtensionRegistry returns the extension registry that this factory uses to +// create dynamic messages. The registry is used by dynamic messages to recognize +// and parse extension fields during de-serialization. +func (f *MessageFactory) GetExtensionRegistry() *ExtensionRegistry { + if f == nil { + return nil + } + return f.er +} + +type wkt interface { + XXX_WellKnownType() string +} + +var typeOfWkt = reflect.TypeOf((*wkt)(nil)).Elem() + +// KnownTypeRegistry is a registry of known message types, as identified by their +// fully-qualified name. A known message type is one for which a protoc-generated +// struct exists, so a dynamic message is not necessary to represent it. A +// MessageFactory uses a KnownTypeRegistry to decide whether to create a generated +// struct or a dynamic message. The zero-value registry (including the behavior of +// a nil pointer) only knows about the "well-known types" in protobuf. These +// include only the wrapper types and a handful of other special types like Any, +// Duration, and Timestamp. +type KnownTypeRegistry struct { + excludeWkt bool + includeDefault bool + mu sync.RWMutex + types map[string]reflect.Type +} + +// NewKnownTypeRegistryWithDefaults creates a new registry that knows about all +// "default" types (those for which protoc-generated code is statically linked +// into the Go program). +func NewKnownTypeRegistryWithDefaults() *KnownTypeRegistry { + return &KnownTypeRegistry{includeDefault: true} +} + +// NewKnownTypeRegistryWithoutWellKnownTypes creates a new registry that does *not* +// include the "well-known types" in protobuf. So even well-known types would be +// represented by a dynamic message. +func NewKnownTypeRegistryWithoutWellKnownTypes() *KnownTypeRegistry { + return &KnownTypeRegistry{excludeWkt: true} +} + +// AddKnownType adds the types of the given messages as known types. +func (r *KnownTypeRegistry) AddKnownType(kts ...proto.Message) { + r.mu.Lock() + defer r.mu.Unlock() + if r.types == nil { + r.types = map[string]reflect.Type{} + } + for _, kt := range kts { + r.types[proto.MessageName(kt)] = reflect.TypeOf(kt) + } +} + +// CreateIfKnown will construct an instance of the given message if it is a known type. +// If the given name is unknown, nil is returned. +func (r *KnownTypeRegistry) CreateIfKnown(messageName string) proto.Message { + msgType := r.GetKnownType(messageName) + if msgType == nil { + return nil + } + + if msgType.Kind() == reflect.Ptr { + return reflect.New(msgType.Elem()).Interface().(proto.Message) + } else { + return reflect.New(msgType).Elem().Interface().(proto.Message) + } +} + +func isWellKnownType(t reflect.Type) bool { + if t.Implements(typeOfWkt) { + return true + } + if msg, ok := reflect.Zero(t).Interface().(proto.Message); ok { + name := proto.MessageName(msg) + _, ok := wellKnownTypeNames[name] + return ok + } + return false +} + +// GetKnownType will return the reflect.Type for the given message name if it is +// known. If it is not known, nil is returned. +func (r *KnownTypeRegistry) GetKnownType(messageName string) reflect.Type { + if r == nil { + // a nil registry behaves the same as zero value instance: only know of well-known types + t := proto.MessageType(messageName) + if t != nil && isWellKnownType(t) { + return t + } + return nil + } + + if r.includeDefault { + t := proto.MessageType(messageName) + if t != nil && isMessage(t) { + return t + } + } else if !r.excludeWkt { + t := proto.MessageType(messageName) + if t != nil && isWellKnownType(t) { + return t + } + } + + r.mu.RLock() + defer r.mu.RUnlock() + return r.types[messageName] +} + +func isMessage(t reflect.Type) bool { + _, ok := reflect.Zero(t).Interface().(proto.Message) + return ok +} diff --git a/vendor/github.com/jhump/protoreflect/dynamic/text.go b/vendor/github.com/jhump/protoreflect/dynamic/text.go new file mode 100644 index 0000000..5680dc2 --- /dev/null +++ b/vendor/github.com/jhump/protoreflect/dynamic/text.go @@ -0,0 +1,1177 @@ +package dynamic + +// Marshalling and unmarshalling of dynamic messages to/from proto's standard text format + +import ( + "bytes" + "fmt" + "io" + "math" + "reflect" + "sort" + "strconv" + "strings" + "text/scanner" + "unicode" + + "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + + "github.com/jhump/protoreflect/codec" + "github.com/jhump/protoreflect/desc" +) + +// MarshalText serializes this message to bytes in the standard text format, +// returning an error if the operation fails. The resulting bytes will be a +// valid UTF8 string. +// +// This method uses a compact form: no newlines, and spaces between field +// identifiers and values are elided. +func (m *Message) MarshalText() ([]byte, error) { + var b indentBuffer + b.indentCount = -1 // no indentation + if err := m.marshalText(&b); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// MarshalTextIndent serializes this message to bytes in the standard text +// format, returning an error if the operation fails. The resulting bytes will +// be a valid UTF8 string. +// +// This method uses a "pretty-printed" form, with each field on its own line and +// spaces between field identifiers and values. +func (m *Message) MarshalTextIndent() ([]byte, error) { + var b indentBuffer + b.indent = " " // TODO: option for indent? + if err := m.marshalText(&b); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func (m *Message) marshalText(b *indentBuffer) error { + // TODO: option for emitting extended Any format? + first := true + // first the known fields + for _, tag := range m.knownFieldTags() { + itag := int32(tag) + v := m.values[itag] + fd := m.FindFieldDescriptor(itag) + if fd.IsMap() { + md := fd.GetMessageType() + kfd := md.FindFieldByNumber(1) + vfd := md.FindFieldByNumber(2) + mp := v.(map[interface{}]interface{}) + keys := make([]interface{}, 0, len(mp)) + for k := range mp { + keys = append(keys, k) + } + sort.Sort(sortable(keys)) + for _, mk := range keys { + mv := mp[mk] + err := b.maybeNext(&first) + if err != nil { + return err + } + err = marshalKnownFieldMapEntryText(b, fd, kfd, mk, vfd, mv) + if err != nil { + return err + } + } + } else if fd.IsRepeated() { + sl := v.([]interface{}) + for _, slv := range sl { + err := b.maybeNext(&first) + if err != nil { + return err + } + err = marshalKnownFieldText(b, fd, slv) + if err != nil { + return err + } + } + } else { + err := b.maybeNext(&first) + if err != nil { + return err + } + err = marshalKnownFieldText(b, fd, v) + if err != nil { + return err + } + } + } + // then the unknown fields + for _, tag := range m.unknownFieldTags() { + itag := int32(tag) + ufs := m.unknownFields[itag] + for _, uf := range ufs { + err := b.maybeNext(&first) + if err != nil { + return err + } + _, err = fmt.Fprintf(b, "%d", tag) + if err != nil { + return err + } + if uf.Encoding == proto.WireStartGroup { + err = b.WriteByte('{') + if err != nil { + return err + } + err = b.start() + if err != nil { + return err + } + in := codec.NewBuffer(uf.Contents) + err = marshalUnknownGroupText(b, in, true) + if err != nil { + return err + } + err = b.end() + if err != nil { + return err + } + err = b.WriteByte('}') + if err != nil { + return err + } + } else { + err = b.sep() + if err != nil { + return err + } + if uf.Encoding == proto.WireBytes { + err = writeString(b, string(uf.Contents)) + if err != nil { + return err + } + } else { + _, err = b.WriteString(strconv.FormatUint(uf.Value, 10)) + if err != nil { + return err + } + } + } + } + } + return nil +} + +func marshalKnownFieldMapEntryText(b *indentBuffer, fd *desc.FieldDescriptor, kfd *desc.FieldDescriptor, mk interface{}, vfd *desc.FieldDescriptor, mv interface{}) error { + var name string + if fd.IsExtension() { + name = fmt.Sprintf("[%s]", fd.GetFullyQualifiedName()) + } else { + name = fd.GetName() + } + _, err := b.WriteString(name) + if err != nil { + return err + } + err = b.sep() + if err != nil { + return err + } + + err = b.WriteByte('<') + if err != nil { + return err + } + err = b.start() + if err != nil { + return err + } + + err = marshalKnownFieldText(b, kfd, mk) + if err != nil { + return err + } + err = b.next() + if err != nil { + return err + } + if !isNil(mv) { + err = marshalKnownFieldText(b, vfd, mv) + if err != nil { + return err + } + } + + err = b.end() + if err != nil { + return err + } + return b.WriteByte('>') +} + +func marshalKnownFieldText(b *indentBuffer, fd *desc.FieldDescriptor, v interface{}) error { + group := fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP + if group { + var name string + if fd.IsExtension() { + name = fmt.Sprintf("[%s]", fd.GetMessageType().GetFullyQualifiedName()) + } else { + name = fd.GetMessageType().GetName() + } + _, err := b.WriteString(name) + if err != nil { + return err + } + } else { + var name string + if fd.IsExtension() { + name = fmt.Sprintf("[%s]", fd.GetFullyQualifiedName()) + } else { + name = fd.GetName() + } + _, err := b.WriteString(name) + if err != nil { + return err + } + err = b.sep() + if err != nil { + return err + } + } + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Int32, reflect.Int64: + ed := fd.GetEnumType() + if ed != nil { + n := int32(rv.Int()) + vd := ed.FindValueByNumber(n) + if vd == nil { + _, err := b.WriteString(strconv.FormatInt(rv.Int(), 10)) + return err + } else { + _, err := b.WriteString(vd.GetName()) + return err + } + } else { + _, err := b.WriteString(strconv.FormatInt(rv.Int(), 10)) + return err + } + case reflect.Uint32, reflect.Uint64: + _, err := b.WriteString(strconv.FormatUint(rv.Uint(), 10)) + return err + case reflect.Float32, reflect.Float64: + f := rv.Float() + var str string + if math.IsNaN(f) { + str = "nan" + } else if math.IsInf(f, 1) { + str = "inf" + } else if math.IsInf(f, -1) { + str = "-inf" + } else { + var bits int + if rv.Kind() == reflect.Float32 { + bits = 32 + } else { + bits = 64 + } + str = strconv.FormatFloat(rv.Float(), 'g', -1, bits) + } + _, err := b.WriteString(str) + return err + case reflect.Bool: + _, err := b.WriteString(strconv.FormatBool(rv.Bool())) + return err + case reflect.Slice: + return writeString(b, string(rv.Bytes())) + case reflect.String: + return writeString(b, rv.String()) + default: + var err error + if group { + err = b.WriteByte('{') + } else { + err = b.WriteByte('<') + } + if err != nil { + return err + } + err = b.start() + if err != nil { + return err + } + // must be a message + if dm, ok := v.(*Message); ok { + err = dm.marshalText(b) + if err != nil { + return err + } + } else { + err = proto.CompactText(b, v.(proto.Message)) + if err != nil { + return err + } + } + err = b.end() + if err != nil { + return err + } + if group { + return b.WriteByte('}') + } else { + return b.WriteByte('>') + } + } +} + +// writeString writes a string in the protocol buffer text format. +// It is similar to strconv.Quote except we don't use Go escape sequences, +// we treat the string as a byte sequence, and we use octal escapes. +// These differences are to maintain interoperability with the other +// languages' implementations of the text format. +func writeString(b *indentBuffer, s string) error { + // use WriteByte here to get any needed indent + if err := b.WriteByte('"'); err != nil { + return err + } + // Loop over the bytes, not the runes. + for i := 0; i < len(s); i++ { + var err error + // Divergence from C++: we don't escape apostrophes. + // There's no need to escape them, and the C++ parser + // copes with a naked apostrophe. + switch c := s[i]; c { + case '\n': + _, err = b.WriteString("\\n") + case '\r': + _, err = b.WriteString("\\r") + case '\t': + _, err = b.WriteString("\\t") + case '"': + _, err = b.WriteString("\\\"") + case '\\': + _, err = b.WriteString("\\\\") + default: + if c >= 0x20 && c < 0x7f { + err = b.WriteByte(c) + } else { + _, err = fmt.Fprintf(b, "\\%03o", c) + } + } + if err != nil { + return err + } + } + return b.WriteByte('"') +} + +func marshalUnknownGroupText(b *indentBuffer, in *codec.Buffer, topLevel bool) error { + first := true + for { + if in.EOF() { + if topLevel { + return nil + } + // this is a nested message: we are expecting an end-group tag, not EOF! + return io.ErrUnexpectedEOF + } + tag, wireType, err := in.DecodeTagAndWireType() + if err != nil { + return err + } + if wireType == proto.WireEndGroup { + return nil + } + err = b.maybeNext(&first) + if err != nil { + return err + } + _, err = fmt.Fprintf(b, "%d", tag) + if err != nil { + return err + } + if wireType == proto.WireStartGroup { + err = b.WriteByte('{') + if err != nil { + return err + } + err = b.start() + if err != nil { + return err + } + err = marshalUnknownGroupText(b, in, false) + if err != nil { + return err + } + err = b.end() + if err != nil { + return err + } + err = b.WriteByte('}') + if err != nil { + return err + } + continue + } else { + err = b.sep() + if err != nil { + return err + } + if wireType == proto.WireBytes { + contents, err := in.DecodeRawBytes(false) + if err != nil { + return err + } + err = writeString(b, string(contents)) + if err != nil { + return err + } + } else { + var v uint64 + switch wireType { + case proto.WireVarint: + v, err = in.DecodeVarint() + case proto.WireFixed32: + v, err = in.DecodeFixed32() + case proto.WireFixed64: + v, err = in.DecodeFixed64() + default: + return proto.ErrInternalBadWireType + } + if err != nil { + return err + } + _, err = b.WriteString(strconv.FormatUint(v, 10)) + if err != nil { + return err + } + } + } + } +} + +// UnmarshalText de-serializes the message that is present, in text format, in +// the given bytes into this message. It first resets the current message. It +// returns an error if the given bytes do not contain a valid encoding of this +// message type in the standard text format +func (m *Message) UnmarshalText(text []byte) error { + m.Reset() + if err := m.UnmarshalMergeText(text); err != nil { + return err + } + return m.Validate() +} + +// UnmarshalMergeText de-serializes the message that is present, in text format, +// in the given bytes into this message. Unlike UnmarshalText, it does not first +// reset the message, instead merging the data in the given bytes into the +// existing data in this message. +func (m *Message) UnmarshalMergeText(text []byte) error { + return m.unmarshalText(newReader(text), tokenEOF) +} + +func (m *Message) unmarshalText(tr *txtReader, end tokenType) error { + for { + tok := tr.next() + if tok.tokTyp == end { + return nil + } + if tok.tokTyp == tokenEOF { + return io.ErrUnexpectedEOF + } + var fd *desc.FieldDescriptor + var extendedAnyType *desc.MessageDescriptor + if tok.tokTyp == tokenInt { + // tag number (indicates unknown field) + tag, err := strconv.ParseInt(tok.val.(string), 10, 32) + if err != nil { + return err + } + itag := int32(tag) + fd = m.FindFieldDescriptor(itag) + if fd == nil { + // can't parse the value w/out field descriptor, so skip it + tok = tr.next() + if tok.tokTyp == tokenEOF { + return io.ErrUnexpectedEOF + } else if tok.tokTyp == tokenOpenBrace { + if err := skipMessageText(tr, true); err != nil { + return err + } + } else if tok.tokTyp == tokenColon { + if err := skipFieldValueText(tr); err != nil { + return err + } + } else { + return textError(tok, "Expecting a colon ':' or brace '{'; instead got %q", tok.txt) + } + tok = tr.peek() + if tok.tokTyp.IsSep() { + tr.next() // consume separator + } + continue + } + } else { + fieldName, err := unmarshalFieldNameText(tr, tok) + if err != nil { + return err + } + fd = m.FindFieldDescriptorByName(fieldName) + if fd == nil { + // See if it's a group name + for _, field := range m.md.GetFields() { + if field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP && field.GetMessageType().GetName() == fieldName { + fd = field + break + } + } + if fd == nil { + // maybe this is an extended Any + if m.md.GetFullyQualifiedName() == "google.protobuf.Any" && fieldName[0] == '[' && strings.Contains(fieldName, "/") { + // strip surrounding "[" and "]" and extract type name from URL + typeUrl := fieldName[1 : len(fieldName)-1] + mname := typeUrl + if slash := strings.LastIndex(mname, "/"); slash >= 0 { + mname = mname[slash+1:] + } + // TODO: add a way to weave an AnyResolver to this point + extendedAnyType = findMessageDescriptor(mname, m.md.GetFile()) + if extendedAnyType == nil { + return textError(tok, "could not parse Any with unknown type URL %q", fieldName) + } + // field 1 is "type_url" + typeUrlField := m.md.FindFieldByNumber(1) + if err := m.TrySetField(typeUrlField, typeUrl); err != nil { + return err + } + } else { + // TODO: add a flag to just ignore unrecognized field names + return textError(tok, "%q is not a recognized field name of %q", fieldName, m.md.GetFullyQualifiedName()) + } + } + } + } + tok = tr.next() + if tok.tokTyp == tokenEOF { + return io.ErrUnexpectedEOF + } + if extendedAnyType != nil { + // consume optional colon; make sure this is a "start message" token + if tok.tokTyp == tokenColon { + tok = tr.next() + if tok.tokTyp == tokenEOF { + return io.ErrUnexpectedEOF + } + } + if tok.tokTyp.EndToken() == tokenError { + return textError(tok, "Expecting a '<' or '{'; instead got %q", tok.txt) + } + + // TODO: use mf.NewMessage and, if not a dynamic message, use proto.UnmarshalText to unmarshal it + g := m.mf.NewDynamicMessage(extendedAnyType) + if err := g.unmarshalText(tr, tok.tokTyp.EndToken()); err != nil { + return err + } + // now we marshal the message to bytes and store in the Any + b, err := g.Marshal() + if err != nil { + return err + } + // field 2 is "value" + anyValueField := m.md.FindFieldByNumber(2) + if err := m.TrySetField(anyValueField, b); err != nil { + return err + } + + } else if (fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_GROUP || + fd.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) && + tok.tokTyp.EndToken() != tokenError { + + // TODO: use mf.NewMessage and, if not a dynamic message, use proto.UnmarshalText to unmarshal it + g := m.mf.NewDynamicMessage(fd.GetMessageType()) + if err := g.unmarshalText(tr, tok.tokTyp.EndToken()); err != nil { + return err + } + if fd.IsRepeated() { + if err := m.TryAddRepeatedField(fd, g); err != nil { + return err + } + } else { + if err := m.TrySetField(fd, g); err != nil { + return err + } + } + } else { + if tok.tokTyp != tokenColon { + return textError(tok, "Expecting a colon ':'; instead got %q", tok.txt) + } + if err := m.unmarshalFieldValueText(fd, tr); err != nil { + return err + } + } + tok = tr.peek() + if tok.tokTyp.IsSep() { + tr.next() // consume separator + } + } +} +func findMessageDescriptor(name string, fd *desc.FileDescriptor) *desc.MessageDescriptor { + md := findMessageInTransitiveDeps(name, fd, map[*desc.FileDescriptor]struct{}{}) + if md == nil { + // couldn't find it; see if we have this message linked in + md, _ = desc.LoadMessageDescriptor(name) + } + return md +} + +func findMessageInTransitiveDeps(name string, fd *desc.FileDescriptor, seen map[*desc.FileDescriptor]struct{}) *desc.MessageDescriptor { + if _, ok := seen[fd]; ok { + // already checked this file + return nil + } + seen[fd] = struct{}{} + md := fd.FindMessage(name) + if md != nil { + return md + } + // not in this file so recursively search its deps + for _, dep := range fd.GetDependencies() { + md = findMessageInTransitiveDeps(name, dep, seen) + if md != nil { + return md + } + } + // couldn't find it + return nil +} + +func textError(tok *token, format string, args ...interface{}) error { + var msg string + if tok.tokTyp == tokenError { + msg = tok.val.(error).Error() + } else { + msg = fmt.Sprintf(format, args...) + } + return fmt.Errorf("line %d, col %d: %s", tok.pos.Line, tok.pos.Column, msg) +} + +type setFunction func(*Message, *desc.FieldDescriptor, interface{}) error + +func (m *Message) unmarshalFieldValueText(fd *desc.FieldDescriptor, tr *txtReader) error { + var set setFunction + if fd.IsRepeated() { + set = (*Message).addRepeatedField + } else { + set = mergeField + } + tok := tr.peek() + if tok.tokTyp == tokenOpenBracket { + tr.next() // consume tok + for { + if err := m.unmarshalFieldElementText(fd, tr, set); err != nil { + return err + } + tok = tr.peek() + if tok.tokTyp == tokenCloseBracket { + tr.next() // consume tok + return nil + } else if tok.tokTyp.IsSep() { + tr.next() // consume separator + } + } + } + return m.unmarshalFieldElementText(fd, tr, set) +} + +func (m *Message) unmarshalFieldElementText(fd *desc.FieldDescriptor, tr *txtReader, set setFunction) error { + tok := tr.next() + if tok.tokTyp == tokenEOF { + return io.ErrUnexpectedEOF + } + + var expected string + switch fd.GetType() { + case descriptorpb.FieldDescriptorProto_TYPE_BOOL: + if tok.tokTyp == tokenIdent { + if tok.val.(string) == "true" { + return set(m, fd, true) + } else if tok.val.(string) == "false" { + return set(m, fd, false) + } + } + expected = "boolean value" + case descriptorpb.FieldDescriptorProto_TYPE_BYTES: + if tok.tokTyp == tokenString { + return set(m, fd, []byte(tok.val.(string))) + } + expected = "bytes string value" + case descriptorpb.FieldDescriptorProto_TYPE_STRING: + if tok.tokTyp == tokenString { + return set(m, fd, tok.val) + } + expected = "string value" + case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: + switch tok.tokTyp { + case tokenFloat: + return set(m, fd, float32(tok.val.(float64))) + case tokenInt: + if f, err := strconv.ParseFloat(tok.val.(string), 32); err != nil { + return err + } else { + return set(m, fd, float32(f)) + } + case tokenIdent: + ident := strings.ToLower(tok.val.(string)) + if ident == "inf" { + return set(m, fd, float32(math.Inf(1))) + } else if ident == "nan" { + return set(m, fd, float32(math.NaN())) + } + case tokenMinus: + peeked := tr.peek() + if peeked.tokTyp == tokenIdent { + ident := strings.ToLower(peeked.val.(string)) + if ident == "inf" { + tr.next() // consume peeked token + return set(m, fd, float32(math.Inf(-1))) + } + } + } + expected = "float value" + case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: + switch tok.tokTyp { + case tokenFloat: + return set(m, fd, tok.val) + case tokenInt: + if f, err := strconv.ParseFloat(tok.val.(string), 64); err != nil { + return err + } else { + return set(m, fd, f) + } + case tokenIdent: + ident := strings.ToLower(tok.val.(string)) + if ident == "inf" { + return set(m, fd, math.Inf(1)) + } else if ident == "nan" { + return set(m, fd, math.NaN()) + } + case tokenMinus: + peeked := tr.peek() + if peeked.tokTyp == tokenIdent { + ident := strings.ToLower(peeked.val.(string)) + if ident == "inf" { + tr.next() // consume peeked token + return set(m, fd, math.Inf(-1)) + } + } + } + expected = "float value" + case descriptorpb.FieldDescriptorProto_TYPE_INT32, + descriptorpb.FieldDescriptorProto_TYPE_SINT32, + descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: + if tok.tokTyp == tokenInt { + if i, err := strconv.ParseInt(tok.val.(string), 10, 32); err != nil { + return err + } else { + return set(m, fd, int32(i)) + } + } + expected = "int value" + case descriptorpb.FieldDescriptorProto_TYPE_INT64, + descriptorpb.FieldDescriptorProto_TYPE_SINT64, + descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: + if tok.tokTyp == tokenInt { + if i, err := strconv.ParseInt(tok.val.(string), 10, 64); err != nil { + return err + } else { + return set(m, fd, i) + } + } + expected = "int value" + case descriptorpb.FieldDescriptorProto_TYPE_UINT32, + descriptorpb.FieldDescriptorProto_TYPE_FIXED32: + if tok.tokTyp == tokenInt { + if i, err := strconv.ParseUint(tok.val.(string), 10, 32); err != nil { + return err + } else { + return set(m, fd, uint32(i)) + } + } + expected = "unsigned int value" + case descriptorpb.FieldDescriptorProto_TYPE_UINT64, + descriptorpb.FieldDescriptorProto_TYPE_FIXED64: + if tok.tokTyp == tokenInt { + if i, err := strconv.ParseUint(tok.val.(string), 10, 64); err != nil { + return err + } else { + return set(m, fd, i) + } + } + expected = "unsigned int value" + case descriptorpb.FieldDescriptorProto_TYPE_ENUM: + if tok.tokTyp == tokenIdent { + // TODO: add a flag to just ignore unrecognized enum value names? + vd := fd.GetEnumType().FindValueByName(tok.val.(string)) + if vd != nil { + return set(m, fd, vd.GetNumber()) + } + } else if tok.tokTyp == tokenInt { + if i, err := strconv.ParseInt(tok.val.(string), 10, 32); err != nil { + return err + } else { + return set(m, fd, int32(i)) + } + } + expected = fmt.Sprintf("enum %s value", fd.GetEnumType().GetFullyQualifiedName()) + case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, + descriptorpb.FieldDescriptorProto_TYPE_GROUP: + + endTok := tok.tokTyp.EndToken() + if endTok != tokenError { + dm := m.mf.NewDynamicMessage(fd.GetMessageType()) + if err := dm.unmarshalText(tr, endTok); err != nil { + return err + } + // TODO: ideally we would use mf.NewMessage and, if not a dynamic message, use + // proto package to unmarshal it. But the text parser isn't particularly amenable + // to that, so we instead convert a dynamic message to a generated one if the + // known-type registry knows about the generated type... + var ktr *KnownTypeRegistry + if m.mf != nil { + ktr = m.mf.ktr + } + pm := ktr.CreateIfKnown(fd.GetMessageType().GetFullyQualifiedName()) + if pm != nil { + if err := dm.ConvertTo(pm); err != nil { + return set(m, fd, pm) + } + } + return set(m, fd, dm) + } + expected = fmt.Sprintf("message %s value", fd.GetMessageType().GetFullyQualifiedName()) + default: + return fmt.Errorf("field %q of message %q has unrecognized type: %v", fd.GetFullyQualifiedName(), m.md.GetFullyQualifiedName(), fd.GetType()) + } + + // if we get here, token was wrong type; create error message + var article string + if strings.Contains("aieou", expected[0:1]) { + article = "an" + } else { + article = "a" + } + return textError(tok, "Expecting %s %s; got %q", article, expected, tok.txt) +} + +func unmarshalFieldNameText(tr *txtReader, tok *token) (string, error) { + if tok.tokTyp == tokenOpenBracket || tok.tokTyp == tokenOpenParen { + // extension name + var closeType tokenType + var closeChar string + if tok.tokTyp == tokenOpenBracket { + closeType = tokenCloseBracket + closeChar = "close bracket ']'" + } else { + closeType = tokenCloseParen + closeChar = "close paren ')'" + } + // must be followed by an identifier + idents := make([]string, 0, 1) + for { + tok = tr.next() + if tok.tokTyp == tokenEOF { + return "", io.ErrUnexpectedEOF + } else if tok.tokTyp != tokenIdent { + return "", textError(tok, "Expecting an identifier; instead got %q", tok.txt) + } + idents = append(idents, tok.val.(string)) + // and then close bracket/paren, or "/" to keep adding URL elements to name + tok = tr.next() + if tok.tokTyp == tokenEOF { + return "", io.ErrUnexpectedEOF + } else if tok.tokTyp == closeType { + break + } else if tok.tokTyp != tokenSlash { + return "", textError(tok, "Expecting a %s; instead got %q", closeChar, tok.txt) + } + } + return "[" + strings.Join(idents, "/") + "]", nil + } else if tok.tokTyp == tokenIdent { + // normal field name + return tok.val.(string), nil + } else { + return "", textError(tok, "Expecting an identifier or tag number; instead got %q", tok.txt) + } +} + +func skipFieldNameText(tr *txtReader) error { + tok := tr.next() + if tok.tokTyp == tokenEOF { + return io.ErrUnexpectedEOF + } else if tok.tokTyp == tokenInt || tok.tokTyp == tokenIdent { + return nil + } else { + _, err := unmarshalFieldNameText(tr, tok) + return err + } +} + +func skipFieldValueText(tr *txtReader) error { + tok := tr.peek() + if tok.tokTyp == tokenOpenBracket { + tr.next() // consume tok + for { + if err := skipFieldElementText(tr); err != nil { + return err + } + tok = tr.peek() + if tok.tokTyp == tokenCloseBracket { + tr.next() // consume tok + return nil + } else if tok.tokTyp.IsSep() { + tr.next() // consume separator + } + + } + } + return skipFieldElementText(tr) +} + +func skipFieldElementText(tr *txtReader) error { + tok := tr.next() + switch tok.tokTyp { + case tokenEOF: + return io.ErrUnexpectedEOF + case tokenInt, tokenFloat, tokenString, tokenIdent: + return nil + case tokenOpenAngle: + return skipMessageText(tr, false) + default: + return textError(tok, "Expecting an angle bracket '<' or a value; instead got %q", tok.txt) + } +} + +func skipMessageText(tr *txtReader, isGroup bool) error { + for { + tok := tr.peek() + if tok.tokTyp == tokenEOF { + return io.ErrUnexpectedEOF + } else if isGroup && tok.tokTyp == tokenCloseBrace { + return nil + } else if !isGroup && tok.tokTyp == tokenCloseAngle { + return nil + } + + // field name or tag + if err := skipFieldNameText(tr); err != nil { + return err + } + + // field value + tok = tr.next() + if tok.tokTyp == tokenEOF { + return io.ErrUnexpectedEOF + } else if tok.tokTyp == tokenOpenBrace { + if err := skipMessageText(tr, true); err != nil { + return err + } + } else if tok.tokTyp == tokenColon { + if err := skipFieldValueText(tr); err != nil { + return err + } + } else { + return textError(tok, "Expecting a colon ':' or brace '{'; instead got %q", tok.txt) + } + + tok = tr.peek() + if tok.tokTyp.IsSep() { + tr.next() // consume separator + } + } +} + +type tokenType int + +const ( + tokenError tokenType = iota + tokenEOF + tokenIdent + tokenString + tokenInt + tokenFloat + tokenColon + tokenComma + tokenSemiColon + tokenOpenBrace + tokenCloseBrace + tokenOpenBracket + tokenCloseBracket + tokenOpenAngle + tokenCloseAngle + tokenOpenParen + tokenCloseParen + tokenSlash + tokenMinus +) + +func (t tokenType) IsSep() bool { + return t == tokenComma || t == tokenSemiColon +} + +func (t tokenType) EndToken() tokenType { + switch t { + case tokenOpenAngle: + return tokenCloseAngle + case tokenOpenBrace: + return tokenCloseBrace + default: + return tokenError + } +} + +type token struct { + tokTyp tokenType + val interface{} + txt string + pos scanner.Position +} + +type txtReader struct { + scanner scanner.Scanner + peeked token + havePeeked bool +} + +func newReader(text []byte) *txtReader { + sc := scanner.Scanner{} + sc.Init(bytes.NewReader(text)) + sc.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanChars | + scanner.ScanStrings | scanner.ScanComments | scanner.SkipComments + // identifiers are same restrictions as Go identifiers, except we also allow dots since + // we accept fully-qualified names + sc.IsIdentRune = func(ch rune, i int) bool { + return ch == '_' || unicode.IsLetter(ch) || + (i > 0 && unicode.IsDigit(ch)) || + (i > 0 && ch == '.') + } + // ignore errors; we handle them if/when we see malformed tokens + sc.Error = func(s *scanner.Scanner, msg string) {} + return &txtReader{scanner: sc} +} + +func (p *txtReader) peek() *token { + if p.havePeeked { + return &p.peeked + } + t := p.scanner.Scan() + if t == scanner.EOF { + p.peeked.tokTyp = tokenEOF + p.peeked.val = nil + p.peeked.txt = "" + p.peeked.pos = p.scanner.Position + } else if err := p.processToken(t, p.scanner.TokenText(), p.scanner.Position); err != nil { + p.peeked.tokTyp = tokenError + p.peeked.val = err + } + p.havePeeked = true + return &p.peeked +} + +func (p *txtReader) processToken(t rune, text string, pos scanner.Position) error { + p.peeked.pos = pos + p.peeked.txt = text + switch t { + case scanner.Ident: + p.peeked.tokTyp = tokenIdent + p.peeked.val = text + case scanner.Int: + p.peeked.tokTyp = tokenInt + p.peeked.val = text // can't parse the number because we don't know if it's signed or unsigned + case scanner.Float: + p.peeked.tokTyp = tokenFloat + var err error + if p.peeked.val, err = strconv.ParseFloat(text, 64); err != nil { + return err + } + case scanner.Char, scanner.String: + p.peeked.tokTyp = tokenString + var err error + if p.peeked.val, err = strconv.Unquote(text); err != nil { + return err + } + case '-': // unary minus, for negative ints and floats + ch := p.scanner.Peek() + if ch < '0' || ch > '9' { + p.peeked.tokTyp = tokenMinus + p.peeked.val = '-' + } else { + t := p.scanner.Scan() + if t == scanner.EOF { + return io.ErrUnexpectedEOF + } else if t == scanner.Float { + p.peeked.tokTyp = tokenFloat + text += p.scanner.TokenText() + p.peeked.txt = text + var err error + if p.peeked.val, err = strconv.ParseFloat(text, 64); err != nil { + p.peeked.pos = p.scanner.Position + return err + } + } else if t == scanner.Int { + p.peeked.tokTyp = tokenInt + text += p.scanner.TokenText() + p.peeked.txt = text + p.peeked.val = text // can't parse the number because we don't know if it's signed or unsigned + } else { + p.peeked.pos = p.scanner.Position + return fmt.Errorf("expecting an int or float but got %q", p.scanner.TokenText()) + } + } + case ':': + p.peeked.tokTyp = tokenColon + p.peeked.val = ':' + case ',': + p.peeked.tokTyp = tokenComma + p.peeked.val = ',' + case ';': + p.peeked.tokTyp = tokenSemiColon + p.peeked.val = ';' + case '{': + p.peeked.tokTyp = tokenOpenBrace + p.peeked.val = '{' + case '}': + p.peeked.tokTyp = tokenCloseBrace + p.peeked.val = '}' + case '<': + p.peeked.tokTyp = tokenOpenAngle + p.peeked.val = '<' + case '>': + p.peeked.tokTyp = tokenCloseAngle + p.peeked.val = '>' + case '[': + p.peeked.tokTyp = tokenOpenBracket + p.peeked.val = '[' + case ']': + p.peeked.tokTyp = tokenCloseBracket + p.peeked.val = ']' + case '(': + p.peeked.tokTyp = tokenOpenParen + p.peeked.val = '(' + case ')': + p.peeked.tokTyp = tokenCloseParen + p.peeked.val = ')' + case '/': + // only allowed to separate URL components in expanded Any format + p.peeked.tokTyp = tokenSlash + p.peeked.val = '/' + default: + return fmt.Errorf("invalid character: %c", t) + } + return nil +} + +func (p *txtReader) next() *token { + t := p.peek() + if t.tokTyp != tokenEOF && t.tokTyp != tokenError { + p.havePeeked = false + } + return t +} |
