blob: a6b1209ea254a7622e5afc69986a355797bee2e4 [file] [log] [blame]
// Copyright 2016 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package golang
import (
"fmt"
"v.io/v23/vdl"
"v.io/x/ref/lib/vdl/compile"
)
// defineRead returns VDLRead method for the def type.
func defineRead(data *goData, def *compile.TypeDef) string {
g := genRead{goData: data}
return g.Gen(def)
}
type genRead struct {
*goData
// anonReaders holds the anonymous types that we need to generate __VDLRead
// functions for. We only generate the function if we're the first one to add
// it to goData; otherwise someone else has already generated it.
anonReaders []*vdl.Type
}
func (g *genRead) Gen(def *compile.TypeDef) string {
var s string
if def.Type.Kind() == vdl.Union {
s += g.genUnionDef(def)
} else {
s += g.genDef(def)
}
s += g.genAnonDef()
return s
}
func (g *genRead) genDef(def *compile.TypeDef) string {
s := fmt.Sprintf(`
func (x *%[1]s) VDLRead(dec %[2]sDecoder) error {`, def.Name, g.Pkg("v.io/v23/vdl"))
// Structs need to be zeroed, since some of the fields may be missing.
if def.Type.Kind() == vdl.Struct {
s += fmt.Sprintf(`
*x = %[1]s`, typedConstWire(g.goData, vdl.ZeroValue(def.Type)))
}
s += g.body(def.Type, namedArg{"x", true}, true) + `
}
`
return s
}
// genUnionDef is a special-case, since we can't attach methods to a pointer
// receiver interface. Instead we create a package-level function.
func (g *genRead) genUnionDef(def *compile.TypeDef) string {
s := fmt.Sprintf(`
func %[1]s(dec %[2]sDecoder, x *%[3]s) error {
if err := dec.StartValue(); err != nil {
return err
}`, unionReadFuncName(def), g.Pkg("v.io/v23/vdl"), def.Name)
s += g.bodyUnion(def.Type, namedArg{"x", true}) + `
return dec.FinishValue()
}
`
return s
}
func unionReadFuncName(def *compile.TypeDef) string {
if def.Exported {
return "VDLRead" + def.Name
}
return "__VDLRead_" + def.Name
}
func (g *genRead) genAnonDef() string {
var s string
// Generate the __VDLRead functions for anonymous types. Creating the
// function for one type may cause us to need more, e.g. [][]Certificate. So
// we just keep looping until there are no new functions to generate. There's
// no danger of infinite looping, since cyclic anonymous types are disallowed
// in the VDL type system.
for len(g.anonReaders) > 0 {
anons := g.anonReaders
g.anonReaders = nil
for _, anon := range anons {
s += fmt.Sprintf(`
func %[1]s(dec %[2]sDecoder, x *%[3]s) error {`, g.anonReaderName(anon), g.Pkg("v.io/v23/vdl"), typeGo(g.goData, anon))
s += g.body(anon, namedArg{"x", true}, true) + `
}
`
}
}
return s
}
func (g *genRead) body(tt *vdl.Type, arg namedArg, topLevel bool) string {
kind := tt.Kind()
sta := `
if err := dec.StartValue(); err != nil {
return err
}`
fin := `
if err := dec.FinishValue(); err != nil {
return err
}`
if topLevel {
fin = `
return dec.FinishValue()`
}
// Handle special cases. The ordering of the cases is very important.
switch {
case tt == vdl.ErrorType:
// Error types call verror.VDLRead directly, similar to named types, but
// even more special-cased. Appears before optional, since ErrorType is
// optional.
return g.bodyError(arg)
case kind == vdl.Optional:
// Optional types need special nil handling. Appears as early as possible,
// so that all subsequent cases have optionality handled correctly.
return sta + g.bodyOptional(tt, arg)
case !topLevel && isNativeType(g.Env, tt):
// Non-top-level native types need a final native conversion, while
// top-level native types use the regular logic to create VDLRead for the
// wire type. Appears as early as possible, so that all subsequent cases
// have nativity handled correctly.
return g.bodyNative(tt, arg)
case !topLevel && tt.Name() != "":
// Non-top-level named types call the VDLRead method defined on the arg.
// The top-level type is always named, and needs a real body generated.
// Appears before bytes, so that we use the defined method rather than
// re-generating extra code.
return g.bodyCallVDLRead(tt, arg)
case tt.IsBytes():
// Bytes use the special Decoder.DecodeBytes method. Appears before
// anonymous types, to avoid processing bytes as a list or array.
return sta + g.bodyBytes(tt, arg) + fin
case !topLevel && (kind == vdl.List || kind == vdl.Set || kind == vdl.Map):
// Non-top-level anonymous types call the unexported __VDLRead* functions
// generated in g.Gen, after the main VDLRead method has been generated.
// Top-level anonymous types use the regular logic, to generate the actual
// body of the __VDLRead* functions.
return g.bodyAnon(tt, arg)
}
// Handle each kind of type.
switch kind {
case vdl.Bool:
return sta + g.bodyScalar(tt, vdl.BoolType, arg, "Bool") + fin
case vdl.String:
return sta + g.bodyScalar(tt, vdl.StringType, arg, "String") + fin
case vdl.Byte, vdl.Uint16, vdl.Uint32, vdl.Uint64:
return sta + g.bodyScalar(tt, vdl.Uint64Type, arg, "Uint", bitlen(kind)) + fin
case vdl.Int8, vdl.Int16, vdl.Int32, vdl.Int64:
return sta + g.bodyScalar(tt, vdl.Int64Type, arg, "Int", bitlen(kind)) + fin
case vdl.Float32, vdl.Float64:
return sta + g.bodyScalar(tt, vdl.Float64Type, arg, "Float", bitlen(kind)) + fin
case vdl.TypeObject:
return sta + g.bodyScalar(tt, vdl.TypeObjectType, arg, "TypeObject") + fin
case vdl.Enum:
return sta + g.bodyEnum(arg) + fin
case vdl.Array:
return sta + g.bodyArray(tt, arg)
case vdl.List:
return sta + g.bodyList(tt, arg)
case vdl.Set, vdl.Map:
return sta + g.bodySetMap(tt, arg)
case vdl.Struct:
return sta + g.bodyStruct(tt, arg)
case vdl.Any:
return g.bodyAny(tt, arg)
default:
panic(fmt.Errorf("VDLRead unhandled type %s", tt))
}
}
func (g *genRead) bodyError(arg namedArg) string {
return fmt.Sprintf(`
if err := %[1]sVDLRead(dec, &%[2]s); err != nil {
return err
}`, g.Pkg("v.io/v23/verror"), arg.Name)
}
func (g *genRead) bodyNative(tt *vdl.Type, arg namedArg) string {
return fmt.Sprintf(`
var wire %[1]s%[2]s
if err := %[1]sToNative(wire, %[3]s); err != nil {
return err
}`, typeGoWire(g.goData, tt), g.bodyCallVDLRead(tt, typedArg("wire", tt)), arg.Ptr())
}
func (g *genRead) bodyCallVDLRead(tt *vdl.Type, arg namedArg) string {
if tt.Kind() == vdl.Union {
// Unions are a special-case, since we can't attach methods to a pointer
// receiver interface. Instead we call the package-level function.
def := g.Env.FindTypeDef(tt)
return fmt.Sprintf(`
if err := %[1]s%[2]s(dec, %[3]s); err != nil {
return err
}`, g.Pkg(def.File.Package.GenPath), unionReadFuncName(def), arg.Ptr())
}
return fmt.Sprintf(`
if err := %[1]s.VDLRead(dec); err != nil {
return err
}`, arg.Name)
}
func (g *genRead) anonReaderName(tt *vdl.Type) string {
return fmt.Sprintf("__VDLReadAnon_%s_%d", tt.Kind(), g.goData.anonReaders[tt])
}
func (g *genRead) bodyAnon(tt *vdl.Type, arg namedArg) string {
id := g.goData.anonReaders[tt]
if id == 0 {
// This is the first time we've encountered this type, add it.
id = len(g.goData.anonReaders) + 1
g.goData.anonReaders[tt] = id
g.anonReaders = append(g.anonReaders, tt)
}
return fmt.Sprintf(`
if err := %[1]s(dec, %[2]s); err != nil {
return err
}`, g.anonReaderName(tt), arg.Ptr())
}
func (g *genRead) bodyScalar(tt, exact *vdl.Type, arg namedArg, method string, params ...interface{}) string {
var paramStr string
for i, p := range params {
if i > 0 {
paramStr += ", "
}
paramStr += fmt.Sprint(p)
}
if tt == exact {
return fmt.Sprintf(`
var err error
if %[1]s, err = dec.Decode%[2]s(%[3]s); err != nil {
return err
}`, arg.Ref(), method, paramStr)
}
// The types don't have an exact match, so we need a conversion. This
// occurs for all named types, as well as numeric types where the bitlen
// isn't exactly the same. E.g.
//
// type Foo uint16
//
// tmp, err := dec.DecodeUint(16) // returns uint64
// if err != nil { return err }
// *x = Foo(tmp)
return fmt.Sprintf(`
tmp, err := dec.Decode%[2]s(%[3]s)
if err != nil {
return err
}
%[1]s = %[4]s(tmp)`, arg.Ref(), method, paramStr, typeGo(g.goData, tt))
}
func (g *genRead) bodyEnum(arg namedArg) string {
return fmt.Sprintf(`
enum, err := dec.DecodeString()
if err != nil {
return err
}
if err := %[1]s.Set(enum); err != nil {
return err
}`, arg.Name)
}
func (g *genRead) bodyBytes(tt *vdl.Type, arg namedArg) string {
s, decodeVar := "", arg.Ptr()
var max int
var assignEnd bool
switch tt.Kind() {
case vdl.Array:
s += fmt.Sprintf(`
bytes := %[1]s[:]`, arg.Name)
decodeVar, max = "&bytes", tt.Len()
case vdl.List:
max = -1
if tt.Name() != "" {
s += `
var bytes []byte`
decodeVar, assignEnd = "&bytes", true
}
}
s += fmt.Sprintf(`
if err := dec.DecodeBytes(%[1]d, %[2]s); err != nil {
return err
}`, max, decodeVar)
if assignEnd {
s += fmt.Sprintf(`
%[1]s = bytes`, arg.Ref())
}
return s
}
func (g *genRead) checkCompat(kind vdl.Kind, varName string) string {
return fmt.Sprintf(`
if (dec.StackDepth() == 1 || dec.IsAny()) && !%[1]sCompatible(%[1]sTypeOf(%[4]s), dec.Type()) {
return %[2]sErrorf("incompatible %[3]s %%T, from %%v", %[4]s, dec.Type())
}`, g.Pkg("v.io/v23/vdl"), g.Pkg("fmt"), kind, varName)
}
func (g *genRead) bodyArray(tt *vdl.Type, arg namedArg) string {
s := g.checkCompat(tt.Kind(), arg.Ref())
s += fmt.Sprintf(`
index := 0
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done != (index >= len(*x)):
return %[1]sErrorf("array len mismatch, got %%d, want %%T", index, %[2]s)
case done:
return dec.FinishValue()
}`, g.Pkg("fmt"), arg.Ref())
elem := fmt.Sprintf(`%[1]s[index]`, arg.Name)
s += g.body(tt.Elem(), typedArg(elem, tt.Elem()), false)
return s + `
index++
}`
}
func (g *genRead) bodyList(tt *vdl.Type, arg namedArg) string {
s := g.checkCompat(tt.Kind(), arg.Ref())
s += fmt.Sprintf(`
switch len := dec.LenHint(); {
case len > 0:
%[1]s = make(%[2]s, 0, len)
default:
%[1]s = nil
}
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done:
return dec.FinishValue()
}
var elem %[3]s`, arg.Ref(), typeGo(g.goData, tt), typeGo(g.goData, tt.Elem()))
s += g.body(tt.Elem(), typedArg("elem", tt.Elem()), false)
return s + fmt.Sprintf(`
%[1]s = append(%[1]s, elem)
}`, arg.Ref())
}
func (g *genRead) bodySetMap(tt *vdl.Type, arg namedArg) string {
s := g.checkCompat(tt.Kind(), arg.Ref())
s += fmt.Sprintf(`
var tmpMap %[2]s
if len := dec.LenHint(); len > 0 {
tmpMap = make(%[2]s, len)
}
for {
switch done, err := dec.NextEntry(); {
case err != nil:
return err
case done:
%[1]s = tmpMap
return dec.FinishValue()
}
var key %[3]s
{`, arg.Ref(), typeGo(g.goData, tt), typeGo(g.goData, tt.Key()))
s += g.body(tt.Key(), typedArg("key", tt.Key()), false) + `
}`
elemVar := "struct{}{}"
if tt.Kind() == vdl.Map {
elemVar = "elem"
s += fmt.Sprintf(`
var elem %[1]s
{`, typeGo(g.goData, tt.Elem()))
s += g.body(tt.Elem(), typedArg("elem", tt.Elem()), false) + `
}`
}
return s + fmt.Sprintf(`
if tmpMap == nil {
tmpMap = make(%[1]s)
}
tmpMap[key] = %[2]s
}`, typeGo(g.goData, tt), elemVar)
}
func (g *genRead) bodyStruct(tt *vdl.Type, arg namedArg) string {
s := g.checkCompat(tt.Kind(), arg.Ref()) + `
for {
f, err := dec.NextField()
if err != nil {
return err
}
switch f {
case "":
return dec.FinishValue()`
for f := 0; f < tt.NumField(); f++ {
field := tt.Field(f)
s += fmt.Sprintf(`
case %[1]q:`, field.Name)
s += g.body(field.Type, arg.Field(field), false)
}
return s + `
default:
if err := dec.SkipValue(); err != nil {
return err
}
}
}`
}
func (g *genRead) bodyUnion(tt *vdl.Type, arg namedArg) string {
s := g.checkCompat(tt.Kind(), arg.Ptr()) + `
f, err := dec.NextField()
if err != nil {
return err
}
switch f {`
for f := 0; f < tt.NumField(); f++ {
// TODO(toddw): Change to using pointers to the union field structs, to
// resolve https://v.io/i/455
field := tt.Field(f)
s += fmt.Sprintf(`
case %[1]q:
var field %[2]s%[1]s`, field.Name, typeGoWire(g.goData, tt))
s += g.body(field.Type, typedArg("field.Value", field.Type), false)
s += fmt.Sprintf(`
%[1]s = field`, arg.Ref())
}
return s + fmt.Sprintf(`
case "":
return %[1]sErrorf("missing field in union %%T, from %%v", %[2]s, dec.Type())
default:
return %[1]sErrorf("field %%q not in union %%T, from %%v", f, %[2]s, dec.Type())
}
switch f, err := dec.NextField(); {
case err != nil:
return err
case f != "":
return %[1]sErrorf("extra field %%q in union %%T, from %%v", f, %[2]s, dec.Type())
}`, g.Pkg("fmt"), arg.Ptr())
}
func (g *genRead) bodyOptional(tt *vdl.Type, arg namedArg) string {
// NOTE: arg.IsPtr is always true here, since tt is optional.
s := `
if dec.IsNil() {`
s += g.checkCompat(tt.Kind(), arg.Name)
s += fmt.Sprintf(`
%[1]s = nil
if err := dec.FinishValue(); err != nil {
return err
}
} else {
%[1]s = new(%[2]s)
dec.IgnoreNextStartValue()`, arg.Name, typeGo(g.goData, tt.Elem()))
return s + g.body(tt.Elem(), arg, false) + `
}`
}
func (g *genRead) bodyAny(tt *vdl.Type, arg namedArg) string {
var s string
if shouldUseVdlValueForAny(g.Package) {
// Hack to deal with top-level any. Any values in the generated code are
// initialized to vdl.ZeroValue(AnyType), so when we call vdl.Value.VDLRead
// on the value, the type of the value is always top-level any. We create a
// new (invalid) vdl.Value here, which causes VDLRead to fill the value with
// the exact type read from the decoder. We can't just change the
// initialization to create a new vdl.Value, since struct fields that aren't
// read will retain their initialized values, and the invalid vdl.Value
// causes problems with the encoder.
//
// TODO(toddw): Remove this hack.
s += fmt.Sprintf(`
%[1]s = new(%[2]sValue)`, arg.Ref(), g.Pkg("v.io/v23/vdl"))
}
return s + g.bodyCallVDLRead(tt, arg)
}
// namedArg represents a named argument, with methods to conveniently return the
// pointer or non-pointer form of the argument.
type namedArg struct {
Name string // variable name
IsPtr bool // is the variable a pointer type
}
func (arg namedArg) IsValid() bool {
return arg.Name != ""
}
func (arg namedArg) Ptr() string {
if arg.IsPtr {
return arg.Name
}
return "&" + arg.Name
}
func (arg namedArg) Ref() string {
if arg.IsPtr {
return "*" + arg.Name
}
return arg.Name
}
func (arg namedArg) SafeRef() string {
if arg.IsPtr {
return "(*" + arg.Name + ")"
}
return arg.Name
}
func (arg namedArg) Field(field vdl.Field) namedArg {
return typedArg(arg.Name+"."+field.Name, field.Type)
}
func (arg namedArg) Index(index string, tt *vdl.Type) namedArg {
return typedArg(arg.SafeRef()+"["+index+"]", tt)
}
func typedArg(name string, tt *vdl.Type) namedArg {
return namedArg{name, tt.Kind() == vdl.Optional}
}
func bitlen(k vdl.Kind) int {
switch k {
case vdl.Byte, vdl.Int8:
return 8
case vdl.Uint16, vdl.Int16:
return 16
case vdl.Uint32, vdl.Int32, vdl.Float32:
return 32
case vdl.Uint64, vdl.Int64, vdl.Float64:
return 64
}
panic(fmt.Errorf("bitlen kind %v not handled", k))
}