// Copyright 2016 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package vdl

import (
	"errors"
	"fmt"
	"reflect"
)

var (
	errReadMustReflect       = errors.New("vdl: read must be handled via reflection")
	errReadIntoNilValue      = errors.New("vdl: read into nil value")
	errReadReflectCantSet    = errors.New("vdl: read into unsettable reflect.Value")
	errReadAnyAlreadyStarted = errors.New("vdl: read into any after StartValue called")
	errReadAnyInterfaceOnly  = errors.New("vdl: read into any only supported for interfaces")
)

// Read uses dec to decode a value into v, calling VDLRead methods and fast
// compiled readers when available, and using reflection otherwise.  This is
// basically an all-purpose VDLRead implementation.
func Read(dec Decoder, v interface{}) error {
	if v == nil {
		return errReadIntoNilValue
	}
	rv := reflect.ValueOf(v)
	if rv.Kind() == reflect.Ptr && !rv.IsNil() {
		// Fastpath check for non-reflect support.  Unfortunately we must use
		// reflection to detect the case where v is a nil pointer, which returns an
		// error in ReadReflect.
		//
		// TODO(toddw): If reflection is too slow, add the nil pointer check to all
		// VDLRead methods, as well as other readNonReflect cases below.
		if err := readNonReflect(dec, false, v); err != errReadMustReflect {
			return err
		}
	}
	return ReadReflect(dec, rv)
}

func readNonReflect(dec Decoder, calledStart bool, v interface{}) error {
	switch x := v.(type) {
	case Reader:
		// Reader handles the case where x has a code-generated decoder, and
		// special-cases such as vdl.Value and vom.RawBytes.
		if calledStart {
			dec.IgnoreNextStartValue()
		}
		return x.VDLRead(dec)
	case **Type:
		// Special-case type decoding, since we must assign the hash-consed pointer
		// for correctness, rather than filling in a newly-created Type.
		if !calledStart {
			if err := dec.StartValue(TypeObjectType); err != nil {
				return err
			}
		}
		var err error
		if *x, err = dec.DecodeTypeObject(); err != nil {
			return err
		}
		return dec.FinishValue()

		// Cases after this point are purely performance optimizations.
		// TODO(toddw): Handle other common cases.
	case *[]byte:
		if !calledStart {
			if err := dec.StartValue(ttByteList); err != nil {
				return err
			}
		}
		if err := dec.DecodeBytes(-1, x); err != nil {
			return err
		}
		return dec.FinishValue()
	}
	return errReadMustReflect
}

var ttByteList = ListType(ByteType)

// ReadReflect is like Read, but takes a reflect.Value argument.
func ReadReflect(dec Decoder, rv reflect.Value) error {
	if !rv.IsValid() {
		return errReadIntoNilValue
	}
	if !rv.CanSet() && rv.Kind() == reflect.Ptr && !rv.IsNil() {
		// Dereference the pointer a single time to make rv settable.
		rv = rv.Elem()
	}
	if !rv.CanSet() {
		return errReadReflectCantSet
	}
	tt, err := TypeFromReflect(rv.Type())
	if err != nil {
		return err
	}
	return readReflect(dec, false, rv, tt)
}

