blob: 5b971679d9a756ca43221efe8d1d854a73932b9b [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vom
import (
"io"
"os"
"reflect"
"v.io/v23/vdl"
"v.io/v23/verror"
)
var (
errDecodeNil = verror.Register(pkgPath+".errDecodeNil", verror.NoRetry, "{1:}{2:} vom: invalid decode into nil interface{}{:_}")
errDecodeNilRawValue = verror.Register(pkgPath+".errDecodeNilRawValue", verror.NoRetry, "{1:}{2:} vom: invalid decode into nil *RawValue{:_}")
errDecodeZeroTypeID = verror.Register(pkgPath+".errDecodeZeroTypeID", verror.NoRetry, "{1:}{2:} vom: zero type id{:_}")
errIndexOutOfRange = verror.Register(pkgPath+".errIndexOutOfRange", verror.NoRetry, "{1:}{2:} vom: index out of range{:_}")
errLeftOverBytes = verror.Register(pkgPath+".errLeftOverBytes", verror.NoRetry, "{1:}{2:} vom: {3} leftover bytes{:_}")
errUnexpectedControlByte = verror.Register(pkgPath+".errUnexpectedControlByte", verror.NoRetry, "{1:}{2:} vom: unexpected control byte {3}{:_}")
errDecodeValueUnhandledType = verror.Register(pkgPath+".errDecodeValueUnhandledType", verror.NoRetry, "{1:}{2:} vom: decodeValue unhandled type {3}{:_}")
errIgnoreValueUnhandledType = verror.Register(pkgPath+".errIgnoreValueUnhandledType", verror.NoRetry, "{1:}{2:} vom: ignoreValue unhandled type {3}{:_}")
)
// Decoder manages the receipt and unmarshalling of typed values from the other
// side of a connection.
type Decoder struct {
mr *messageReader
typeDec *TypeDecoder
}
// This is only used for debugging; add this as the first line of NewDecoder to
// dump formatted vom bytes to stdout:
// r = teeDump(r)
func teeDump(r io.Reader) io.Reader {
return io.TeeReader(r, NewDumper(NewDumpWriter(os.Stdout)))
}
// NewDecoder returns a new Decoder that reads from the given reader. The
// Decoder understands all formats generated by the Encoder.
func NewDecoder(r io.Reader) *Decoder {
// When the TypeDecoder isn't shared, we always decode type messages in
// Decoder.decodeValueType() and feed them to the TypeDecoder. That is,
// the TypeDecoder will never read messages from the buffer. So we pass
// a nil buffer to newTypeDecoder.
mr := newMessageReader(newDecbuf(r))
typeDec := newTypeDecoderInternal(mr)
mr.SetCallbacks(typeDec.lookupType, typeDec.readSingleType)
return &Decoder{
mr: mr,
typeDec: typeDec,
}
}
// NewDecoderWithTypeDecoder returns a new Decoder that reads from the given
// reader. Types will be decoded separately through the given typeDec.
func NewDecoderWithTypeDecoder(r io.Reader, typeDec *TypeDecoder) *Decoder {
mr := newMessageReader(newDecbuf(r))
mr.SetCallbacks(typeDec.lookupType, nil)
return &Decoder{
mr: mr,
typeDec: typeDec,
}
}
// Decode reads the next value from the reader(s) and stores it in value v.
// The type of v need not exactly match the type of the originally encoded
// value; decoding succeeds as long as the values are compatible.
//
// Types that are special-cased, only for v:
// *RawValue - Store raw (uninterpreted) bytes in v.
//
// Types that are special-cased, recursively throughout v:
// *vdl.Value - Decode into v.
// reflect.Value - Decode into v, which must be settable.
//
// Decoding into a RawValue captures the value in a raw form, which may be
// subsequently passed to an Encoder for transcoding.
//
// Decode(nil) always returns an error. Use Ignore() to ignore the next value.
func (d *Decoder) Decode(v interface{}) error {
switch tv := v.(type) {
case nil:
return verror.New(errDecodeNil, nil)
case *RawValue:
if tv == nil {
return verror.New(errDecodeNilRawValue, nil)
}
return d.decodeRaw(tv)
}
tid, err := d.mr.StartValueMessage()
if err != nil {
return err
}
valType, err := d.typeDec.lookupType(tid)
if err != nil {
return err
}
if err := d.decodeValueMsg(valType, v); err != nil {
return err
}
return d.mr.EndMessage()
}
// Ignore ignores the next value from the reader.
func (d *Decoder) Ignore() error {
tid, err := d.mr.StartValueMessage()
if err != nil {
return err
}
valType, err := d.typeDec.lookupType(tid)
if err != nil {
return err
}
valLen, err := d.decodeValueByteLen(valType)
if err != nil {
return err
}
if err := d.mr.Skip(valLen); err != nil {
return err
}
return d.mr.EndMessage()
}
func (d *Decoder) decodeRaw(raw *RawValue) error {
tid, err := d.mr.StartValueMessage()
if err != nil {
return err
}
if raw.t, err = d.typeDec.lookupType(tid); err != nil {
return err
}
if raw.value, err = d.mr.ReadAllValueBytes(); err != nil {
return err
}
refTypeLen := len(d.mr.AllReferencedTypes())
if cap(raw.refTypes) >= refTypeLen {
raw.refTypes = raw.refTypes[:refTypeLen]
} else {
raw.refTypes = make([]*vdl.Type, refTypeLen)
}
for i, tid := range d.mr.AllReferencedTypes() {
if raw.refTypes[i], err = d.typeDec.lookupType(tid); err != nil {
return err
}
}
return d.mr.EndMessage()
}
// decodeWireType decodes the next type definition message and returns its
// type id.
func (d *Decoder) decodeWireType(wt *wireType) (typeId, error) {
tid, err := d.mr.StartTypeMessage()
if err != nil {
return 0, err
}
// Decode the wire type like a regular value.
if err := d.decodeValueMsg(wireTypeType, wt); err != nil {
return 0, err
}
return tid, d.mr.EndMessage()
}
// decodeValueByteLen returns the byte length of the next value.
func (d *Decoder) decodeValueByteLen(tt *vdl.Type) (int, error) {
if hasChunkLen(tt) {
// Use the explicit message length.
if d.mr.version == Version81 {
// TODO(bprosnitz) Implement this for version 81
panic("not yet implemented for version 81")
}
return d.mr.buf.lim, nil
}
// No explicit message length, but the length can be computed.
switch {
case tt.Kind() == vdl.Byte:
// Single byte is always encoded as 1 byte.
return 1, nil
case tt.Kind() == vdl.Array && tt.IsBytes():
// Byte arrays are exactly their length and encoded with 1-byte header.
return tt.Len() + 1, nil
case tt.Kind() == vdl.String || tt.IsBytes():
// Strings and byte lists are encoded with a length header.
strlen, bytelen, err := binaryPeekUint(d.mr)
switch {
case err != nil:
return 0, err
case strlen > maxBinaryMsgLen:
return 0, verror.New(errMsgLen, nil)
}
return int(strlen) + bytelen, nil
default:
// Must be a primitive, which is encoded as an underlying uint.
return binaryPeekUintByteLen(d.mr)
}
}
// decodeValueMsg decodes the rest of the message assuming type tt
func (d *Decoder) decodeValueMsg(tt *vdl.Type, v interface{}) error {
target, err := vdl.ReflectTarget(reflect.ValueOf(v))
if err != nil {
return err
}
return d.decodeValue(tt, target)
}
// decodeValue decodes the rest of the message assuming type tt.
func (d *Decoder) decodeValue(tt *vdl.Type, target vdl.Target) error {
ttFrom := tt
if tt.Kind() == vdl.Optional {
// If the type is optional, we expect to see either WireCtrlNil or the actual
// value, but not both. And thus, we can just peek for the WireCtrlNil here.
switch ctrl, err := binaryPeekControl(d.mr); {
case err != nil:
return err
case ctrl == WireCtrlNil:
d.mr.Skip(1)
return target.FromNil(ttFrom)
}
tt = tt.Elem()
}
if tt.IsBytes() {
len, err := binaryDecodeLenOrArrayLen(d.mr, tt)
if err != nil {
return err
}
// TODO(toddw): remove allocation
buf := make([]byte, len)
if err := d.mr.ReadIntoBuf(buf); err != nil {
return err
}
return target.FromBytes(buf, ttFrom)
}
switch kind := tt.Kind(); kind {
case vdl.Bool:
v, err := binaryDecodeBool(d.mr)
if err != nil {
return err
}
return target.FromBool(v, ttFrom)
case vdl.Byte, vdl.Uint16, vdl.Uint32, vdl.Uint64:
var v uint64
if tt.Kind() == vdl.Byte && d.mr.version == Version80 {
b, err := d.mr.ReadByte()
if err != nil {
return err
}
v = uint64(b)
} else {
var err error
v, err = binaryDecodeUint(d.mr)
if err != nil {
return err
}
}
return target.FromUint(v, ttFrom)
case vdl.Int8, vdl.Int16, vdl.Int32, vdl.Int64:
if d.mr.version == Version80 && tt.Kind() == vdl.Int8 {
return verror.New(errUnsupportedInVOMVersion, nil, "int8", d.mr.version)
}
v, err := binaryDecodeInt(d.mr)
if err != nil {
return err
}
return target.FromInt(v, ttFrom)
case vdl.Float32, vdl.Float64:
v, err := binaryDecodeFloat(d.mr)
if err != nil {
return err
}
return target.FromFloat(v, ttFrom)
case vdl.Complex64, vdl.Complex128:
re, err := binaryDecodeFloat(d.mr)
if err != nil {
return err
}
im, err := binaryDecodeFloat(d.mr)
if err != nil {
return err
}
return target.FromComplex(complex(re, im), ttFrom)
case vdl.String:
v, err := binaryDecodeString(d.mr)
if err != nil {
return err
}
return target.FromString(v, ttFrom)
case vdl.Enum:
index, err := binaryDecodeUint(d.mr)
switch {
case err != nil:
return err
case index >= uint64(tt.NumEnumLabel()):
return verror.New(errIndexOutOfRange, nil)
}
return target.FromEnumLabel(tt.EnumLabel(int(index)), ttFrom)
case vdl.TypeObject:
x, err := binaryDecodeUint(d.mr)
if err != nil {
return err
}
var typeobject *vdl.Type
if d.mr.version == Version80 {
typeobject, err = d.typeDec.lookupType(typeId(x))
} else {
typeobject, err = d.mr.ReferencedType(x)
}
if err != nil {
return err
}
return target.FromTypeObject(typeobject)
case vdl.Array, vdl.List:
len, err := binaryDecodeLenOrArrayLen(d.mr, tt)
if err != nil {
return err
}
listTarget, err := target.StartList(ttFrom, len)
if err != nil {
return err
}
for ix := 0; ix < len; ix++ {
elem, err := listTarget.StartElem(ix)
if err != nil {
return err
}
if err := d.decodeValue(tt.Elem(), elem); err != nil {
return err
}
if err := listTarget.FinishElem(elem); err != nil {
return err
}
}
return target.FinishList(listTarget)
case vdl.Set:
len, err := binaryDecodeLen(d.mr)
if err != nil {
return err
}
setTarget, err := target.StartSet(ttFrom, len)
if err != nil {
return err
}
for ix := 0; ix < len; ix++ {
key, err := setTarget.StartKey()
if err != nil {
return err
}
if err := d.decodeValue(tt.Key(), key); err != nil {
return err
}
switch err := setTarget.FinishKey(key); {
case err == vdl.ErrFieldNoExist:
continue
case err != nil:
return err
}
}
return target.FinishSet(setTarget)
case vdl.Map:
len, err := binaryDecodeLen(d.mr)
if err != nil {
return err
}
mapTarget, err := target.StartMap(ttFrom, len)
if err != nil {
return err
}
for ix := 0; ix < len; ix++ {
key, err := mapTarget.StartKey()
if err != nil {
return err
}
if err := d.decodeValue(tt.Key(), key); err != nil {
return err
}
switch field, err := mapTarget.FinishKeyStartField(key); {
case err == vdl.ErrFieldNoExist:
if err := d.ignoreValue(tt.Elem()); err != nil {
return err
}
case err != nil:
return err
default:
if err := d.decodeValue(tt.Elem(), field); err != nil {
return err
}
if err := mapTarget.FinishField(key, field); err != nil {
return err
}
}
}
return target.FinishMap(mapTarget)
case vdl.Struct:
fieldsTarget, err := target.StartFields(ttFrom)
if err != nil {
return err
}
// Loop through decoding the 0-based field index and corresponding field.
decodedFields := make([]bool, tt.NumField())
for {
index, ctrl, err := binaryDecodeUintWithControl(d.mr)
switch {
case err != nil:
return err
case ctrl == WireCtrlEnd:
// Fill not-yet-decoded fields with their zero values.
for index, decoded := range decodedFields {
if decoded {
continue
}
ttfield := tt.Field(index)
switch key, field, err := fieldsTarget.StartField(ttfield.Name); {
case err == vdl.ErrFieldNoExist:
// Ignore it.
case err != nil:
return err
default:
if err := vdl.FromValue(field, vdl.ZeroValue(ttfield.Type)); err != nil {
return err
}
if err := fieldsTarget.FinishField(key, field); err != nil {
return err
}
}
}
return target.FinishFields(fieldsTarget)
case ctrl != 0:
return verror.New(errUnexpectedControlByte, nil, ctrl)
case index >= uint64(tt.NumField()):
return verror.New(errIndexOutOfRange, nil)
}
ttfield := tt.Field(int(index))
switch key, field, err := fieldsTarget.StartField(ttfield.Name); {
case err == vdl.ErrFieldNoExist:
if err := d.ignoreValue(ttfield.Type); err != nil {
return err
}
case err != nil:
return err
default:
if err := d.decodeValue(ttfield.Type, field); err != nil {
return err
}
if err := fieldsTarget.FinishField(key, field); err != nil {
return err
}
}
decodedFields[index] = true
}
case vdl.Union:
fieldsTarget, err := target.StartFields(ttFrom)
if err != nil {
return err
}
index, err := binaryDecodeUint(d.mr)
switch {
case err != nil:
return err
case index >= uint64(tt.NumField()):
return verror.New(errIndexOutOfRange, nil)
}
ttfield := tt.Field(int(index))
key, field, err := fieldsTarget.StartField(ttfield.Name)
if err != nil {
return err
}
if err := d.decodeValue(ttfield.Type, field); err != nil {
return err
}
if err := fieldsTarget.FinishField(key, field); err != nil {
return err
}
return target.FinishFields(fieldsTarget)
case vdl.Any:
var elemType *vdl.Type
switch x, ctrl, err := binaryDecodeUintWithControl(d.mr); {
case err != nil:
return err
case ctrl == WireCtrlNil:
return target.FromNil(tt)
case ctrl != 0:
return verror.New(errUnexpectedControlByte, nil, ctrl)
case d.mr.version == Version80:
if elemType, err = d.typeDec.lookupType(typeId(x)); err != nil {
return err
}
default:
if elemType, err = d.mr.ReferencedType(x); err != nil {
return err
}
}
return d.decodeValue(elemType, target)
default:
panic(verror.New(errDecodeValueUnhandledType, nil, tt))
}
}
// ignoreValue ignores the rest of the value of type t. This is used to ignore
// unknown struct fields.
func (d *Decoder) ignoreValue(tt *vdl.Type) error {
if tt.IsBytes() {
len, err := binaryDecodeLenOrArrayLen(d.mr, tt)
if err != nil {
return err
}
return d.mr.Skip(len)
}
switch kind := tt.Kind(); kind {
case vdl.Bool:
return d.mr.Skip(1)
case vdl.Byte:
if d.mr.version == Version80 {
return d.mr.Skip(1)
} else {
return binaryIgnoreUint(d.mr)
}
case vdl.Uint16, vdl.Uint32, vdl.Uint64, vdl.Int8, vdl.Int16, vdl.Int32, vdl.Int64, vdl.Float32, vdl.Float64, vdl.Enum, vdl.TypeObject:
if d.mr.version == Version80 && tt.Kind() == vdl.Int8 {
return verror.New(errUnsupportedInVOMVersion, nil, "int8", d.mr.version)
}
// The underlying encoding of all these types is based on uint.
return binaryIgnoreUint(d.mr)
case vdl.Complex64, vdl.Complex128:
// Complex is encoded as two floats, so we can simply ignore two uints.
if err := binaryIgnoreUint(d.mr); err != nil {
return err
}
return binaryIgnoreUint(d.mr)
case vdl.String:
return binaryIgnoreString(d.mr)
case vdl.Array, vdl.List, vdl.Set, vdl.Map:
len, err := binaryDecodeLenOrArrayLen(d.mr, tt)
if err != nil {
return err
}
for ix := 0; ix < len; ix++ {
if kind == vdl.Set || kind == vdl.Map {
if err := d.ignoreValue(tt.Key()); err != nil {
return err
}
}
if kind == vdl.Array || kind == vdl.List || kind == vdl.Map {
if err := d.ignoreValue(tt.Elem()); err != nil {
return err
}
}
}
return nil
case vdl.Struct:
// Loop through decoding the 0-based field index and corresponding field.
for {
switch index, ctrl, err := binaryDecodeUintWithControl(d.mr); {
case err != nil:
return err
case ctrl == WireCtrlEnd:
return nil
case ctrl != 0:
return verror.New(errUnexpectedControlByte, nil, ctrl)
case index >= uint64(tt.NumField()):
return verror.New(errIndexOutOfRange, nil)
default:
ttfield := tt.Field(int(index))
if err := d.ignoreValue(ttfield.Type); err != nil {
return err
}
}
}
case vdl.Union:
switch index, err := binaryDecodeUint(d.mr); {
case err != nil:
return err
case index >= uint64(tt.NumField()):
return verror.New(errIndexOutOfRange, nil)
default:
ttfield := tt.Field(int(index))
return d.ignoreValue(ttfield.Type)
}
case vdl.Any:
var elemType *vdl.Type
switch x, ctrl, err := binaryDecodeUintWithControl(d.mr); {
case err != nil:
return err
case ctrl == WireCtrlNil:
return nil
case ctrl != 0:
return verror.New(errUnexpectedControlByte, nil, ctrl)
case d.mr.version == Version80:
if elemType, err = d.typeDec.lookupType(typeId(x)); err != nil {
return err
}
default:
if elemType, err = d.mr.ReferencedType(x); err != nil {
return err
}
}
return d.ignoreValue(elemType)
default:
panic(verror.New(errIgnoreValueUnhandledType, nil, tt))
}
}