| // Copyright 2016 The Vanadium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package golang |
| |
| import ( |
| "fmt" |
| "strings" |
| |
| "v.io/v23/vdl" |
| "v.io/x/ref/lib/vdl/compile" |
| ) |
| |
| // defineWrite returns the VDLWrite method for the def type. |
| func defineWrite(data *goData, def *compile.TypeDef) string { |
| g := genWrite{goData: data} |
| return g.Gen(def) |
| } |
| |
| type genWrite struct { |
| *goData |
| // anonWriters holds the anonymous types that we need to generate __VDLWrite |
| // functions for. We only generate the function if we're the first one to add |
| // it to goData; otherwise someone else has already generated it. |
| anonWriters []*vdl.Type |
| } |
| |
| func (g *genWrite) Gen(def *compile.TypeDef) string { |
| var s string |
| if def.Type.Kind() == vdl.Union { |
| s += g.genUnionDef(def) |
| } else { |
| s += g.genDef(def) |
| } |
| s += g.genAnonDef() |
| return s |
| } |
| |
| func (g *genWrite) genDef(def *compile.TypeDef) string { |
| body := g.body(def.Type, namedArg{"x", false}, false, true) |
| return fmt.Sprintf(` |
| func (x %[1]s) VDLWrite(enc %[2]sEncoder) error {%[3]s |
| } |
| `, def.Name, g.Pkg("v.io/v23/vdl"), body) |
| } |
| |
| // genUnionDef is a special-case, since we need to generate methods for each |
| // concrete union struct. |
| func (g *genWrite) genUnionDef(def *compile.TypeDef) string { |
| var s string |
| for ix := 0; ix < def.Type.NumField(); ix++ { |
| field := def.Type.Field(ix) |
| body := g.bodyUnion(field, namedArg{"x", false}) |
| s += fmt.Sprintf(` |
| func (x %[1]s%[2]s) VDLWrite(enc %[3]sEncoder) error { |
| if err := enc.StartValue(%[4]s); err != nil { |
| return err |
| }%[5]s |
| return enc.FinishValue() |
| } |
| `, def.Name, field.Name, g.Pkg("v.io/v23/vdl"), g.TypeOf(def.Type), body) |
| } |
| return s |
| } |
| |
| func (g *genWrite) genAnonDef() string { |
| var s string |
| // Generate the __VDLWrite functions for anonymous types. Creating the |
| // function for one type may cause us to need more, e.g. [][]Certificate. So |
| // we just keep looping until there are no new functions to generate. There's |
| // no danger of infinite looping, since cyclic anonymous types are disallowed |
| // in the VDL type system. |
| for len(g.anonWriters) > 0 { |
| anons := g.anonWriters |
| g.anonWriters = nil |
| for _, anon := range anons { |
| body := g.body(anon, namedArg{"x", false}, false, true) |
| s += fmt.Sprintf(` |
| func %[1]s(enc %[2]sEncoder, x %[3]s) error {%[4]s |
| } |
| `, g.anonWriterName(anon), g.Pkg("v.io/v23/vdl"), typeGo(g.goData, anon), body) |
| } |
| } |
| return s |
| } |
| |
| func (g *genWrite) body(tt *vdl.Type, arg namedArg, skipNilCheck, topLevel bool) string { |
| kind := tt.Kind() |
| sta := fmt.Sprintf(` |
| if err := enc.StartValue(%[1]s); err != nil { |
| return err |
| }`, g.TypeOf(tt)) |
| fin := ` |
| if err := enc.FinishValue(); err != nil { |
| return err |
| }` |
| retnil := "" |
| if topLevel { |
| fin = ` |
| return enc.FinishValue()` |
| retnil = ` |
| return nil` |
| } |
| // Handle special cases. The ordering of the cases is very important. |
| switch { |
| case tt == vdl.ErrorType: |
| // Error types call verror.VDLWrite directly, similar to named types, but |
| // even more special-cased. Appears before optional, since ErrorType is |
| // optional. |
| return g.bodyError(arg) |
| case kind == vdl.Optional: |
| // Optional types need special nil handling. Appears before native types, |
| // to allow native types to be optional. |
| return g.bodyOptional(tt, arg, skipNilCheck) |
| case !topLevel && isNativeType(g.Env, tt): |
| // Non-top-level native types need an initial native conversion, while |
| // top-level native types use the regular logic to create VDLWrite for the |
| // wire type. Appears as early as possible, so that all subsequent cases |
| // have nativity handled correctly. |
| return g.bodyNative(tt, arg, skipNilCheck) |
| case tt.IsBytes(): |
| // Bytes call the fast Encoder.WriteValueBytes method. Appears before named |
| // types to avoid an extra VDLWrite method call, and appears before |
| // anonymous lists to avoid slow byte-at-a-time encoding. |
| return g.bodyFastpath(tt, arg, false) + retnil |
| case !topLevel && tt.Name() != "" && !g.hasFastpath(tt, false): |
| // Non-top-level named types call the VDLWrite method defined on the arg. |
| // The top-level type is always named, and needs a real body generated. |
| // We let fastpath types drop through, to avoid the extra method call. |
| return g.bodyCallVDLWrite(tt, arg, skipNilCheck) |
| case !topLevel && (kind == vdl.List || kind == vdl.Set || kind == vdl.Map): |
| // Non-top-level anonymous types call the unexported __VDLWrite* functions |
| // generated in g.Gen, after the main VDLWrite method has been generated. |
| // Top-level anonymous types use the regular logic, to generate the actual |
| // body of the __VDLWrite* functions. |
| return g.bodyAnon(tt, arg) |
| } |
| // Handle each kind of type. |
| if g.hasFastpath(tt, false) { |
| // Don't perform native conversions, since they were already performed above. |
| // All scalar types have a fastpath. |
| return g.bodyFastpath(tt, arg, false) + retnil |
| } |
| switch kind { |
| case vdl.Array: |
| return sta + g.bodyArray(tt, arg) + fin |
| case vdl.List: |
| return sta + g.bodyList(tt, arg) + fin |
| case vdl.Set: |
| return sta + g.bodySet(tt, arg) + fin |
| case vdl.Map: |
| return sta + g.bodyMap(tt, arg) + fin |
| case vdl.Struct: |
| return sta + g.bodyStruct(tt, arg) + fin |
| case vdl.Any: |
| return g.bodyAny(arg, skipNilCheck) |
| default: |
| panic(fmt.Errorf("VDLWrite unhandled type %s", tt)) |
| } |
| } |
| |
| func (g *genWrite) bodyError(arg namedArg) string { |
| return fmt.Sprintf(` |
| if err := %[1]sVDLWrite(enc, %[2]s); err != nil { |
| return err |
| }`, g.Pkg("v.io/v23/verror"), arg.Name) |
| } |
| |
| func (g *genWrite) bodyNative(tt *vdl.Type, arg namedArg, skipNilCheck bool) string { |
| s := fmt.Sprintf(` |
| var wire %[1]s |
| if err := %[1]sFromNative(&wire, %[2]s); err != nil { |
| return err |
| }`, typeGoWire(g.goData, tt), arg.Ref()) |
| return s + g.bodyCallVDLWrite(tt, typedArg("wire", tt), skipNilCheck) |
| } |
| |
| func (g *genWrite) bodyCallVDLWrite(tt *vdl.Type, arg namedArg, skipNilCheck bool) string { |
| s := fmt.Sprintf(` |
| if err := %[1]s.VDLWrite(enc); err != nil { |
| return err |
| }`, arg.Name) |
| // Handle cases where a nil arg would cause the VDLWrite call to panic. Here |
| // are the potential cases: |
| // Optional: Never happens; optional types already handled. |
| // TypeObject: The vdl.Type.VDLWrite method handles nil. |
| // List, Set, Map: VDLWrite uses len(arg) and "range arg", which handle nil. |
| // Union: Needs handling below. |
| // Any: Needs handling below. |
| if k := tt.Kind(); !skipNilCheck && (k == vdl.Union || k == vdl.Any) { |
| s = fmt.Sprintf(` |
| switch { |
| case %[1]s == nil: |
| // Write the zero value of the %[2]s type. |
| if err := %[3]sZeroValue(%[4]s).VDLWrite(enc); err != nil { |
| return err |
| } |
| default:%[5]s |
| }`, arg.Ref(), k.String(), g.Pkg("v.io/v23/vdl"), g.TypeOf(tt), s) |
| } |
| return s |
| } |
| |
| func (g *genWrite) anonWriterName(tt *vdl.Type) string { |
| return fmt.Sprintf("__VDLWriteAnon_%s_%d", tt.Kind(), g.goData.anonWriters[tt]) |
| } |
| |
| func (g *genWrite) bodyAnon(tt *vdl.Type, arg namedArg) string { |
| id := g.goData.anonWriters[tt] |
| if id == 0 { |
| // This is the first time we've encountered this type, add it. |
| id = len(g.goData.anonWriters) + 1 |
| g.goData.anonWriters[tt] = id |
| g.anonWriters = append(g.anonWriters, tt) |
| } |
| return fmt.Sprintf(` |
| if err := %[1]s(enc, %[2]s); err != nil { |
| return err |
| }`, g.anonWriterName(tt), arg.Ref()) |
| } |
| |
| func (g *genWrite) hasFastpath(tt *vdl.Type, nativeConv bool) bool { |
| method, _, _ := g.fastpathInfo(tt, namedArg{}, nativeConv) |
| return method != "" |
| } |
| |
| func (g *genWrite) fastpathInfo(tt *vdl.Type, arg namedArg, nativeConv bool) (method string, params []string, init string) { |
| kind := tt.Kind() |
| p1, p2 := "", "" |
| // When fastpathInfo is called in order to produce NextEntry* or NextField* |
| // methods, we must perform the native conversion if tt is a native type. |
| if nativeConv && isNativeType(g.Env, tt) { |
| init = fmt.Sprintf(` |
| var wire %[1]s |
| if err := %[1]sFromNative(&wire, %[2]s); err != nil { |
| return err |
| }`, typeGoWire(g.goData, tt), arg.Ref()) |
| arg = typedArg("wire", tt) |
| } |
| // Handle bytes fastpath. Go doesn't allow type conversions from []MyByte to |
| // []byte, but the reflect package does let us perform this conversion. |
| if tt.IsBytes() { |
| method, p1 = "Bytes", g.TypeOf(tt) |
| switch { |
| case tt.Elem() != vdl.ByteType: |
| slice := arg.Ref() |
| if kind == vdl.Array { |
| slice = arg.Name + "[:]" |
| } |
| p2 = fmt.Sprintf(`%sValueOf(%s).Bytes()`, g.Pkg("reflect"), slice) |
| case kind == vdl.Array: |
| p2 = arg.Name + "[:]" |
| default: |
| p2 = g.cast(arg.Ref(), tt, vdl.ListType(vdl.ByteType)) |
| } |
| } |
| // Handle scalar fastpaths. |
| switch kind { |
| case vdl.Bool: |
| method, p1, p2 = "Bool", g.TypeOf(tt), g.cast(arg.Ref(), tt, vdl.BoolType) |
| case vdl.String: |
| method, p1, p2 = "String", g.TypeOf(tt), g.cast(arg.Ref(), tt, vdl.StringType) |
| case vdl.Enum: |
| method, p1, p2 = "String", g.TypeOf(tt), arg.Name+".String()" |
| case vdl.Byte, vdl.Uint16, vdl.Uint32, vdl.Uint64: |
| method, p1, p2 = "Uint", g.TypeOf(tt), g.cast(arg.Ref(), tt, vdl.Uint64Type) |
| case vdl.Int8, vdl.Int16, vdl.Int32, vdl.Int64: |
| method, p1, p2 = "Int", g.TypeOf(tt), g.cast(arg.Ref(), tt, vdl.Int64Type) |
| case vdl.Float32, vdl.Float64: |
| method, p1, p2 = "Float", g.TypeOf(tt), g.cast(arg.Ref(), tt, vdl.Float64Type) |
| case vdl.TypeObject: |
| method, p2 = "TypeObject", arg.Ref() |
| } |
| if method == "" { |
| return "", nil, "" |
| } |
| if p1 != "" { |
| params = append(params, p1) |
| } |
| params = append(params, p2) |
| return |
| } |
| |
| func (g *genWrite) cast(value string, tt, exact *vdl.Type) string { |
| if tt != exact { |
| // The types don't have an exact match, so we need a conversion. This |
| // occurs for all named types, as well as numeric types where the bitlen |
| // isn't exactly the same. E.g. |
| // |
| // type Foo uint16 |
| // |
| // x := Foo(123) |
| // enc.WriteValueUint(tt, uint64(x)) |
| return typeGoWire(g.goData, exact) + "(" + value + ")" |
| } |
| return value |
| } |
| |
| func (g *genWrite) bodyFastpath(tt *vdl.Type, arg namedArg, nativeConv bool) string { |
| method, params, init := g.fastpathInfo(tt, arg, nativeConv) |
| return fmt.Sprintf(`%[1]s |
| if err := enc.WriteValue%[2]s(%[3]s); err != nil { |
| return err |
| }`, init, method, strings.Join(params, ", ")) |
| } |
| |
| const ( |
| encNextEntry = ` |
| if err := enc.NextEntry(false); err != nil { |
| return err |
| }` |
| encNextEntryDone = ` |
| if err := enc.NextEntry(true); err != nil { |
| return err |
| }` |
| encNextFieldDone = ` |
| if err := enc.NextField(""); err != nil { |
| return err |
| }` |
| ) |
| |
| func (g *genWrite) bodyArray(tt *vdl.Type, arg namedArg) string { |
| elemArg := typedArg("elem", tt.Elem()) |
| s := fmt.Sprintf(` |
| for _, elem := range %[1]s {`, arg.Ref()) |
| method, params, init := g.fastpathInfo(tt.Elem(), elemArg, true) |
| if method != "" { |
| s += fmt.Sprintf(`%[1]s |
| if err := enc.NextEntryValue%[2]s(%[3]s); err != nil { |
| return err |
| }`, init, method, strings.Join(params, ", ")) |
| } else { |
| s += encNextEntry |
| s += g.body(tt.Elem(), elemArg, false, false) |
| } |
| s += ` |
| }` + encNextEntryDone |
| return s |
| } |
| |
| func (g *genWrite) bodyList(tt *vdl.Type, arg namedArg) string { |
| elemArg := typedArg("elem", tt.Elem()) |
| s := fmt.Sprintf(` |
| if err := enc.SetLenHint(len(%[1]s)); err != nil { |
| return err |
| } |
| for _, elem := range %[1]s {`, arg.Ref()) |
| method, params, init := g.fastpathInfo(tt.Elem(), elemArg, true) |
| if method != "" { |
| s += fmt.Sprintf(`%[1]s |
| if err := enc.NextEntryValue%[2]s(%[3]s); err != nil { |
| return err |
| }`, init, method, strings.Join(params, ", ")) |
| } else { |
| s += encNextEntry |
| s += g.body(tt.Elem(), elemArg, false, false) |
| } |
| s += ` |
| }` + encNextEntryDone |
| return s |
| } |
| |
| func (g *genWrite) bodySet(tt *vdl.Type, arg namedArg) string { |
| keyArg := typedArg("key", tt.Key()) |
| s := fmt.Sprintf(` |
| if err := enc.SetLenHint(len(%[1]s)); err != nil { |
| return err |
| } |
| for key := range %[1]s {`, arg.Ref()) |
| method, params, init := g.fastpathInfo(tt.Key(), keyArg, true) |
| if method != "" { |
| s += fmt.Sprintf(`%[1]s |
| if err := enc.NextEntryValue%[2]s(%[3]s); err != nil { |
| return err |
| }`, init, method, strings.Join(params, ", ")) |
| } else { |
| s += encNextEntry |
| s += g.body(tt.Key(), keyArg, false, false) |
| } |
| s += ` |
| }` + encNextEntryDone |
| return s |
| } |
| |
| func (g *genWrite) bodyMap(tt *vdl.Type, arg namedArg) string { |
| keyArg, elemArg := typedArg("key", tt.Key()), typedArg("elem", tt.Elem()) |
| s := fmt.Sprintf(` |
| if err := enc.SetLenHint(len(%[1]s)); err != nil { |
| return err |
| } |
| for key, elem := range %[1]s {`, arg.Ref()) |
| method, params, init := g.fastpathInfo(tt.Key(), keyArg, true) |
| if method != "" { |
| s += fmt.Sprintf(`%[1]s |
| if err := enc.NextEntryValue%[2]s(%[3]s); err != nil { |
| return err |
| }`, init, method, strings.Join(params, ", ")) |
| } else { |
| s += encNextEntry |
| s += g.body(tt.Key(), keyArg, false, false) |
| } |
| s += g.body(tt.Elem(), elemArg, false, false) |
| s += ` |
| }` + encNextEntryDone |
| return s |
| } |
| |
| func (g *genWrite) bodyStruct(tt *vdl.Type, arg namedArg) string { |
| var s string |
| for i := 0; i < tt.NumField(); i++ { |
| field := tt.Field(i) |
| fieldArg := arg.Field(field) |
| zero := genIsZero{g.goData} |
| expr := zero.Expr(ifNeZero, field.Type, fieldArg, field.Name) |
| s += fmt.Sprintf(` |
| if %[1]s {`, expr) |
| method, params, init := g.fastpathInfo(field.Type, fieldArg, true) |
| if method != "" { |
| params = append([]string{`"` + field.Name + `"`}, params...) |
| s += fmt.Sprintf(`%[1]s |
| if err := enc.NextFieldValue%[2]s(%[3]s); err != nil { |
| return err |
| }`, init, method, strings.Join(params, ", ")) |
| } else { |
| // The second-to-last true parameter indicates that nil checks can be |
| // skipped, since we've already ensured the field isn't zero here. |
| fieldBody := g.body(field.Type, fieldArg, true, false) |
| s += fmt.Sprintf(` |
| if err := enc.NextField(%[1]q); err != nil { |
| return err |
| }%[2]s`, field.Name, fieldBody) |
| } |
| s += ` |
| }` |
| } |
| s += encNextFieldDone |
| return s |
| } |
| |
| func (g *genWrite) bodyUnion(field vdl.Field, arg namedArg) string { |
| var s string |
| fieldArg := typedArg(arg.Name+".Value", field.Type) |
| method, params, init := g.fastpathInfo(field.Type, fieldArg, true) |
| if method != "" { |
| params = append([]string{`"` + field.Name + `"`}, params...) |
| s += fmt.Sprintf(`%[1]s |
| if err := enc.NextFieldValue%[2]s(%[3]s); err != nil { |
| return err |
| }`, init, method, strings.Join(params, ", ")) |
| } else { |
| s += fmt.Sprintf(` |
| if err := enc.NextField(%[1]q); err != nil { |
| return err |
| }`, field.Name) |
| s += g.body(field.Type, fieldArg, false, false) |
| } |
| s += encNextFieldDone |
| return s |
| } |
| |
| func (g *genWrite) bodyOptional(tt *vdl.Type, arg namedArg, skipNilCheck bool) string { |
| s := ` |
| enc.SetNextStartValueIsOptional()` + g.body(tt.Elem(), arg, false, false) |
| if !skipNilCheck { |
| s = fmt.Sprintf(` |
| if %[1]s == nil { |
| if err := enc.NilValue(%[2]s); err != nil { |
| return err |
| } |
| } else {%[3]s |
| }`, arg.Name, g.TypeOf(tt), s) |
| } |
| return s |
| } |
| |
| func (g *genWrite) bodyAny(arg namedArg, skipNilCheck bool) string { |
| mode := goAnyRepMode(g.Package) |
| // Handle interface{} special-case. |
| if mode == goAnyRepInterface { |
| return fmt.Sprintf(` |
| if err := %[1]sWrite(enc, %[2]s); err != nil { |
| return err |
| }`, g.Pkg("v.io/v23/vdl"), arg.Ref()) |
| } |
| // Handle vdl.Value and vom.RawBytes representations. |
| s := fmt.Sprintf(` |
| if err := %[1]s.VDLWrite(enc); err != nil { |
| return err |
| }`, arg.Name) |
| if !skipNilCheck { |
| s = fmt.Sprintf(` |
| if %[1]s == nil { |
| if err := enc.NilValue(%[2]sAnyType); err != nil { |
| return err |
| } |
| } else {%[3]s |
| }`, arg.Ref(), g.Pkg("v.io/v23/vdl"), s) |
| } |
| return s |
| } |