blob: 8ce9dcec52e315e1089b0d40b323644d1f38f92a [file] [log] [blame]
package vom
import (
"errors"
"fmt"
"reflect"
"v.io/v23/vdl"
)
var (
errDecodeZeroTypeID = errors.New("vom: type ID 0")
errIndexOutOfRange = errors.New("vom: index out of range")
)
type binaryDecoder struct {
buf *decbuf
recvTypes *decoderTypes
}
func newBinaryDecoder(buf *decbuf, types *decoderTypes) *binaryDecoder {
return &binaryDecoder{buf, types}
}
func (d *binaryDecoder) Decode(target vdl.Target) error {
valType, err := d.decodeValueType()
if err != nil {
return err
}
return d.decodeValueMsg(valType, target)
}
func (d *binaryDecoder) DecodeRaw(raw *RawValue) error {
valType, err := d.decodeValueType()
if err != nil {
return err
}
valLen, err := d.decodeValueByteLen(valType)
if err != nil {
return err
}
raw.recvTypes = d.recvTypes
raw.valType = valType
if cap(raw.data) >= valLen {
raw.data = raw.data[:valLen]
} else {
raw.data = make([]byte, valLen)
}
return d.buf.ReadFull(raw.data)
}
func (d *binaryDecoder) Ignore() error {
valType, err := d.decodeValueType()
if err != nil {
return err
}
valLen, err := d.decodeValueByteLen(valType)
if err != nil {
return err
}
return d.buf.Skip(valLen)
}
// decodeValueType returns the type of the next value message. Any type
// definition messages it encounters along the way are decoded and added to
// recvTypes.
func (d *binaryDecoder) decodeValueType() (*vdl.Type, error) {
for {
id, err := binaryDecodeInt(d.buf)
if err != nil {
return nil, err
}
switch {
case id == 0:
return nil, errDecodeZeroTypeID
case id > 0:
// This is a value message, the typeID is +id.
tid := typeID(+id)
tt, err := d.recvTypes.LookupOrBuildType(tid)
if err != nil {
return nil, err
}
return tt, nil
}
// This is a type message, the typeID is -id.
tid := typeID(-id)
// Decode the wireType like a regular value, and store it in recvTypes. The
// type will actually be built when a value message arrives using this tid.
var wt wireType
target, err := vdl.ReflectTarget(reflect.ValueOf(&wt))
if err != nil {
return nil, err
}
if err := d.decodeValueMsg(wireTypeType, target); err != nil {
return nil, err
}
if err := d.recvTypes.AddWireType(tid, wt); err != nil {
return nil, err
}
}
}
// decodeValueByteLen returns the byte length of the next value.
func (d *binaryDecoder) decodeValueByteLen(tt *vdl.Type) (int, error) {
if hasBinaryMsgLen(tt) {
// Use the explicit message length.
msgLen, err := binaryDecodeLen(d.buf)
if err != nil {
return 0, err
}
return msgLen, nil
}
// No explicit message length, but the length can be computed.
switch {
case tt.Kind() == vdl.Byte:
// Single byte is always encoded as 1 byte.
return 1, nil
case tt.Kind() == vdl.Array && tt.IsBytes():
// Byte arrays are exactly their length and encoded with 1-byte header.
return tt.Len() + 1, nil
case tt.Kind() == vdl.String || tt.IsBytes():
// Strings and byte lists are encoded with a length header.
strlen, bytelen, err := binaryPeekUint(d.buf)
switch {
case err != nil:
return 0, err
case strlen > maxBinaryMsgLen:
return 0, errMsgLen
}
return int(strlen) + bytelen, nil
default:
// Must be a primitive, which is encoded as an underlying uint.
return binaryPeekUintByteLen(d.buf)
}
}
// decodeValueMsg decodes the rest of the message assuming type t, handling the
// optional message length.
func (d *binaryDecoder) decodeValueMsg(tt *vdl.Type, target vdl.Target) error {
if hasBinaryMsgLen(tt) {
msgLen, err := binaryDecodeLen(d.buf)
if err != nil {
return err
}
d.buf.SetLimit(msgLen)
}
err := d.decodeValue(tt, target)
leftover := d.buf.RemoveLimit()
switch {
case err != nil:
return err
case leftover > 0:
return fmt.Errorf("vom: %d leftover bytes", leftover)
}
return nil
}
// decodeValue decodes the rest of the message assuming type tt.
func (d *binaryDecoder) decodeValue(tt *vdl.Type, target vdl.Target) error {
ttFrom := tt
if tt.Kind() == vdl.Optional {
// If the type is optional, we expect to see either WireCtrlNil or the actual
// value, but not both. And thus, we can just peek for the WireCtrlNil here.
switch ctrl, err := binaryPeekControl(d.buf); {
case err != nil:
return err
case ctrl == WireCtrlNil:
d.buf.Skip(1)
return target.FromNil(ttFrom)
}
tt = tt.Elem()
}
if tt.IsBytes() {
len, err := binaryDecodeLenOrArrayLen(d.buf, tt)
if err != nil {
return err
}
bytes, err := d.buf.ReadBuf(len)
if err != nil {
return err
}
return target.FromBytes(bytes, ttFrom)
}
switch kind := tt.Kind(); kind {
case vdl.Bool:
v, err := binaryDecodeBool(d.buf)
if err != nil {
return err
}
return target.FromBool(v, ttFrom)
case vdl.Byte:
v, err := d.buf.ReadByte()
if err != nil {
return err
}
return target.FromUint(uint64(v), ttFrom)
case vdl.Uint16, vdl.Uint32, vdl.Uint64:
v, err := binaryDecodeUint(d.buf)
if err != nil {
return err
}
return target.FromUint(v, ttFrom)
case vdl.Int16, vdl.Int32, vdl.Int64:
v, err := binaryDecodeInt(d.buf)
if err != nil {
return err
}
return target.FromInt(v, ttFrom)
case vdl.Float32, vdl.Float64:
v, err := binaryDecodeFloat(d.buf)
if err != nil {
return err
}
return target.FromFloat(v, ttFrom)
case vdl.Complex64, vdl.Complex128:
re, err := binaryDecodeFloat(d.buf)
if err != nil {
return err
}
im, err := binaryDecodeFloat(d.buf)
if err != nil {
return err
}
return target.FromComplex(complex(re, im), ttFrom)
case vdl.String:
v, err := binaryDecodeString(d.buf)
if err != nil {
return err
}
return target.FromString(v, ttFrom)
case vdl.Enum:
index, err := binaryDecodeUint(d.buf)
switch {
case err != nil:
return err
case index >= uint64(tt.NumEnumLabel()):
return errIndexOutOfRange
}
return target.FromEnumLabel(tt.EnumLabel(int(index)), ttFrom)
case vdl.TypeObject:
id, err := binaryDecodeUint(d.buf)
if err != nil {
return err
}
typeobj, err := d.recvTypes.LookupOrBuildType(typeID(id))
if err != nil {
return err
}
return target.FromTypeObject(typeobj)
case vdl.Array, vdl.List:
len, err := binaryDecodeLenOrArrayLen(d.buf, tt)
if err != nil {
return err
}
listTarget, err := target.StartList(ttFrom, len)
if err != nil {
return err
}
for ix := 0; ix < len; ix++ {
elem, err := listTarget.StartElem(ix)
if err != nil {
return err
}
if err := d.decodeValue(tt.Elem(), elem); err != nil {
return err
}
if err := listTarget.FinishElem(elem); err != nil {
return err
}
}
return target.FinishList(listTarget)
case vdl.Set:
len, err := binaryDecodeLen(d.buf)
if err != nil {
return err
}
setTarget, err := target.StartSet(ttFrom, len)
if err != nil {
return err
}
for ix := 0; ix < len; ix++ {
key, err := setTarget.StartKey()
if err != nil {
return err
}
if err := d.decodeValue(tt.Key(), key); err != nil {
return err
}
switch err := setTarget.FinishKey(key); {
case err == vdl.ErrFieldNoExist:
continue
case err != nil:
return err
}
}
return target.FinishSet(setTarget)
case vdl.Map:
len, err := binaryDecodeLen(d.buf)
if err != nil {
return err
}
mapTarget, err := target.StartMap(ttFrom, len)
if err != nil {
return err
}
for ix := 0; ix < len; ix++ {
key, err := mapTarget.StartKey()
if err != nil {
return err
}
if err := d.decodeValue(tt.Key(), key); err != nil {
return err
}
switch field, err := mapTarget.FinishKeyStartField(key); {
case err == vdl.ErrFieldNoExist:
if err := d.ignoreValue(tt.Elem()); err != nil {
return err
}
case err != nil:
return err
default:
if err := d.decodeValue(tt.Elem(), field); err != nil {
return err
}
if err := mapTarget.FinishField(key, field); err != nil {
return err
}
}
}
return target.FinishMap(mapTarget)
case vdl.Struct:
fieldsTarget, err := target.StartFields(ttFrom)
if err != nil {
return err
}
// Loop through decoding the 0-based field index and corresponding field.
decodedFields := make([]bool, tt.NumField())
for {
index, ctrl, err := binaryDecodeUintWithControl(d.buf)
switch {
case err != nil:
return err
case ctrl == WireCtrlEOF:
// Fill not-yet-decoded fields with their zero values.
for index, decoded := range decodedFields {
if decoded {
continue
}
ttfield := tt.Field(index)
switch key, field, err := fieldsTarget.StartField(ttfield.Name); {
case err == vdl.ErrFieldNoExist:
// Ignore it.
case err != nil:
return err
default:
if err := vdl.FromValue(field, vdl.ZeroValue(ttfield.Type)); err != nil {
return err
}
if err := fieldsTarget.FinishField(key, field); err != nil {
return err
}
}
}
return target.FinishFields(fieldsTarget)
case ctrl != 0:
return fmt.Errorf("vom: unexpected control byte 0x%x", ctrl)
case index >= uint64(tt.NumField()):
return errIndexOutOfRange
}
ttfield := tt.Field(int(index))
switch key, field, err := fieldsTarget.StartField(ttfield.Name); {
case err == vdl.ErrFieldNoExist:
if err := d.ignoreValue(ttfield.Type); err != nil {
return err
}
case err != nil:
return err
default:
if err := d.decodeValue(ttfield.Type, field); err != nil {
return err
}
if err := fieldsTarget.FinishField(key, field); err != nil {
return err
}
}
decodedFields[index] = true
}
case vdl.Union:
fieldsTarget, err := target.StartFields(ttFrom)
if err != nil {
return err
}
index, err := binaryDecodeUint(d.buf)
switch {
case err != nil:
return err
case index >= uint64(tt.NumField()):
return errIndexOutOfRange
}
ttfield := tt.Field(int(index))
key, field, err := fieldsTarget.StartField(ttfield.Name)
if err != nil {
return err
}
if err := d.decodeValue(ttfield.Type, field); err != nil {
return err
}
if err := fieldsTarget.FinishField(key, field); err != nil {
return err
}
return target.FinishFields(fieldsTarget)
case vdl.Any:
switch id, ctrl, err := binaryDecodeUintWithControl(d.buf); {
case err != nil:
return err
case ctrl == WireCtrlNil:
return target.FromNil(vdl.AnyType)
case ctrl != 0:
return fmt.Errorf("vom: unexpected control byte 0x%x", ctrl)
default:
elemType, err := d.recvTypes.LookupOrBuildType(typeID(id))
if err != nil {
return err
}
return d.decodeValue(elemType, target)
}
default:
panic(fmt.Errorf("vom: decodeValue unhandled type %v", tt))
}
}
// ignoreValue ignores the rest of the value of type t. This is used to ignore
// unknown struct fields.
func (d *binaryDecoder) ignoreValue(tt *vdl.Type) error {
if tt.IsBytes() {
len, err := binaryDecodeLenOrArrayLen(d.buf, tt)
if err != nil {
return err
}
return d.buf.Skip(len)
}
switch kind := tt.Kind(); kind {
case vdl.Bool, vdl.Byte:
return d.buf.Skip(1)
case vdl.Uint16, vdl.Uint32, vdl.Uint64, vdl.Int16, vdl.Int32, vdl.Int64, vdl.Float32, vdl.Float64, vdl.Enum, vdl.TypeObject:
// The underlying encoding of all these types is based on uint.
return binaryIgnoreUint(d.buf)
case vdl.Complex64, vdl.Complex128:
// Complex is encoded as two floats, so we can simply ignore two uints.
if err := binaryIgnoreUint(d.buf); err != nil {
return err
}
return binaryIgnoreUint(d.buf)
case vdl.String:
return binaryIgnoreString(d.buf)
case vdl.Array, vdl.List, vdl.Set, vdl.Map:
len, err := binaryDecodeLenOrArrayLen(d.buf, tt)
if err != nil {
return err
}
for ix := 0; ix < len; ix++ {
if kind == vdl.Set || kind == vdl.Map {
if err := d.ignoreValue(tt.Key()); err != nil {
return err
}
}
if kind == vdl.Array || kind == vdl.List || kind == vdl.Map {
if err := d.ignoreValue(tt.Elem()); err != nil {
return err
}
}
}
return nil
case vdl.Struct:
// Loop through decoding the 0-based field index and corresponding field.
for {
switch index, ctrl, err := binaryDecodeUintWithControl(d.buf); {
case err != nil:
return err
case ctrl == WireCtrlEOF:
return nil
case ctrl != 0:
return fmt.Errorf("vom: unexpected control byte 0x%x", ctrl)
case index >= uint64(tt.NumField()):
return errIndexOutOfRange
default:
ttfield := tt.Field(int(index))
if err := d.ignoreValue(ttfield.Type); err != nil {
return err
}
}
}
case vdl.Union:
switch index, err := binaryDecodeUint(d.buf); {
case err != nil:
return err
case index >= uint64(tt.NumField()):
return errIndexOutOfRange
default:
ttfield := tt.Field(int(index))
return d.ignoreValue(ttfield.Type)
}
case vdl.Any:
switch id, ctrl, err := binaryDecodeUintWithControl(d.buf); {
case err != nil:
return err
case ctrl == WireCtrlNil:
return nil
case ctrl != 0:
return fmt.Errorf("vom: unexpected control byte 0x%x", ctrl)
default:
elemType, err := d.recvTypes.LookupOrBuildType(typeID(id))
if err != nil {
return err
}
return d.ignoreValue(elemType)
}
default:
panic(fmt.Errorf("vom: ignoreValue unhandled type %v", tt))
}
}