blob: 4cece9afabd07451fab65d63ff4bafbe90fbbc98 [file] [log] [blame]
// Copyright 2016 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vdl
import (
"errors"
"fmt"
"reflect"
)
var (
errReadMustReflect = errors.New("vdl: read must be handled via reflection")
errReadIntoNilValue = errors.New("vdl: read into nil value")
errReadReflectCantSet = errors.New("vdl: read into unsettable reflect.Value")
errReadAnyAlreadyStarted = errors.New("vdl: read into any after StartValue called")
errReadAnyInterfaceOnly = errors.New("vdl: read into any only supported for interfaces")
)
// Read uses dec to decode a value into v, calling VDLRead methods and fast
// compiled readers when available, and using reflection otherwise. This is
// basically an all-purpose VDLRead implementation.
func Read(dec Decoder, v interface{}) error {
if v == nil {
return errReadIntoNilValue
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr && !rv.IsNil() {
// Fastpath check for non-reflect support. Unfortunately we must use
// reflection to detect the case where v is a nil pointer, which returns an
// error in ReadReflect.
//
// TODO(toddw): If reflection is too slow, add the nil pointer check to all
// VDLRead methods, as well as other readNonReflect cases below.
if err := readNonReflect(dec, false, v); err != errReadMustReflect {
return err
}
}
return ReadReflect(dec, rv)
}
func readNonReflect(dec Decoder, calledStart bool, v interface{}) error {
switch x := v.(type) {
case Reader:
// Reader handles the case where x has a code-generated decoder, and
// special-cases such as vdl.Value and vom.RawBytes.
if calledStart {
dec.IgnoreNextStartValue()
}
return x.VDLRead(dec)
case **Type:
// Special-case type decoding, since we must assign the hash-consed pointer
// for correctness, rather than filling in a newly-created Type.
if !calledStart {
if err := dec.StartValue(TypeObjectType); err != nil {
return err
}
}
var err error
if *x, err = dec.DecodeTypeObject(); err != nil {
return err
}
return dec.FinishValue()
// Cases after this point are purely performance optimizations.
// TODO(toddw): Handle other common cases.
case *[]byte:
if !calledStart {
if err := dec.StartValue(ttByteList); err != nil {
return err
}
}
if err := dec.DecodeBytes(-1, x); err != nil {
return err
}
return dec.FinishValue()
}
return errReadMustReflect
}
var ttByteList = ListType(ByteType)
// ReadReflect is like Read, but takes a reflect.Value argument.
func ReadReflect(dec Decoder, rv reflect.Value) error {
if !rv.IsValid() {
return errReadIntoNilValue
}
if !rv.CanSet() && rv.Kind() == reflect.Ptr && !rv.IsNil() {
// Dereference the pointer a single time to make rv settable.
rv = rv.Elem()
}
if !rv.CanSet() {
return errReadReflectCantSet
}
tt, err := TypeFromReflect(rv.Type())
if err != nil {
return err
}
return readReflect(dec, false, rv, tt)
}
// readReflect uses dec to decode a value into rv, which has VDL type tt. On
// success we guarantee that StartValue / FinishValue has been called on dec.
// If calledStart is true, StartValue has already been called.
func readReflect(dec Decoder, calledStart bool, rv reflect.Value, tt *Type) error {
// Handle decoding into an any rv value first, since vom.RawBytes.VDLRead
// doesn't support IgnoreNextStartValue, and requires that StartValue hasn't
// been called yet. Note that cases where the dec value is any but the rv
// value isn't any will pass through.
if tt == AnyType {
return readIntoAny(dec, calledStart, rv)
}
// Now start the decoder value, if we haven't already.
if !calledStart {
if err := dec.StartValue(tt); err != nil {
return err
}
}
// Handle nil decoded values next, to simplify the rest of the cases. This
// handles cases where the dec value is either any(nil) or optional(nil).
if dec.IsNil() {
return readFromNil(dec, rv, tt)
}
// Now we know that the decoded value isn't nil. Flatten pointers and check
// for fast non-reflect support.
rv = rvFlattenPointers(rv)
if err := readNonReflect(dec, true, rv.Addr().Interface()); err != errReadMustReflect {
return err
}
// Handle native types, which need the ToNative conversion. Notice that rv is
// never a pointer here, so we don't support native pointer types. In theory
// we could support native pointer types, but they're complicated and will
// probably slow everything down.
//
// TODO(toddw): Investigate support for native pointer types.
if ni := nativeInfoFromNative(rv.Type()); ni != nil {
rvWire := reflect.New(ni.WireType).Elem()
if err := readReflect(dec, true, rvWire, tt); err != nil {
return err
}
return ni.ToNative(rvWire, rv.Addr())
// NOTE: readReflect guarantees that FinishValue has already been called.
}
// Handle non-nil wire values.
if err := readNonNilValue(dec, rv, tt.NonOptional()); err != nil {
return err
}
return dec.FinishValue()
}
// readIntoAny uses dec to decode a value into rv, which has VDL type any.
func readIntoAny(dec Decoder, calledStart bool, rv reflect.Value) error {
if calledStart {
// The existing code ensures that calledStart is always false here, since
// readReflect(dec, true, ...) is only called in situations where it's
// impossible to call readIntoAny. E.g. it's called later in this function,
// which never calls it with another any type. If we did, we'd have a vdl
// any(any), which isn't allowed. This error tries to prevent future
// changes that will break this requirement.
//
// The requirement is mandated by vom.RawBytes.VDLRead, which doesn't handle
// IgnoreNextStartValue.
return errReadAnyAlreadyStarted
}
// Flatten pointers and check for fast non-reflect support, which handles
// vdl.Value and vom.RawBytes, and any other special-cases.
rv = rvFlattenPointers(rv)
if err := readNonReflect(dec, false, rv.Addr().Interface()); err != errReadMustReflect {
return err
}
// The only case left is to handle interfaces. We allow decoding into
// all interfaces, including interface{}.
if rv.Kind() != reflect.Interface {
return errReadAnyInterfaceOnly
}
if err := dec.StartValue(AnyType); err != nil {
return err
}
// Handle decoding any(nil) by setting the rv interface to nil. Note that the
// only case where dec.Type() is AnyType is when the value is any(nil).
if dec.Type() == AnyType {
if !rv.IsNil() {
rv.Set(reflect.Zero(rv.Type()))
}
return dec.FinishValue()
}
// Lookup the reflect type based on the decoder type, and create a new value
// to decode into. If the dec value is optional, ensure that we lookup based
// on an optional type. Note that if the dec value is nil, dec.Type() is
// already optional, so rtDecode will already be a pointer.
ttDecode := dec.Type()
if dec.IsOptional() && !dec.IsNil() {
ttDecode = OptionalType(ttDecode)
}
rtDecode := typeToReflectNew(ttDecode)
// Handle top-level "v.io/v23/vdl.WireError" types. TypeToReflect will find
// vdl.WireError based on regular wire type registration, and will find the Go
// error interface based on regular native type registration, and these are
// fine for nested error types.
//
// But this is the case where we're decoding into a top-level Go interface,
// and we'll lose type information if the dec value is nil. So instead we
// return the registered verror.E type. Examples:
//
// ttDecode -> rtDecode
// -----------------------
// WireError verror.E
// ?WireError *verror.E
// []WireError []vdl.WireError (1)
// []?WireError []error
//
// TODO(toddw): The (1) case above is weird; we would like to return verror.E,
// but that's hard because the native conversion we've registered doesn't
// currently include the verror.E type:
//
// ToNative(wire *vdl.WireError, native *error)
// FromNative(wire **vdl.WireError, native error)
//
// We could make this more consistent by registering a pair of conversion
// functions instead:
//
// ToNative(wire vdl.WireError, native *verror.E)
// FromNative(wire *vdl.WireError, native verror.E)
//
// ToNative(wire *verror.E, native *error)
// FromNative(wire **verror.E, native error)
if ttDecode.NonOptional().Name() == ErrorType.Elem().Name() {
if ni, err := nativeInfoForError(); err == nil {
if ttDecode.Kind() == Optional {
rtDecode = reflect.PtrTo(ni.NativeType)
} else {
rtDecode = ni.NativeType
}
}
}
if rtDecode == nil {
return fmt.Errorf("vdl: %v not registered, either call vdl.Register, or use vdl.Value or vom.RawBytes instead", dec.Type())
}
if !rtDecode.Implements(rv.Type()) {
return fmt.Errorf("vdl: %v doesn't implement %v", rtDecode, rv.Type())
}
// Handle both nil and non-nil values by decoding into rvDecode, and setting
// rv. Both nil and non-nil values are handled in the readReflect call.
rvDecode := reflect.New(rtDecode).Elem()
if err := readReflect(dec, true, rvDecode, ttDecode); err != nil {
return err
}
rv.Set(rvDecode)
// NOTE: readReflect guarantees that FinishValue has already been called.
return nil
}
// readFromNil uses dec to decode a nil value into rv, which has VDL type tt.
// The value in dec might be either any(nil) or optional(nil).
//
// REQUIRES: dec.IsNil() && tt != AnyType
func readFromNil(dec Decoder, rv reflect.Value, tt *Type) error {
if tt.Kind() != Optional {
return fmt.Errorf("vdl: can't decode nil into non-optional %v", tt)
}
// Flatten pointers until we have a single pointer left, or there were no
// pointers to begin with.
rt := rv.Type()
for rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rt.Elem()))
}
rv, rt = rv.Elem(), rt.Elem()
}
// Handle tricky cases where rv is a native type.
if handled, err := readFromNilNative(dec, rv, tt); handled {
return err
}
// Now handle the simple case: rv has one pointer left, and should be set to a
// nil pointer.
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("vdl: can't decode nil into non-pointer %v optional %v", rt, tt)
}
if !rv.IsNil() {
rv.Set(reflect.Zero(rt))
}
return dec.FinishValue()
}
// readFromNilNative handles tricky cases where rv is a native type. Returns
// true if rv is a native type and was handled, otherwise returns false.
//
// REQUIRES: rv.Type() has at most one pointer.
func readFromNilNative(dec Decoder, rv reflect.Value, tt *Type) (bool, error) {
var ni *nativeInfo
if rt := rv.Type(); rt.Kind() != reflect.Ptr {
// Handle the case where rv isn't a pointer; e.g. the Go error interface is
// a non-pointer native type, and is handled here.
ni = nativeInfoFromNative(rt)
} else {
// Handle the case where rv is a pointer, and the elem is a native type.
// E.g. *error is handled here. Note that we don't support native pointer
// types; see comments at other calls to nativeInfoFromNative.
ni = nativeInfoFromNative(rt.Elem())
if ni != nil {
if rv.IsNil() {
rv.Set(reflect.New(rt.Elem()))
}
rv = rv.Elem()
}
}
if ni != nil {
// Handle the native type from either case above. At this point, rv is the
// native type and isn't a nil pointer.
rvWire := reflect.New(ni.WireType).Elem()
if err := readReflect(dec, true, rvWire, tt); err != nil {
return true, err
}
return true, ni.ToNative(rvWire, rv.Addr())
// NOTE: readReflect guarantees that FinishValue has already been called.
}
return false, nil
}
// settable exists to avoid a call to reflect.Call() to invoke Set()
// which results in an allocation
type settable interface {
Set(string) error
}
func readNonNilValue(dec Decoder, rv reflect.Value, tt *Type) error {
// Handle named and unnamed []byte and [N]byte, where the element type is the
// unnamed byte type. Cases like []MyByte fall through and are handled as
// regular lists, since we can't easily convert []MyByte to []byte.
switch {
case tt.Kind() == Array && tt.Elem() == ByteType:
bytes := rv.Slice(0, tt.Len()).Interface().([]byte)
return dec.DecodeBytes(tt.Len(), &bytes)
case tt.Kind() == List && tt.Elem() == ByteType:
var bytes []byte
if err := dec.DecodeBytes(-1, &bytes); err != nil {
return err
}
rv.Set(reflect.ValueOf(bytes))
return nil
}
// Handle regular non-nil values.
switch kind := tt.Kind(); kind {
case Bool:
val, err := dec.DecodeBool()
if err != nil {
return err
}
rv.SetBool(val)
return nil
case String:
val, err := dec.DecodeString()
if err != nil {
return err
}
rv.SetString(val)
return nil
case Enum:
val, err := dec.DecodeString()
if err != nil {
return err
}
return rv.Addr().Interface().(settable).Set(val)
case Byte, Uint16, Uint32, Uint64:
val, err := dec.DecodeUint(kind.BitLen())
if err != nil {
return err
}
rv.SetUint(val)
return nil
case Int8, Int16, Int32, Int64:
val, err := dec.DecodeInt(kind.BitLen())
if err != nil {
return err
}
rv.SetInt(val)
return nil
case Float32, Float64:
val, err := dec.DecodeFloat(kind.BitLen())
if err != nil {
return err
}
rv.SetFloat(val)
return nil
case Array:
return readArray(dec, rv, tt)
case List:
return readList(dec, rv, tt)
case Set:
return readSet(dec, rv, tt)
case Map:
return readMap(dec, rv, tt)
case Struct:
return readStruct(dec, rv, tt)
case Union:
return readUnion(dec, rv, tt)
}
// Note that Any was already handled via readAny, Optional was handled via
// readFromNil (or stripped off for non-nil values), and TypeObject was
// handled via the readNonReflect special-case.
return fmt.Errorf("vdl: Read unhandled type %v %v", rv.Type(), tt)
}
func readArray(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
index := 0
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done != (index >= rv.Len()):
return fmt.Errorf("array len mismatch, done:%v index:%d len:%d %v", done, index, rt.Len(), rt)
case done:
return nil
}
if err := readReflect(dec, false, rv.Index(index), tt.Elem()); err != nil {
return err
}
index++
}
}
func readList(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
switch len := dec.LenHint(); {
case len > 0:
rv.Set(reflect.MakeSlice(rt, 0, len))
default:
rv.Set(reflect.Zero(rt))
}
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done:
return nil
}
elem := reflect.New(rt.Elem()).Elem()
if err := readReflect(dec, false, elem, tt.Elem()); err != nil {
return err
}
rv.Set(reflect.Append(rv, elem))
}
}
var rvEmptyStruct = reflect.ValueOf(struct{}{})
func readSet(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
tmpSet, isNil := reflect.Zero(rt), true
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done:
rv.Set(tmpSet)
return nil
}
key := reflect.New(rt.Key()).Elem()
if err := readReflect(dec, false, key, tt.Key()); err != nil {
return err
}
if isNil {
tmpSet, isNil = reflect.MakeMap(rt), false
}
tmpSet.SetMapIndex(key, rvEmptyStruct)
}
}
func readMap(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
tmpMap, isNil := reflect.Zero(rt), true
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done:
rv.Set(tmpMap)
return nil
}
key := reflect.New(rt.Key()).Elem()
if err := readReflect(dec, false, key, tt.Key()); err != nil {
return err
}
elem := reflect.New(rt.Elem()).Elem()
if err := readReflect(dec, false, elem, tt.Elem()); err != nil {
return err
}
if isNil {
tmpMap, isNil = reflect.MakeMap(rt), false
}
tmpMap.SetMapIndex(key, elem)
}
}
func readStruct(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
// Reset to the zero struct, since fields may be missing.
//
// TODO(toddw): Avoid repeated zero-setting of nested structs.
rvZero, err := rvZeroValue(rt, tt)
if err != nil {
return err
}
rv.Set(rvZero)
for {
name, err := dec.NextField()
switch {
case err != nil:
return err
case name == "":
return nil
}
switch ttField, index := tt.FieldByName(name); {
case index != -1:
rvField := rv.FieldByName(name)
if err := readReflect(dec, false, rvField, ttField.Type); err != nil {
return err
}
default:
if err := dec.SkipValue(); err != nil {
return err
}
}
}
}
func readUnion(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
name, err := dec.NextField()
switch {
case err != nil:
return err
case name == "":
return fmt.Errorf("missing field in union %v, from %v", rt, dec.Type())
}
ttField, index := tt.FieldByName(name)
if index == -1 {
return fmt.Errorf("field %q not in union %v, from %v", name, rt, dec.Type())
}
// We have a union interface. Create a new field based on its rep type, fill
// in its value, and assign the field to the interface.
ri, _, err := deriveReflectInfo(rt)
if err != nil {
return err
}
rvField := reflect.New(ri.UnionFields[index].RepType).Elem()
if err := readReflect(dec, false, rvField.Field(0), ttField.Type); err != nil {
return err
}
rv.Set(rvField)
switch name, err := dec.NextField(); {
case err != nil:
return err
case name != "":
return fmt.Errorf("extra field %q in union %v, from %v", name, rt, dec.Type())
}
return nil
}