TBR:
veyron/lib/flags: fix the build.
Change-Id: I187e8415c498a68595ab580a311be19e4a08810f
diff --git a/lib/flags/doc.go b/lib/flags/doc.go
new file mode 100644
index 0000000..8681ac4
--- /dev/null
+++ b/lib/flags/doc.go
@@ -0,0 +1,3 @@
+// Package flags provides implementations of the flag.Value interface
+// that are commonly used.
+package flags
diff --git a/lib/flags/listen.go b/lib/flags/listen.go
new file mode 100644
index 0000000..de2855f
--- /dev/null
+++ b/lib/flags/listen.go
@@ -0,0 +1,120 @@
+package flags
+
+import (
+ "fmt"
+ "net"
+ "strconv"
+)
+
+// TCPProtocolFlag implements flag.Value to provide validation of the
+// command line values passed to it: tcp, tcp4 or tcp6 being the
+// only allowed values.
+type TCPProtocolFlag struct{ Protocol string }
+
+// Implements flag.Value.Get
+func (t TCPProtocolFlag) Get() interface{} {
+ return t.Protocol
+}
+
+// Implements flag.Value.Set
+func (t *TCPProtocolFlag) Set(s string) error {
+ switch s {
+ case "tcp", "tcp4", "tcp6":
+ t.Protocol = s
+ return nil
+ default:
+ return fmt.Errorf("%q is not a tcp protocol", s)
+ }
+
+}
+
+// Implements flag.Value.String
+func (t TCPProtocolFlag) String() string {
+ return t.Protocol
+}
+
+// IPHostPortFlag implements flag.Value to provide validation of the
+// command line value it is set to. The allowed format is <host>:<port> in
+// ip4 and ip6 formats. The host may be specified as a hostname or as an IP
+// address (v4 or v6). If a hostname is used and it resolves to multiple IP
+// addresses then all of those addresses are stored in IPHostPort.
+type IPHostPortFlag struct {
+ Address string
+ Host string
+ IP []*net.IPAddr
+ Port string
+}
+
+// Implements flag.Value.Get
+func (ip IPHostPortFlag) Get() interface{} {
+ return ip.String()
+}
+
+// Implements flag.Value.Set
+func (ip *IPHostPortFlag) Set(s string) error {
+ ip.Address = s
+ host, port, err := net.SplitHostPort(s)
+ if err != nil {
+ // no port number in s.
+ host = s
+ ip.Port = "0"
+ } else {
+ // have a port in s.
+ if _, err := strconv.ParseUint(port, 10, 16); err != nil {
+ return fmt.Errorf("failed to parse port number from %s", s)
+ }
+ ip.Port = port
+ }
+ // if len(host) == 0 then we have no host, just a port.
+ if len(host) > 0 {
+ if addr := net.ParseIP(host); addr == nil {
+ // Could be a hostname.
+ addrs, err := net.LookupIP(host)
+ if err != nil {
+ return fmt.Errorf("%s is neither an IP address nor a host name:%s", host, err)
+ }
+ for _, a := range addrs {
+ ip.IP = append(ip.IP, &net.IPAddr{IP: a})
+ }
+ ip.Host = host
+ } else {
+ ip.IP = []*net.IPAddr{{IP: addr}}
+ }
+ return nil
+ }
+ return nil
+}
+
+// Implements flag.Value.String
+func (ip IPHostPortFlag) String() string {
+ host := ip.Host
+ if len(ip.Host) == 0 && ip.IP != nil && len(ip.IP) > 0 {
+ // We don't have a hostname, so there should be at most one IP address.
+ host = ip.IP[0].String()
+ }
+ return net.JoinHostPort(host, ip.Port)
+}
+
+// IPFlag implements flag.Value in order to provide validation of
+// IP addresses in the flag package.
+type IPFlag struct{ net.IP }
+
+// Implements flag.Value.Get
+func (ip IPFlag) Get() interface{} {
+ return ip.IP
+}
+
+// Implements flag.Value.Set
+func (ip *IPFlag) Set(s string) error {
+ t := net.ParseIP(s)
+ if t == nil {
+ return fmt.Errorf("failed to parse %s as an IP address", s)
+ }
+ ip.IP = t
+ return nil
+}
+
+// Implements flag.Value.String
+func (ip IPFlag) String() string {
+ return ip.IP.String()
+}
diff --git a/lib/flags/listen_test.go b/lib/flags/listen_test.go
new file mode 100644
index 0000000..572fd99
--- /dev/null
+++ b/lib/flags/listen_test.go
@@ -0,0 +1,91 @@
+package flags_test
+
+import (
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+
+ "veyron/lib/flags"
+)
+
+func TestIPFlag(t *testing.T) {
+ ip := &flags.IPFlag{}
+ if err := ip.Set("172.16.1.22"); err != nil {
+ t.Errorf("unexpected error %s", err)
+ }
+ if got, want := ip.IP, net.ParseIP("172.16.1.22"); !got.Equal(want) {
+ t.Errorf("got %s, expected %s", got, want)
+ }
+ if err := ip.Set("172.16"); err == nil || err.Error() != "failed to parse 172.16 as an IP address" {
+ t.Errorf("expected error %v", err)
+ }
+}
+
+func TestTCPFlag(t *testing.T) {
+ tcp := &flags.TCPProtocolFlag{}
+ if err := tcp.Set("tcp6"); err != nil {
+ t.Errorf("unexpected error %s", err)
+ }
+ if got, want := tcp.Protocol, "tcp6"; got != want {
+ t.Errorf("got %s, expected %s", got, want)
+ }
+ if err := tcp.Set("foo"); err == nil || !strings.Contains(err.Error(), "not a tcp protocol") {
+ t.Errorf("expected error %v", err)
+ }
+}
+
+func TestIPHostPortFlag(t *testing.T) {
+ lh := []*net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}
+ ip6 := []*net.IPAddr{{IP: net.ParseIP("FE80:0000:0000:0000:0202:B3FF:FE1E:8329")}}
+ cases := []struct {
+ input string
+ want flags.IPHostPortFlag
+ str string
+ }{
+ {"", flags.IPHostPortFlag{Port: "0"}, ":0"},
+ {":0", flags.IPHostPortFlag{Port: "0"}, ":0"},
+ {":22", flags.IPHostPortFlag{Port: "22"}, ":22"},
+ {"127.0.0.1", flags.IPHostPortFlag{IP: lh, Port: "0"}, "127.0.0.1:0"},
+ {"127.0.0.1:10", flags.IPHostPortFlag{IP: lh, Port: "10"}, "127.0.0.1:10"},
+ {"[]:0", flags.IPHostPortFlag{Port: "0"}, ":0"},
+ {"[FE80:0000:0000:0000:0202:B3FF:FE1E:8329]:100", flags.IPHostPortFlag{IP: ip6, Port: "100"}, "[fe80::202:b3ff:fe1e:8329]:100"},
+ }
+ for _, c := range cases {
+ got, want := &flags.IPHostPortFlag{}, &c.want
+ c.want.Address = c.input
+ if err := got.Set(c.input); err != nil || !reflect.DeepEqual(got, want) {
+ if err != nil {
+ t.Errorf("%q: unexpected error %s", c.input, err)
+ } else {
+ t.Errorf("%q: got %#v, want %#v", c.input, got, want)
+ }
+ }
+ if got.String() != c.str {
+ t.Errorf("%q: got %#v, want %#v", c.input, got.String(), c.str)
+ }
+ }
+
+ host := &flags.IPHostPortFlag{}
+ if err := host.Set("localhost:122"); err != nil {
+ t.Errorf("unexpected error: %s", err)
+ }
+ if len(host.IP) == 0 {
+ t.Errorf("localhost should have resolved to at least one address")
+ }
+ if got, want := host.Port, "122"; got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+ if got, want := host.String(), "localhost:122"; got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+
+ for _, s := range []string{
+ ":", ":59999999", "nohost.invalid", "nohost.invalid:"} {
+ f := &flags.IPHostPortFlag{}
+ if err := f.Set(s); err == nil {
+ t.Errorf("expected an error for %q", s)
+ }
+ }
+
+}