// readReflect uses dec to decode a value into rv, which has VDL type tt.  On
// success we guarantee that StartValue / FinishValue has been called on dec.
// If calledStart is true, StartValue has already been called.
func readReflect(dec Decoder, calledStart bool, rv reflect.Value, tt *Type) error {
	// Handle decoding into an any rv value first, since vom.RawBytes.VDLRead
	// doesn't support IgnoreNextStartValue, and requires that StartValue hasn't
	// been called yet.  Note that cases where the dec value is any but the rv
	// value isn't any will pass through.
	if tt == AnyType {
		return readIntoAny(dec, calledStart, rv)
	}
	// Now start the decoder value, if we haven't already.
	if !calledStart {
		if err := dec.StartValue(tt); err != nil {
			return err
		}
	}
	// Handle nil decoded values next, to simplify the rest of the cases.  This
	// handles cases where the dec value is either any(nil) or optional(nil).
	if dec.IsNil() {
		return readFromNil(dec, rv, tt)
	}
	// Now we know that the decoded value isn't nil.  Flatten pointers and check
	// for fast non-reflect support.
	rv = rvFlattenPointers(rv)
	if err := readNonReflect(dec, true, rv.Addr().Interface()); err != errReadMustReflect {
		return err
	}
	// Handle native types, which need the ToNative conversion.  Notice that rv is
	// never a pointer here, so we don't support native pointer types.  In theory
	// we could support native pointer types, but they're complicated and will
	// probably slow everything down.
	//
	// TODO(toddw): Investigate support for native pointer types.
	if ni := nativeInfoFromNative(rv.Type()); ni != nil {
		rvWire := reflect.New(ni.WireType).Elem()
		if err := readReflect(dec, true, rvWire, tt); err != nil {
			return err
		}
		return ni.ToNative(rvWire, rv.Addr())
		// NOTE: readReflect guarantees that FinishValue has already been called.
	}
	// Handle non-nil wire values.
	if err := readNonNilValue(dec, rv, tt.NonOptional()); err != nil {
		return err
	}
	return dec.FinishValue()
}

// readIntoAny uses dec to decode a value into rv, which has VDL type any.
func readIntoAny(dec Decoder, calledStart bool, rv reflect.Value) error {
	if calledStart {
		// The existing code ensures that calledStart is always false here, since
		// readReflect(dec, true, ...) is only called in situations where it's
		// impossible to call readIntoAny.  E.g. it's called later in this function,
		// which never calls it with another any type.  If we did, we'd have a vdl
		// any(any), which isn't allowed.  This error tries to prevent future
		// changes that will break this requirement.
		//
		// The requirement is mandated by vom.RawBytes.VDLRead, which doesn't handle
		// IgnoreNextStartValue.
		return errReadAnyAlreadyStarted
	}
	// Flatten pointers and check for fast non-reflect support, which handles
	// vdl.Value and vom.RawBytes, and any other special-cases.
	rv = rvFlattenPointers(rv)
	if err := readNonReflect(dec, false, rv.Addr().Interface()); err != errReadMustReflect {
		return err
	}
	// The only case left is to handle interfaces.  We allow decoding into
	// all interfaces, including interface{}.
	if rv.Kind() != reflect.Interface {
		return errReadAnyInterfaceOnly
	}
	if err := dec.StartValue(AnyType); err != nil {
		return err
	}
	// Handle decoding any(nil) by setting the rv interface to nil.  Note that the
	// only case where dec.Type() is AnyType is when the value is any(nil).
	if dec.Type() == AnyType {
		if !rv.IsNil() {
			rv.Set(reflect.Zero(rv.Type()))
		}
		return dec.FinishValue()
	}
	// Lookup the reflect type based on the decoder type, and create a new value
	// to decode into.  If the dec value is optional, ensure that we lookup based
	// on an optional type.  Note that if the dec value is nil, dec.Type() is
	// already optional, so rtDecode will already be a pointer.
	ttDecode := dec.Type()
	if dec.IsOptional() && !dec.IsNil() {
		ttDecode = OptionalType(ttDecode)
	}
	rtDecode := typeToReflectNew(ttDecode)
	// Handle top-level "v.io/v23/vdl.WireError" types.  TypeToReflect will find
	// vdl.WireError based on regular wire type registration, and will find the Go
	// error interface based on regular native type registration, and these are
	// fine for nested error types.
	//
	// But this is the case where we're decoding into a top-level Go interface,
	// and we'll lose type information if the dec value is nil.  So instead we
	// return the registered verror.E type.  Examples:
	//
	//   ttDecode  ->  rtDecode
	//   -----------------------
	//   WireError     verror.E
	//   ?WireError    *verror.E
	//   []WireError   []vdl.WireError (1)
	//   []?WireError  []error
	//
	// TODO(toddw): The (1) case above is weird; we would like to return verror.E,
	// but that's hard because the native conversion we've registered doesn't
	// currently include the verror.E type:
	//
	//    ToNative(wire *vdl.WireError, native *error)
	//    FromNative(wire **vdl.WireError, native error)
	//
	// We could make this more consistent by registering a pair of conversion
	// functions instead:
	//
	//    ToNative(wire vdl.WireError, native *verror.E)
	//    FromNative(wire *vdl.WireError, native verror.E)
	//
	//    ToNative(wire *verror.E, native *error)
	//    FromNative(wire **verror.E, native error)
	if ttDecode.NonOptional().Name() == ErrorType.Elem().Name() {
		if ni, err := nativeInfoForError(); err == nil {
			if ttDecode.Kind() == Optional {
				rtDecode = reflect.PtrTo(ni.NativeType)
			} else {
				rtDecode = ni.NativeType
			}
		}
	}
	if rtDecode == nil {
		return fmt.Errorf("vdl: %v not registered, either call vdl.Register, or use vdl.Value or vom.RawBytes instead", dec.Type())
	}
	if !rtDecode.Implements(rv.Type()) {
		return fmt.Errorf("vdl: %v doesn't implement %v", rtDecode, rv.Type())
	}
	// Handle both nil and non-nil values by decoding into rvDecode, and setting
	// rv.  Both nil and non-nil values are handled in the readReflect call.
	rvDecode := reflect.New(rtDecode).Elem()
	if err := readReflect(dec, true, rvDecode, ttDecode); err != nil {
		return err
	}
	rv.Set(rvDecode)
	// NOTE: readReflect guarantees that FinishValue has already been called.
	return nil
}

