blob: a1afa659b28df379090d19d89918dee1d3552f12 [file] [log] [blame]
// Copyright 2016 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vdl
import (
"errors"
"fmt"
"reflect"
)
var (
errReadMustReflect = errors.New("vdl: read must be handled via reflection")
errReadIntoNilValue = errors.New("vdl: read into nil value")
errReadReflectCantSet = errors.New("vdl: read into unsettable reflect.Value")
errReadAnyAlreadyStarted = errors.New("vdl: read into any after StartValue called")
errReadAnyInterfaceOnly = errors.New("vdl: read into any only supported for interfaces")
)
// Read uses dec to decode a value into v, calling VDLRead methods and fast
// compiled readers when available, and using reflection otherwise. This is
// basically an all-purpose VDLRead implementation.
func Read(dec Decoder, v interface{}) error {
if v == nil {
return errReadIntoNilValue
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Ptr && !rv.IsNil() {
// Fastpath check for non-reflect support. Unfortunately we must use
// reflection to detect the case where v is a nil pointer, which returns an
// error in ReadReflect.
//
// TODO(toddw): If reflection is too slow, add the nil pointer check to all
// VDLRead methods, as well as other readNonReflect cases below.
if err := readNonReflect(dec, false, v); err != errReadMustReflect {
return err
}
}
return ReadReflect(dec, rv)
}
func readNonReflect(dec Decoder, calledStart bool, v interface{}) error {
switch x := v.(type) {
case Reader:
// Reader handles the case where x has a code-generated decoder, and
// special-cases such as vdl.Value and vom.RawBytes.
if calledStart {
dec.IgnoreNextStartValue()
}
return x.VDLRead(dec)
case **Type:
// Special-case type decoding, since we must assign the hash-consed pointer
// for correctness, rather than filling in a newly-created Type.
if !calledStart {
if err := dec.StartValue(); err != nil {
return err
}
}
var err error
if *x, err = dec.DecodeTypeObject(); err != nil {
return err
}
return dec.FinishValue()
// Cases after this point are purely performance optimizations.
// TODO(toddw): Handle other common cases.
case *[]byte:
if !calledStart {
if err := dec.StartValue(); err != nil {
return err
}
}
if err := dec.DecodeBytes(-1, x); err != nil {
return err
}
return dec.FinishValue()
}
return errReadMustReflect
}
// ReadReflect is like Read, but takes a reflect.Value argument. Use Read if
// performance is important and you have an interface{} handy.
func ReadReflect(dec Decoder, rv reflect.Value) error {
if !rv.IsValid() {
return errReadIntoNilValue
}
if !rv.CanSet() && rv.Kind() == reflect.Ptr && !rv.IsNil() {
// Dereference the pointer a single time to make rv settable.
rv = rv.Elem()
}
if !rv.CanSet() {
return errReadReflectCantSet
}
tt, err := TypeFromReflect(rv.Type())
if err != nil {
return err
}
return readReflect(dec, false, rv, tt)
}
// readReflect uses dec to decode a value into rv, which has VDL type tt. On
// success we guarantee that StartValue / FinishValue has been called on dec.
// If calledStart is true, StartValue has already been called.
func readReflect(dec Decoder, calledStart bool, rv reflect.Value, tt *Type) error {
// Handle native types first, since they need the ToNative conversion.
if ni := nativeInfoFromNative(rv.Type()); ni != nil {
rvWire := reflect.New(ni.WireType).Elem()
if err := readNonNative(dec, calledStart, rvWire, tt); err != nil {
return err
}
return ni.ToNative(rvWire, rv.Addr())
}
return readNonNative(dec, calledStart, rv, tt)
}
func readNonNative(dec Decoder, calledStart bool, rv reflect.Value, tt *Type) error {
// Any is handled first, since any(nil) is handled differently from ?T(nil)
// contained in an any value, so this factoring makes things simpler.
if tt == AnyType {
return readAny(dec, calledStart, rv)
}
// Now we can start the decoder value, if we haven't already.
if !calledStart {
if err := dec.StartValue(); err != nil {
return err
}
if err := decoderCompatible(dec, tt); err != nil {
return err
}
}
// Nil decoded values are handled next, to special-case the pointer handling;
// we don't create pointers all the way down to the actual value.
if dec.IsNil() {
return readFromNil(dec, rv, tt)
}
// Now we know that the decoded value isn't nil. Walk pointers and check for
// faster non-reflect support.
rv = readWalkPointers(rv)
if err := readNonReflect(dec, true, rv.Addr().Interface()); err != errReadMustReflect {
return err
}
// Handle the non-nil decoded value.
if err := readNonNilValue(dec, rv, tt.NonOptional()); err != nil {
return err
}
return dec.FinishValue()
}
// readWalkPointers repeatedly dereferences pointers, creating new values if the
// pointer is nil, and returns the final non-pointer reflect value.
func readWalkPointers(rv reflect.Value) reflect.Value {
for rv.Kind() == reflect.Ptr {
// Special-case to stop at *Type, which is filled in via readNonReflect.
if rv.Type() == rtPtrToType {
return rv
}
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
rv = rv.Elem()
}
return rv
}
func readAny(dec Decoder, calledStart bool, rv reflect.Value) error {
if calledStart {
// The existing code ensures that calledStart is always false here, since
// readReflect(dec, true, ...) is only called below in this function, which
// never calls it with another any type. If we did, we'd have a vdl
// any(any), which isn't allowed. This error tries to prevent future
// changes that will break this requirement.
//
// Also note that the implementation of vom.RawBytes.VDLRead requires that
// StartValue has not been called yet.
return errReadAnyAlreadyStarted
}
// Walk pointers and check for faster non-reflect support, which handles
// vdl.Value and vom.RawBytes, and any other special-cases.
rv = readWalkPointers(rv)
if err := readNonReflect(dec, false, rv.Addr().Interface()); err != errReadMustReflect {
return err
}
// The only case left is to handle interfaces. We allow decoding into
// interface{}, as well as any other interface.
if rv.Kind() != reflect.Interface {
return errReadAnyInterfaceOnly
}
if err := dec.StartValue(); err != nil {
return err
}
// Handle decoding any(nil) by setting the interface to nil. Note that the
// only case where dec.Type() is AnyType is when the value is any(nil).
if dec.Type() == AnyType {
rv.Set(reflect.Zero(rv.Type()))
return dec.FinishValue()
}
// Lookup the reflect type based on the decoder type, and create a new value
// to decode into.
rtDecode := TypeToReflect(dec.Type())
if rtDecode == nil {
return fmt.Errorf("vdl: %v not registered, call vdl.Register, or use vdl.Value or vom.RawBytes instead", dec.Type())
}
// If we decoded an optional type, ensure that it is a pointer. Note that if
// we decoded a nil, dec.Type() is already optional, so rtDecode will already
// be a pointer.
if dec.IsOptional() && !dec.IsNil() {
rtDecode = reflect.PtrTo(rtDecode)
}
if !rtDecode.Implements(rv.Type()) {
return fmt.Errorf("vdl: %v doesn't implement %v", rtDecode, rv.Type())
}
// Handle decoding optional(nil), by setting rv to a nil pointer of the
// concrete type. We know that rtDecode must be a pointer, since dec.Type()
// is optional.
if dec.Type().Kind() == Optional {
rv.Set(reflect.Zero(rtDecode))
return dec.FinishValue()
}
// Handle non-nil values by decoding into rvDecode, and setting rv.
rvDecode := reflect.New(rtDecode).Elem()
if err := readReflect(dec, true, rvDecode, dec.Type()); err != nil {
return err
}
rv.Set(rvDecode)
// Note that dec.FinishValue has already been called by readReflect.
return nil
}
func readFromNil(dec Decoder, rv reflect.Value, tt *Type) error {
if tt.Kind() != Optional {
return fmt.Errorf("vdl: can't decode nil into non-optional %v", tt)
}
// Note that since tt is optional, we know that rv is always a pointer, or the
// special-case error interface.
rv.Set(reflect.Zero(rv.Type()))
return dec.FinishValue()
}
func readNonNilValue(dec Decoder, rv reflect.Value, tt *Type) error {
// Handle named and unnamed []byte and [N]byte, where the element type is the
// unnamed byte type. Cases like []MyByte fall through and are handled as
// regular lists, since we can't easily convert []MyByte to []byte.
switch {
case tt.Kind() == Array && tt.Elem() == ByteType:
bytes := rv.Slice(0, tt.Len()).Interface().([]byte)
return dec.DecodeBytes(tt.Len(), &bytes)
case tt.Kind() == List && tt.Elem() == ByteType:
var bytes []byte
if err := dec.DecodeBytes(-1, &bytes); err != nil {
return err
}
rv.Set(reflect.ValueOf(bytes))
return nil
}
// Handle regular non-nil values.
switch kind := tt.Kind(); kind {
case Bool:
val, err := dec.DecodeBool()
if err != nil {
return err
}
rv.SetBool(val)
return nil
case String:
val, err := dec.DecodeString()
if err != nil {
return err
}
rv.SetString(val)
return nil
case Enum:
val, err := dec.DecodeString()
if err != nil {
return err
}
return rv.Addr().Interface().(settable).Set(val)
case Byte, Uint16, Uint32, Uint64:
val, err := dec.DecodeUint(kind.BitLen())
if err != nil {
return err
}
rv.SetUint(val)
return nil
case Int8, Int16, Int32, Int64:
val, err := dec.DecodeInt(kind.BitLen())
if err != nil {
return err
}
rv.SetInt(val)
return nil
case Float32, Float64:
val, err := dec.DecodeFloat(kind.BitLen())
if err != nil {
return err
}
rv.SetFloat(val)
return nil
case Array:
return readArray(dec, rv, tt)
case List:
return readList(dec, rv, tt)
case Set:
return readSet(dec, rv, tt)
case Map:
return readMap(dec, rv, tt)
case Struct:
return readStruct(dec, rv, tt)
case Union:
return readUnion(dec, rv, tt)
}
// Note that Any was already handled via readAny, Optional was handled via
// readFromNil (or stripped off for non-nil values), and TypeObject was
// handled via the readNonReflect special-case.
return fmt.Errorf("vdl: Read unhandled type %v %v", rv.Type(), tt)
}
func readArray(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
index := 0
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done != (index >= rv.Len()):
return fmt.Errorf("array len mismatch, done:%v index:%d len:%d %v", done, index, rt.Len(), rt)
case done:
return nil
}
if err := readReflect(dec, false, rv.Index(index), tt.Elem()); err != nil {
return err
}
index++
}
}
func readList(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
switch len := dec.LenHint(); {
case len > 0:
rv.Set(reflect.MakeSlice(rt, 0, len))
default:
rv.Set(reflect.Zero(rt))
}
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done:
return nil
}
elem := reflect.New(rt.Elem()).Elem()
if err := readReflect(dec, false, elem, tt.Elem()); err != nil {
return err
}
rv.Set(reflect.Append(rv, elem))
}
}
var rvEmptyStruct = reflect.ValueOf(struct{}{})
func readSet(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
tmpSet, isNil := reflect.Zero(rt), true
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done:
rv.Set(tmpSet)
return nil
}
key := reflect.New(rt.Key()).Elem()
if err := readReflect(dec, false, key, tt.Key()); err != nil {
return err
}
if isNil {
tmpSet, isNil = reflect.MakeMap(rt), false
}
tmpSet.SetMapIndex(key, rvEmptyStruct)
}
}
func readMap(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
tmpMap, isNil := reflect.Zero(rt), true
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done:
rv.Set(tmpMap)
return nil
}
key := reflect.New(rt.Key()).Elem()
if err := readReflect(dec, false, key, tt.Key()); err != nil {
return err
}
elem := reflect.New(rt.Elem()).Elem()
if err := readReflect(dec, false, elem, tt.Elem()); err != nil {
return err
}
if isNil {
tmpMap, isNil = reflect.MakeMap(rt), false
}
tmpMap.SetMapIndex(key, elem)
}
}
func readStruct(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
// Reset to the zero struct, since fields may be missing.
//
// TODO(toddw): Avoid repeated zero-setting of nested structs.
rvZero, err := rvZeroValue(rt, tt)
if err != nil {
return err
}
rv.Set(rvZero)
for {
name, err := dec.NextField()
switch {
case err != nil:
return err
case name == "":
return nil
}
switch ttField, index := tt.FieldByName(name); {
case index != -1:
rvField := rv.FieldByName(name)
if err := readReflect(dec, false, rvField, ttField.Type); err != nil {
return err
}
default:
if err := dec.SkipValue(); err != nil {
return err
}
}
}
}
func readUnion(dec Decoder, rv reflect.Value, tt *Type) error {
rt := rv.Type()
name, err := dec.NextField()
switch {
case err != nil:
return err
case name == "":
return fmt.Errorf("missing field in union %v, from %v", rt, dec.Type())
}
ttField, index := tt.FieldByName(name)
if index == -1 {
return fmt.Errorf("field %q not in union %v, from %v", name, rt, dec.Type())
}
// We have a union interface. Create a new field based on its rep type, fill
// in its value, and assign the field to the interface.
ri, _, err := deriveReflectInfo(rt)
if err != nil {
return err
}
rvField := reflect.New(ri.UnionFields[index].RepType).Elem()
if err := readReflect(dec, false, rvField.Field(0), ttField.Type); err != nil {
return err
}
rv.Set(rvField)
switch name, err := dec.NextField(); {
case err != nil:
return err
case name != "":
return fmt.Errorf("extra field %q in union %v, from %v", name, rt, dec.Type())
}
return nil
}