// readFromNil uses dec to decode a nil value into rv, which has VDL type tt.
// The value in dec might be either any(nil) or optional(nil).
//
// REQUIRES: dec.IsNil() && tt != AnyType
func readFromNil(dec Decoder, rv reflect.Value, tt *Type) error {
	if tt.Kind() != Optional {
		return fmt.Errorf("vdl: can't decode nil into non-optional %v", tt)
	}
	// Flatten pointers until we have a single pointer left, or there were no
	// pointers to begin with.
	rt := rv.Type()
	for rt.Kind() == reflect.Ptr && rt.Elem().Kind() == reflect.Ptr {
		if rv.IsNil() {
			rv.Set(reflect.New(rt.Elem()))
		}
		rv, rt = rv.Elem(), rt.Elem()
	}
	// Handle tricky cases where rv is a native type.
	if handled, err := readFromNilNative(dec, rv, tt); handled {
		return err
	}
	// Now handle the simple case: rv has one pointer left, and should be set to a
	// nil pointer.
	if rt.Kind() != reflect.Ptr {
		return fmt.Errorf("vdl: can't decode nil into non-pointer %v optional %v", rt, tt)
	}
	if !rv.IsNil() {
		rv.Set(reflect.Zero(rt))
	}
	return dec.FinishValue()
}

// readFromNilNative handles tricky cases where rv is a native type.  Returns
// true if rv is a native type and was handled, otherwise returns false.
//
// REQUIRES: rv.Type() has at most one pointer.
func readFromNilNative(dec Decoder, rv reflect.Value, tt *Type) (bool, error) {
	var ni *nativeInfo
	if rt := rv.Type(); rt.Kind() != reflect.Ptr {
		// Handle the case where rv isn't a pointer; e.g. the Go error interface is
		// a non-pointer native type, and is handled here.
		ni = nativeInfoFromNative(rt)
	} else {
		// Handle the case where rv is a pointer, and the elem is a native type.
		// E.g. *error is handled here.  Note that we don't support native pointer
		// types; see comments at other calls to nativeInfoFromNative.
		ni = nativeInfoFromNative(rt.Elem())
		if ni != nil {
			if rv.IsNil() {
				rv.Set(reflect.New(rt.Elem()))
			}
			rv = rv.Elem()
		}
	}
	if ni != nil {
		// Handle the native type from either case above.  At this point, rv is the
		// native type and isn't a nil pointer.
		rvWire := reflect.New(ni.WireType).Elem()
		if err := readReflect(dec, true, rvWire, tt); err != nil {
			return true, err
		}
		return true, ni.ToNative(rvWire, rv.Addr())
		// NOTE: readReflect guarantees that FinishValue has already been called.
	}
	return false, nil
}

func readNonNilValue(dec Decoder, rv reflect.Value, tt *Type) error {
	// Handle named and unnamed []byte and [N]byte, where the element type is the
	// unnamed byte type.  Cases like []MyByte fall through and are handled as
	// regular lists, since we can't easily convert []MyByte to []byte.
	switch {
	case tt.Kind() == Array && tt.Elem() == ByteType:
		bytes := rv.Slice(0, tt.Len()).Interface().([]byte)
		return dec.DecodeBytes(tt.Len(), &bytes)
	case tt.Kind() == List && tt.Elem() == ByteType:
		var bytes []byte
		if err := dec.DecodeBytes(-1, &bytes); err != nil {
			return err
		}
		rv.Set(reflect.ValueOf(bytes))
		return nil
	}
	// Handle regular non-nil values.
	switch kind := tt.Kind(); kind {
	case Bool:
		val, err := dec.DecodeBool()
		if err != nil {
			return err
		}
		rv.SetBool(val)
		return nil
	case String:
		val, err := dec.DecodeString()
		if err != nil {
			return err
		}
		rv.SetString(val)
		return nil
	case Enum:
		val, err := dec.DecodeString()
		if err != nil {
			return err
		}
		return rv.Addr().Interface().(settable).Set(val)
	case Byte, Uint16, Uint32, Uint64:
		val, err := dec.DecodeUint(kind.BitLen())
		if err != nil {
			return err
		}
		rv.SetUint(val)
		return nil
	case Int8, Int16, Int32, Int64:
		val, err := dec.DecodeInt(kind.BitLen())
		if err != nil {
			return err
		}
		rv.SetInt(val)
		return nil
	case Float32, Float64:
		val, err := dec.DecodeFloat(kind.BitLen())
		if err != nil {
			return err
		}
		rv.SetFloat(val)
		return nil
	case Array:
		return readArray(dec, rv, tt)
	case List:
		return readList(dec, rv, tt)
	case Set:
		return readSet(dec, rv, tt)
	case Map:
		return readMap(dec, rv, tt)
	case Struct:
		return readStruct(dec, rv, tt)
	case Union:
		return readUnion(dec, rv, tt)
	}
	// Note that Any was already handled via readAny, Optional was handled via
	// readFromNil (or stripped off for non-nil values), and TypeObject was
	// handled via the readNonReflect special-case.
	return fmt.Errorf("vdl: Read unhandled type %v %v", rv.Type(), tt)
}

func readArray(dec Decoder, rv reflect.Value, tt *Type) error {
	rt := rv.Type()
	index := 0
	for {
		switch done, err := dec.NextEntry(); {
		case err != nil:
			return err
		case done != (index >= rv.Len()):
			return fmt.Errorf("array len mismatch, done:%v index:%d len:%d %v", done, index, rt.Len(), rt)
		case done:
			return nil
		}
		if err := readReflect(dec, false, rv.Index(index), tt.Elem()); err != nil {
			return err
		}
		index++
	}
}

func readList(dec Decoder, rv reflect.Value, tt *Type) error {
	rt := rv.Type()
	switch len := dec.LenHint(); {
	case len > 0:
		rv.Set(reflect.MakeSlice(rt, 0, len))
	default:
		rv.Set(reflect.Zero(rt))
	}
	for {
		switch done, err := dec.NextEntry(); {
		case err != nil:
			return err
		case done:
			return nil
		}
		elem := reflect.New(rt.Elem()).Elem()
		if err := readReflect(dec, false, elem, tt.Elem()); err != nil {
			return err
		}
		rv.Set(reflect.Append(rv, elem))
	}
}

var rvEmptyStruct = reflect.ValueOf(struct{}{})

func readSet(dec Decoder, rv reflect.Value, tt *Type) error {
	rt := rv.Type()
	tmpSet, isNil := reflect.Zero(rt), true
	for {
		switch done, err := dec.NextEntry(); {
		case err != nil:
			return err
		case done:
			rv.Set(tmpSet)
			return nil
		}
		key := reflect.New(rt.Key()).Elem()
		if err := readReflect(dec, false, key, tt.Key()); err != nil {
			return err
		}
		if isNil {
			tmpSet, isNil = reflect.MakeMap(rt), false
		}
		tmpSet.SetMapIndex(key, rvEmptyStruct)
	}
}

func readMap(dec Decoder, rv reflect.Value, tt *Type) error {
	rt := rv.Type()
	tmpMap, isNil := reflect.Zero(rt), true
	for {
		switch done, err := dec.NextEntry(); {
		case err != nil:
			return err
		case done:
			rv.Set(tmpMap)
			return nil
		}
		key := reflect.New(rt.Key()).Elem()
		if err := readReflect(dec, false, key, tt.Key()); err != nil {
			return err
		}
		elem := reflect.New(rt.Elem()).Elem()
		if err := readReflect(dec, false, elem, tt.Elem()); err != nil {
			return err
		}
		if isNil {
			tmpMap, isNil = reflect.MakeMap(rt), false
		}
		tmpMap.SetMapIndex(key, elem)
	}
}

func readStruct(dec Decoder, rv reflect.Value, tt *Type) error {
	rt := rv.Type()
	// Reset to the zero struct, since fields may be missing.
	//
	// TODO(toddw): Avoid repeated zero-setting of nested structs.
	rvZero, err := rvZeroValue(rt, tt)
	if err != nil {
		return err
	}
	rv.Set(rvZero)
	for {
		name, err := dec.NextField()
		switch {
		case err != nil:
			return err
		case name == "":
			return nil
		}
		switch ttField, index := tt.FieldByName(name); {
		case index != -1:
			rvField := rv.FieldByName(name)
			if err := readReflect(dec, false, rvField, ttField.Type); err != nil {
				return err
			}
		default:
			if err := dec.SkipValue(); err != nil {
				return err
			}
		}
	}
}

func readUnion(dec Decoder, rv reflect.Value, tt *Type) error {
	rt := rv.Type()
	name, err := dec.NextField()
	switch {
	case err != nil:
		return err
	case name == "":
		return fmt.Errorf("missing field in union %v, from %v", rt, dec.Type())
	}
	ttField, index := tt.FieldByName(name)
	if index == -1 {
		return fmt.Errorf("field %q not in union %v, from %v", name, rt, dec.Type())
	}
	// We have a union interface.  Create a new field based on its rep type, fill
	// in its value, and assign the field to the interface.
	ri, _, err := deriveReflectInfo(rt)
	if err != nil {
		return err
	}
	rvField := reflect.New(ri.UnionFields[index].RepType).Elem()
	if err := readReflect(dec, false, rvField.Field(0), ttField.Type); err != nil {
		return err
	}
	rv.Set(rvField)
	switch name, err := dec.NextField(); {
	case err != nil:
		return err
	case name != "":
		return fmt.Errorf("extra field %q in union %v, from %v", name, rt, dec.Type())
	}
	return nil
}
