Merge "runtimes/google/ipc: Encode expected server blessings in the name instead of options.RemoteID"
diff --git a/lib/flags/flags.go b/lib/flags/flags.go
index b1313e1..b8a1473 100644
--- a/lib/flags/flags.go
+++ b/lib/flags/flags.go
@@ -24,9 +24,13 @@
// --veyron.tcp.address
// --veyron.proxy
Listen
+ // --veyron.acl (which may be repeated to supply multiple values)
+ ACL
)
const defaultNamespaceRoot = "/proxy.envyor.com:8101"
+const defaultACLName = "veyron"
+const defaultACLFile = "acl.json"
// Flags represents the set of flag groups created by a call to
// CreateAndRegister.
@@ -46,7 +50,7 @@
func (nsr *namespaceRootFlagVar) Set(v string) error {
if !nsr.isSet {
- // override the default value and
+ // override the default value
nsr.isSet = true
nsr.roots = []string{}
}
@@ -54,6 +58,30 @@
return nil
}
+type aclFlagVar struct {
+ isSet bool
+ files map[string]string
+}
+
+func (aclf *aclFlagVar) String() string {
+ return fmt.Sprintf("%v", aclf.files)
+}
+
+func (aclf *aclFlagVar) Set(v string) error {
+ if !aclf.isSet {
+ // override the default value
+ aclf.isSet = true
+ aclf.files = make(map[string]string)
+ }
+ parts := strings.SplitN(v, ":", 2)
+ if len(parts) != 2 {
+ return fmt.Errorf("%q is not in 'name:file' format", v)
+ }
+ name, file := parts[0], parts[1]
+ aclf.files[name] = file
+ return nil
+}
+
// RuntimeFlags contains the values of the Runtime flag group.
type RuntimeFlags struct {
// NamespaceRoots may be initialized by NAMESPACE_ROOT* enivornment
@@ -68,6 +96,17 @@
namespaceRootsFlag namespaceRootFlagVar
}
+// ACLFlags contains the values of the ACLFlags flag group.
+type ACLFlags struct {
+ flag aclFlagVar
+}
+
+// ACLFile returns the file which is presumed to contain ACL information
+// associated with the supplied name parameter.
+func (af ACLFlags) ACLFile(name string) string {
+ return af.flag.files[name]
+}
+
// ListenFlags contains the values of the Listen flag group.
type ListenFlags struct {
ListenProtocol TCPProtocolFlag
@@ -85,11 +124,19 @@
} else {
f.namespaceRootsFlag.roots = roots
}
+
fs.Var(&f.namespaceRootsFlag, "veyron.namespace.root", "local namespace root; can be repeated to provided multiple roots")
fs.StringVar(&f.Credentials, "veyron.credentials", creds, "directory to use for storing security credentials")
return f
}
+func createAndRegisterACLFlags(fs *flag.FlagSet) *ACLFlags {
+ f := &ACLFlags{}
+ f.flag.files = map[string]string{defaultACLName: defaultACLFile}
+ fs.Var(&f.flag, "veyron.acl", "specify an acl file as <name>:<aclfile>")
+ return f
+}
+
// createAndRegisterListenFlags creates and registers the ListenFlags
// group with the supplied flag.FlagSet.
func createAndRegisterListenFlags(fs *flag.FlagSet) *ListenFlags {
@@ -116,6 +163,8 @@
f.groups[Runtime] = createAndRegisterRuntimeFlags(fs)
case Listen:
f.groups[Listen] = createAndRegisterListenFlags(fs)
+ case ACL:
+ f.groups[ACL] = createAndRegisterACLFlags(fs)
}
}
return f
@@ -145,6 +194,17 @@
return ListenFlags{}
}
+// ACLFlags returns a copy of the ACL flag group stored in Flags.
+// This copy will contain default values if the ACL flag group
+// was not specified when CreateAndRegister was called. The HasGroup
+// method can be used for testing to see if any given group was configured.
+func (f *Flags) ACLFlags() ACLFlags {
+ if p := f.groups[ACL]; p != nil {
+ return *(p.(*ACLFlags))
+ }
+ return ACLFlags{}
+}
+
// HasGroup returns group if the supplied FlagGroup has been created
// for these Flags.
func (f *Flags) HasGroup(group FlagGroup) bool {
diff --git a/lib/flags/flags_test.go b/lib/flags/flags_test.go
index c802942..ee9c32a 100644
--- a/lib/flags/flags_test.go
+++ b/lib/flags/flags_test.go
@@ -14,11 +14,11 @@
func TestFlags(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
if flags.CreateAndRegister(fs) != nil {
- t.Errorf("should have failed")
+ t.Fatalf("should have returned a nil value")
}
fl := flags.CreateAndRegister(fs, flags.Runtime)
if fl == nil {
- t.Fatalf("should have returned a non-nil value")
+ t.Errorf("should have failed")
}
creds := "creddir"
roots := []string{"ab:cd:ef"}
@@ -42,6 +42,26 @@
}
}
+func TestACLFlags(t *testing.T) {
+ fs := flag.NewFlagSet("test", flag.ContinueOnError)
+ fl := flags.CreateAndRegister(fs, flags.Runtime, flags.ACL)
+ args := []string{"--veyron.acl=veyron:foo.json", "--veyron.acl=bar:bar.json", "--veyron.acl=baz:bar:baz.json"}
+ fl.Parse(args)
+ aclf := fl.ACLFlags()
+ if got, want := aclf.ACLFile("veyron"), "foo.json"; got != want {
+ t.Errorf("got %t, want %t", got, want)
+ }
+ if got, want := aclf.ACLFile("bar"), "bar.json"; got != want {
+ t.Errorf("got %t, want %t", got, want)
+ }
+ if got, want := aclf.ACLFile("wombat"), ""; got != want {
+ t.Errorf("got %t, want %t", got, want)
+ }
+ if got, want := aclf.ACLFile("baz"), "bar:baz.json"; got != want {
+ t.Errorf("got %t, want %t", got, want)
+ }
+}
+
func TestFlagError(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
fs.SetOutput(ioutil.Discard)
@@ -55,6 +75,15 @@
if got, want := len(fl.Args()), 1; got != want {
t.Errorf("got %d, want %d [args: %v]", got, want, fl.Args())
}
+
+ fs = flag.NewFlagSet("test", flag.ContinueOnError)
+ fs.SetOutput(ioutil.Discard)
+ fl = flags.CreateAndRegister(fs, flags.ACL)
+ args = []string{"--veyron.acl=noname"}
+ err = fl.Parse(args)
+ if err == nil {
+ t.Fatalf("expected this to fail!")
+ }
}
func TestFlagsGroups(t *testing.T) {
@@ -138,7 +167,7 @@
os.Setenv(rootEnvVar, "")
os.Setenv(rootEnvVar0, "")
- fl := flags.CreateAndRegister(flag.NewFlagSet("test", flag.ContinueOnError), flags.Runtime)
+ fl := flags.CreateAndRegister(flag.NewFlagSet("test", flag.ContinueOnError), flags.Runtime, flags.ACL)
if err := fl.Parse([]string{}); err != nil {
t.Fatalf("unexpected error: %s", err)
}
@@ -146,4 +175,8 @@
if got, want := rtf.NamespaceRoots, []string{"/proxy.envyor.com:8101"}; !reflect.DeepEqual(got, want) {
t.Errorf("got %q, want %q", got, want)
}
+ aclf := fl.ACLFlags()
+ if got, want := aclf.ACLFile("veyron"), "acl.json"; got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
}
diff --git a/lib/glob/glob.go b/lib/glob/glob.go
index 353179d..94453f9 100644
--- a/lib/glob/glob.go
+++ b/lib/glob/glob.go
@@ -25,8 +25,9 @@
// Glob represents a slash separated path glob expression.
type Glob struct {
- elems []string
- recursive bool
+ elems []string
+ recursive bool
+ restricted bool
}
// Parse returns a new Glob.
@@ -39,9 +40,15 @@
if pattern != "" {
g.elems = strings.Split(pattern, "/")
}
- if last := len(g.elems) - 1; last >= 0 && g.elems[last] == "..." {
- g.elems = g.elems[:last]
- g.recursive = true
+ if last := len(g.elems) - 1; last >= 0 {
+ if g.elems[last] == "..." {
+ g.elems = g.elems[:last]
+ g.recursive = true
+ } else if g.elems[last] == "***" {
+ g.elems = g.elems[:last]
+ g.recursive = true
+ g.restricted = true
+ }
}
// The only error we can get from the filepath library is badpattern.
@@ -67,12 +74,18 @@
return !g.recursive && len(g.elems) == 0
}
+// Restricted returns true if recursion is restricted (up to the caller to
+// know what that means).
+func (g *Glob) Restricted() bool {
+ return g.restricted
+}
+
// Split returns the suffix of g starting at the path element corresponding to start.
func (g *Glob) Split(start int) *Glob {
if start >= len(g.elems) {
- return &Glob{elems: nil, recursive: g.recursive}
+ return &Glob{elems: nil, recursive: g.recursive, restricted: g.restricted}
}
- return &Glob{elems: g.elems[start:], recursive: g.recursive}
+ return &Glob{elems: g.elems[start:], recursive: g.recursive, restricted: g.restricted}
}
// MatchInitialSegment tries to match segment against the initial element of g.
@@ -169,7 +182,11 @@
func (g *Glob) String() string {
e := g.elems
if g.recursive {
- e = append(e, "...")
+ if g.restricted {
+ e = append(e, "***")
+ } else {
+ e = append(e, "...")
+ }
}
return filepath.Join(e...)
}
diff --git a/lib/modules/core/wspr.go b/lib/modules/core/wspr.go
index 13f3b7b..7665530 100644
--- a/lib/modules/core/wspr.go
+++ b/lib/modules/core/wspr.go
@@ -47,7 +47,7 @@
}
args = fl.Args()
- proxy := wspr.NewWSPR(*port, initListenSpec(fl), *identd)
+ proxy := wspr.NewWSPR(*port, initListenSpec(fl), *identd, nil)
defer proxy.Shutdown()
addr := proxy.Listen()
diff --git a/runtimes/google/lib/publisher/publisher.go b/runtimes/google/lib/publisher/publisher.go
index f7e474a..63fad95 100644
--- a/runtimes/google/lib/publisher/publisher.go
+++ b/runtimes/google/lib/publisher/publisher.go
@@ -190,11 +190,11 @@
ctx context.T
ns naming.Namespace
period time.Duration
- deadline time.Time // deadline for the next sync call
- names []string // names that have been added
- servers map[string]bool // servers that have been added
- servesMT map[string]bool // true if server is a mount table server
- mounts map[mountKey]*mountStatus // map each (name,server) to its status
+ deadline time.Time // deadline for the next sync call
+ names []string // names that have been added
+ servers map[string]bool // servers that have been added, true
+ // if server is a mount table server
+ mounts map[mountKey]*mountStatus // map each (name,server) to its status
}
type mountKey struct {
@@ -216,7 +216,6 @@
period: period,
deadline: time.Now().Add(period),
servers: make(map[string]bool),
- servesMT: make(map[string]bool),
mounts: make(map[mountKey]*mountStatus),
}
}
@@ -234,18 +233,17 @@
}
}
ps.names = append(ps.names, name)
- for server, _ := range ps.servers {
+ for server, servesMT := range ps.servers {
status := new(mountStatus)
ps.mounts[mountKey{name, server}] = status
- ps.mount(name, server, status, ps.servesMT[server])
+ ps.mount(name, server, status, servesMT)
}
}
func (ps *pubState) addServer(server string, servesMT bool) {
// Each non-dup server that is added causes new mounts to be created for all
// existing names.
- if !ps.servers[server] {
- ps.servers[server] = true
+ if _, exists := ps.servers[server]; !exists {
ps.servers[server] = servesMT
for _, name := range ps.names {
status := new(mountStatus)
@@ -288,7 +286,7 @@
// Desired state is "unmounted", failed at previous attempt. Retry.
ps.unmount(key.name, key.server, status)
} else {
- ps.mount(key.name, key.server, status, ps.servesMT[key.server])
+ ps.mount(key.name, key.server, status, ps.servers[key.server])
}
}
}
diff --git a/runtimes/google/naming/namespace/all_test.go b/runtimes/google/naming/namespace/all_test.go
index 941694c..b803577 100644
--- a/runtimes/google/naming/namespace/all_test.go
+++ b/runtimes/google/naming/namespace/all_test.go
@@ -358,7 +358,7 @@
{"mt2", mts[mt4MP].name},
{"//mt2", mts[mt5MP].name},
} {
- if err := ns.Mount(r.NewContext(), mp.name, mp.server, ttl); err != nil {
+ if err := ns.Mount(r.NewContext(), mp.name, mp.server, ttl, naming.ServesMountTableOpt(true)); err != nil {
boom(t, "Failed to Mount %s: %s", mp.name, err)
}
}
@@ -507,14 +507,21 @@
tests := []struct {
pattern string
expectedCalls int
+ expected []string
}{
- {"mt4/foo/bar/glob", 0},
- {"mt4/foo/bar/glob/...", 1},
- {"mt4/foo/bar/*", 0},
+ {"mt4/foo/bar/glob", 0, []string{"mt4/foo/bar/glob"}},
+ {"mt4/foo/bar/glob/...", 1, []string{"mt4/foo/bar/glob"}},
+ {"mt4/foo/bar/glob/*", 1, nil},
+ {"mt4/foo/bar/***", 0, []string{"mt4/foo/bar", "mt4/foo/bar/glob"}},
+ {"mt4/foo/bar/...", 1, []string{"mt4/foo/bar", "mt4/foo/bar/glob"}},
+ {"mt4/foo/bar/*", 0, []string{"mt4/foo/bar/glob"}},
+ {"mt4/***/bar/***", 0, []string{"mt4/foo/bar", "mt4/foo/bar/glob"}},
+ {"mt4/*/bar/***", 0, []string{"mt4/foo/bar", "mt4/foo/bar/glob"}},
}
+ // Test allowing the tests to descend into leaves.
for _, test := range tests {
out := doGlob(t, r, ns, test.pattern, 0)
- compare(t, "Glob", test.pattern, []string{"mt4/foo/bar/glob"}, out)
+ compare(t, "Glob", test.pattern, out, test.expected)
if calls := globServer.GetAndResetCount(); calls != test.expectedCalls {
boom(t, "Wrong number of Glob calls to terminal server got: %d want: %d.", calls, test.expectedCalls)
}
@@ -539,17 +546,17 @@
defer c3.server.Stop()
m := "c1/c2"
- if err := ns.Mount(r.NewContext(), m, c1.name, ttl); err != nil {
+ if err := ns.Mount(r.NewContext(), m, c1.name, ttl, naming.ServesMountTableOpt(true)); err != nil {
boom(t, "Failed to Mount %s: %s", "c1/c2", err)
}
m = "c1/c2/c3"
- if err := ns.Mount(r.NewContext(), m, c3.name, ttl); err != nil {
+ if err := ns.Mount(r.NewContext(), m, c3.name, ttl, naming.ServesMountTableOpt(true)); err != nil {
boom(t, "Failed to Mount %s: %s", m, err)
}
m = "c1/c3/c4"
- if err := ns.Mount(r.NewContext(), m, c1.name, ttl); err != nil {
+ if err := ns.Mount(r.NewContext(), m, c1.name, ttl, naming.ServesMountTableOpt(true)); err != nil {
boom(t, "Failed to Mount %s: %s", m, err)
}
diff --git a/runtimes/google/naming/namespace/cache.go b/runtimes/google/naming/namespace/cache.go
index 13ee576..be7b469 100644
--- a/runtimes/google/naming/namespace/cache.go
+++ b/runtimes/google/naming/namespace/cache.go
@@ -96,7 +96,7 @@
for _, s := range entry.Servers {
ce.Servers = append(ce.Servers, s)
}
- ce.MT = entry.MT
+ ce.SetServesMountTable(entry.ServesMountTable())
// All keys must be terminal.
prefix = naming.MakeTerminal(prefix)
c.Lock()
diff --git a/runtimes/google/naming/namespace/cache_test.go b/runtimes/google/naming/namespace/cache_test.go
index d707509..0f4cffd 100644
--- a/runtimes/google/naming/namespace/cache_test.go
+++ b/runtimes/google/naming/namespace/cache_test.go
@@ -124,7 +124,7 @@
t.Errorf("%s should have caused something to flush", toflush)
}
name := preload[2].name
- if _, ok := c.entries[name]; !ok {
+ if _, err := c.lookup(name); err != nil {
t.Errorf("%s should not have been flushed", name)
}
if len(c.entries) != 2 {
diff --git a/runtimes/google/naming/namespace/glob.go b/runtimes/google/naming/namespace/glob.go
index ff11b80..4083bb3 100644
--- a/runtimes/google/naming/namespace/glob.go
+++ b/runtimes/google/naming/namespace/glob.go
@@ -9,6 +9,7 @@
"veyron.io/veyron/veyron2/context"
"veyron.io/veyron/veyron2/naming"
+ "veyron.io/veyron/veyron2/options"
"veyron.io/veyron/veyron2/services/mounttable/types"
verror "veyron.io/veyron/veyron2/verror2"
"veyron.io/veyron/veyron2/vlog"
@@ -40,17 +41,26 @@
// query on since we know the server will not supply a new address for the
// current name.
if pattern.Finished() {
+ if !server.ServesMountTable() {
+ return nil
+ }
+ // TODO(p): soon to be unnecessary.
_, n := naming.SplitAddressName(s.Server)
if strings.HasPrefix(n, "//") {
return nil
}
}
+ // If this is restricted recursive and not a mount table, don't
+ // descend into it.
+ if pattern.Restricted() && !server.ServesMountTable() && pattern.Len() == 0 {
+ return nil
+ }
+
// Don't further resolve s.Server.
- s := naming.MakeTerminal(s.Server)
callCtx, _ := ctx.WithTimeout(callTimeout)
client := ns.rt.Client()
- call, err := client.StartCall(callCtx, s, "Glob", []interface{}{pstr})
+ call, err := client.StartCall(callCtx, s.Server, "Glob", []interface{}{pstr}, options.NoResolve(true))
if err != nil {
lastErr = err
continue // try another instance
@@ -79,6 +89,7 @@
},
depth: qe.depth,
}
+ x.me.SetServesMountTable(e.MT)
// x.depth is the number of severs we've walked through since we've gone
// recursive (i.e. with pattern length of 0).
if pattern.Len() == 0 {
@@ -102,29 +113,25 @@
// Glob implements naming.MountTable.Glob.
func (ns *namespace) Glob(ctx context.T, pattern string) (chan naming.MountEntry, error) {
defer vlog.LogCall()()
- root, globPattern := naming.SplitAddressName(pattern)
- g, err := glob.Parse(globPattern)
+ e, patternWasRooted := ns.rootMountEntry(pattern)
+ if len(e.Servers) == 0 {
+ return nil, verror.Make(naming.ErrNoMountTable, ctx)
+ }
+
+ // If pattern was already rooted, make sure we tack that root
+ // onto all returned names. Otherwise, just return the relative
+ // name.
+ var prefix string
+ if patternWasRooted {
+ prefix = e.Servers[0].Server
+ }
+ g, err := glob.Parse(e.Name)
if err != nil {
return nil, err
}
-
- // Add constant components of pattern to the servers' addresses and
- // to the prefix.
- //var prefixElements []string
- //prefixElements, g = g.SplitFixedPrefix()
- //prefix := strings.Join(prefixElements, "/")
- prefix := ""
- if len(root) != 0 {
- prefix = naming.JoinAddressName(root, prefix)
- }
-
- // Start a thread to get the results and return the reply channel to the caller.
- servers := ns.rootName(prefix)
- if len(servers) == 0 {
- return nil, verror.Make(naming.ErrNoMountTable, ctx)
- }
+ e.Name = ""
reply := make(chan naming.MountEntry, 100)
- go ns.globLoop(ctx, servers, prefix, g, reply)
+ go ns.globLoop(ctx, e, prefix, g, reply)
return reply, nil
}
@@ -137,7 +144,7 @@
return strings.Count(name, "/") - strings.Count(name, "//") + 1
}
-func (ns *namespace) globLoop(ctx context.T, servers []string, prefix string, pattern *glob.Glob, reply chan naming.MountEntry) {
+func (ns *namespace) globLoop(ctx context.T, e *naming.MountEntry, prefix string, pattern *glob.Glob, reply chan naming.MountEntry) {
defer close(reply)
// As we encounter new mount tables while traversing the Glob, we add them to the list 'l'. The loop below
@@ -145,7 +152,7 @@
// server. globAtServer will send on 'reply' any terminal entries that match the glob and add any new mount
// tables to be traversed to the list 'l'.
l := list.New()
- l.PushBack(&queuedEntry{me: &naming.MountEntry{Name: "", Servers: convertStringsToServers(servers)}})
+ l.PushBack(&queuedEntry{me: e})
atRoot := true
// Perform a breadth first search of the name graph.
@@ -168,8 +175,7 @@
reply <- x
}
- // 2. The current name fullfills the pattern and further servers did not respond
- // with "". That is, we want to prefer foo/ over foo.
+ // 2. The current name fullfills the pattern.
if suffix.Len() == 0 && !atRoot {
x := *e.me
x.Name = naming.Join(prefix, x.Name)
diff --git a/runtimes/google/naming/namespace/namespace.go b/runtimes/google/naming/namespace/namespace.go
index 2353607..1b966d7 100644
--- a/runtimes/google/naming/namespace/namespace.go
+++ b/runtimes/google/naming/namespace/namespace.go
@@ -109,29 +109,26 @@
}
// rootMountEntry 'roots' a name creating a mount entry for the name.
-func (ns *namespace) rootMountEntry(name string) *naming.MountEntry {
+func (ns *namespace) rootMountEntry(name string) (*naming.MountEntry, bool) {
e := new(naming.MountEntry)
expiration := time.Now().Add(time.Hour) // plenty of time for a call
address, suffix := naming.SplitAddressName(name)
if len(address) == 0 {
- e.MT = true
+ e.SetServesMountTable(true)
e.Name = name
ns.RLock()
defer ns.RUnlock()
for _, r := range ns.roots {
e.Servers = append(e.Servers, naming.MountedServer{Server: r, Expires: expiration})
}
- return e
+ return e, false
}
// TODO(p): right now I assume any address handed to me to be resolved is a mount table.
// Eventually we should do something like the following:
- // if ep, err := ns.rt.NewEndpoint(address); err == nil && ep.ServesMountTable() {
- // e.MT = true
- // }
- e.MT = true
+ e.SetServesMountTable(true)
e.Name = suffix
- e.Servers = append(e.Servers, naming.MountedServer{Server: address, Expires: expiration})
- return e
+ e.Servers = append(e.Servers, naming.MountedServer{Server: naming.JoinAddressName(address, ""), Expires: expiration})
+ return e, true
}
// notAnMT returns true if the error indicates this isn't a mounttable server.
diff --git a/runtimes/google/naming/namespace/resolve.go b/runtimes/google/naming/namespace/resolve.go
index 7333bf9..5943b6a 100644
--- a/runtimes/google/naming/namespace/resolve.go
+++ b/runtimes/google/naming/namespace/resolve.go
@@ -66,17 +66,10 @@
return true
}
-func makeTerminal(names []string) (ret []string) {
- for _, name := range names {
- ret = append(ret, naming.MakeTerminal(name))
- }
- return
-}
-
// ResolveX implements veyron2/naming.Namespace.
func (ns *namespace) ResolveX(ctx context.T, name string) (*naming.MountEntry, error) {
defer vlog.LogCall()()
- e := ns.rootMountEntry(name)
+ e, _ := ns.rootMountEntry(name)
if vlog.V(2) {
_, file, line, _ := runtime.Caller(1)
vlog.Infof("ResolveX(%s) called from %s:%d", name, file, line)
@@ -88,7 +81,7 @@
// Iterate walking through mount table servers.
for remaining := ns.maxResolveDepth; remaining > 0; remaining-- {
vlog.VI(2).Infof("ResolveX(%s) loop %v", name, *e)
- if !e.MT || terminal(e) {
+ if !e.ServesMountTable() || terminal(e) {
vlog.VI(1).Infof("ResolveX(%s) -> %v", name, *e)
return e, nil
}
@@ -126,7 +119,7 @@
// ResolveToMountTableX implements veyron2/naming.Namespace.
func (ns *namespace) ResolveToMountTableX(ctx context.T, name string) (*naming.MountEntry, error) {
defer vlog.LogCall()()
- e := ns.rootMountEntry(name)
+ e, _ := ns.rootMountEntry(name)
if vlog.V(2) {
_, file, line, _ := runtime.Caller(1)
vlog.Infof("ResolveToMountTableX(%s) called from %s:%d", name, file, line)
@@ -141,7 +134,7 @@
var err error
curr := e
// If the next name to resolve doesn't point to a mount table, we're done.
- if !e.MT || terminal(e) {
+ if !e.ServesMountTable() || terminal(e) {
vlog.VI(1).Infof("ResolveToMountTableX(%s) -> %v", name, last)
return last, nil
}
@@ -196,9 +189,8 @@
func unresolveAgainstServer(ctx context.T, client ipc.Client, names []string) ([]string, error) {
finalErr := errors.New("no servers to unresolve")
for _, name := range names {
- name = naming.MakeTerminal(name)
callCtx, _ := ctx.WithTimeout(callTimeout)
- call, err := client.StartCall(callCtx, name, "UnresolveStep", nil)
+ call, err := client.StartCall(callCtx, name, "UnresolveStep", nil, options.NoResolve(true))
if err != nil {
finalErr = err
vlog.VI(2).Infof("StartCall %q.UnresolveStep() failed: %s", name, err)
diff --git a/runtimes/google/naming/namespace/stub.go b/runtimes/google/naming/namespace/stub.go
index 5e1699c..2f62a39 100644
--- a/runtimes/google/naming/namespace/stub.go
+++ b/runtimes/google/naming/namespace/stub.go
@@ -34,5 +34,7 @@
}
func convertMountEntry(e *types.MountEntry) *naming.MountEntry {
- return &naming.MountEntry{Name: e.Name, MT: e.MT, Servers: convertServers(e.Servers)}
+ v := &naming.MountEntry{Name: e.Name, Servers: convertServers(e.Servers)}
+ v.SetServesMountTable(e.MT)
+ return v
}
diff --git a/services/mgmt/binary/binaryd/main.go b/services/mgmt/binary/binaryd/main.go
index 3a74c7c..28bb8ad 100644
--- a/services/mgmt/binary/binaryd/main.go
+++ b/services/mgmt/binary/binaryd/main.go
@@ -3,6 +3,7 @@
import (
"flag"
"io/ioutil"
+ "net/http"
"os"
"path/filepath"
@@ -59,6 +60,21 @@
}
}
vlog.Infof("Binary repository rooted at %v", *root)
+
+ state, err := impl.NewState(*root, defaultDepth)
+ if err != nil {
+ vlog.Errorf("NewState(%v, %v) failed: %v", *root, defaultDepth, err)
+ return
+ }
+
+ // TODO(caprita): Flagify port.
+ go func() {
+ if err := http.ListenAndServe(":8080", http.FileServer(impl.NewHTTPRoot(state))); err != nil {
+ vlog.Errorf("ListenAndServe() failed: %v", err)
+ os.Exit(1)
+ }
+ }()
+
runtime := rt.Init()
defer runtime.Cleanup()
server, err := runtime.NewServer()
@@ -68,17 +84,12 @@
}
defer server.Stop()
auth := vflag.NewAuthorizerOrDie()
- dispatcher, err := impl.NewDispatcher(*root, defaultDepth, auth)
- if err != nil {
- vlog.Errorf("NewDispatcher(%v, %v, %v) failed: %v", *root, defaultDepth, auth, err)
- return
- }
endpoint, err := server.Listen(roaming.ListenSpec)
if err != nil {
vlog.Errorf("Listen(%s) failed: %v", roaming.ListenSpec, err)
return
}
- if err := server.Serve(*name, dispatcher); err != nil {
+ if err := server.Serve(*name, impl.NewDispatcher(state, auth)); err != nil {
vlog.Errorf("Serve(%v) failed: %v", *name, err)
return
}
diff --git a/services/mgmt/binary/binaryd/test.sh b/services/mgmt/binary/binaryd/test.sh
index dae104b..e3547e9 100755
--- a/services/mgmt/binary/binaryd/test.sh
+++ b/services/mgmt/binary/binaryd/test.sh
@@ -27,7 +27,8 @@
|| shell_test::fail "line ${LINENO} failed to start binaryd"
# Create a binary file.
- local -r BINARY="${REPO}/test-binary"
+ local -r BINARY_SUFFIX="test-binary"
+ local -r BINARY="${REPO}/${BINARY_SUFFIX}"
local -r BINARY_FILE=$(shell::tmp_file)
dd if=/dev/urandom of="${BINARY_FILE}" bs=1000000 count=16 \
|| shell_test::fail "line ${LINENO}: faile to create a random binary file"
@@ -35,9 +36,15 @@
# Download the binary file.
local -r BINARY_FILE2=$(shell::tmp_file)
- "${BINARY_BIN}" download "${BINARY}" "${BINARY_FILE2}" || shell_test::fail "line ${LINENO}: 'download' failed"
+ "${BINARY_BIN}" download "${BINARY}" "${BINARY_FILE2}" || shell_test::fail "line ${LINENO}: 'RPC download' failed"
if [[ $(cmp "${BINARY_FILE}" "${BINARY_FILE2}" &> /dev/null) ]]; then
- shell_test::fail "mismatching binary files"
+ shell_test::fail "mismatching binary file downloaded via RPC"
+ fi
+
+ local -r BINARY_FILE3=$(shell::tmp_file)
+ curl -f -o "${BINARY_FILE3}" http://localhost:8080/"${BINARY_SUFFIX}" || shell_test::fail "line ${LINENO}: 'HTTP download' failed"
+ if [[ $(cmp "${BINARY_FILE}" "${BINARY_FILE3}" &> /dev/null) ]]; then
+ shell_test::fail "mismatching binary file downloaded via HTTP"
fi
# Remove the binary file.
diff --git a/services/mgmt/binary/impl/dispatcher.go b/services/mgmt/binary/impl/dispatcher.go
index 7f54fb7..465b0a6 100644
--- a/services/mgmt/binary/impl/dispatcher.go
+++ b/services/mgmt/binary/impl/dispatcher.go
@@ -24,8 +24,11 @@
state *state
}
-// newDispatcher is the dispatcher factory.
-func NewDispatcher(root string, depth int, authorizer security.Authorizer) (*dispatcher, error) {
+// TODO(caprita): Move this together with state into a new file, state.go.
+
+// NewState creates a new state object for the binary service. This
+// should be passed into both NewDispatcher and NewHTTPRoot.
+func NewState(root string, depth int) (*state, error) {
if min, max := 0, md5.Size-1; min > depth || depth > max {
return nil, fmt.Errorf("Unexpected depth, expected a value between %v and %v, got %v", min, max, depth)
}
@@ -40,15 +43,20 @@
if expected, got := Version, strings.TrimSpace(string(output)); expected != got {
return nil, fmt.Errorf("Unexpected version: expected %v, got %v", expected, got)
}
- return &dispatcher{
- auth: authorizer,
- state: &state{
- depth: depth,
- root: root,
- },
+ return &state{
+ depth: depth,
+ root: root,
}, nil
}
+// NewDispatcher is the dispatcher factory.
+func NewDispatcher(state *state, authorizer security.Authorizer) ipc.Dispatcher {
+ return &dispatcher{
+ auth: authorizer,
+ state: state,
+ }
+}
+
// DISPATCHER INTERFACE IMPLEMENTATION
func (d *dispatcher) Lookup(suffix, method string) (interface{}, security.Authorizer, error) {
diff --git a/services/mgmt/binary/impl/http.go b/services/mgmt/binary/impl/http.go
new file mode 100644
index 0000000..2c189b3
--- /dev/null
+++ b/services/mgmt/binary/impl/http.go
@@ -0,0 +1,49 @@
+package impl
+
+import (
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "veyron.io/veyron/veyron2/vlog"
+
+ "veyron.io/veyron/veyron/services/mgmt/binary/impl/merge_file"
+)
+
+// NewHTTPRoot returns an implementation of http.FileSystem that can be used
+// to serve the content in the binary service.
+func NewHTTPRoot(state *state) http.FileSystem {
+ return &httpRoot{state}
+}
+
+type httpRoot struct {
+ state *state
+}
+
+// TODO(caprita): Tie this in with DownloadURL, to control which binaries
+// are downloadable via url.
+
+// Open implements http.FileSystem. It uses the merge file implementation
+// to wrap the content parts into one logical file.
+func (r httpRoot) Open(name string) (http.File, error) {
+ name = strings.TrimPrefix(name, "/")
+ vlog.Infof("HTTP handler opening %s", name)
+ parts, err := getParts(dir(name, r.state))
+ if err != nil {
+ return nil, err
+ }
+ partFiles := make([]*os.File, len(parts))
+ for i, part := range parts {
+ if err := checksumExists(part); err != nil {
+ return nil, err
+ }
+ dataPath := filepath.Join(part, data)
+ var err error
+ if partFiles[i], err = os.Open(dataPath); err != nil {
+ vlog.Errorf("Open(%v) failed: %v", dataPath, err)
+ return nil, errOperationFailed
+ }
+ }
+ return merge_file.NewMergeFile(name, partFiles)
+}
diff --git a/services/mgmt/binary/impl/http_test.go b/services/mgmt/binary/impl/http_test.go
new file mode 100644
index 0000000..5b31d0f
--- /dev/null
+++ b/services/mgmt/binary/impl/http_test.go
@@ -0,0 +1,75 @@
+package impl
+
+import (
+ "bytes"
+ "crypto/md5"
+ "encoding/hex"
+ "io/ioutil"
+ "net/http"
+ "testing"
+
+ "veyron.io/veyron/veyron2/rt"
+
+ "veyron.io/veyron/veyron/lib/testutil"
+)
+
+// TestHTTP checks that HTTP download works.
+func TestHTTP(t *testing.T) {
+ // TODO(caprita): This is based on TestMultiPart (impl_test.go). Share
+ // the code where possible.
+ for length := 2; length < 5; length++ {
+ binary, url, cleanup := startServer(t, 2)
+ defer cleanup()
+ // Create <length> chunks of up to 4MB of random bytes.
+ data := make([][]byte, length)
+ for i := 0; i < length; i++ {
+ // Random size, but at least 1 (avoid empty parts).
+ size := testutil.Rand.Intn(1000*bufferLength) + 1
+ data[i] = testutil.RandomBytes(size)
+ }
+ if err := binary.Create(rt.R().NewContext(), int32(length)); err != nil {
+ t.Fatalf("Create() failed: %v", err)
+ }
+ for i := 0; i < length; i++ {
+ if streamErr, err := invokeUpload(t, binary, data[i], int32(i)); streamErr != nil || err != nil {
+ t.FailNow()
+ }
+ }
+ parts, err := binary.Stat(rt.R().NewContext())
+ if err != nil {
+ t.Fatalf("Stat() failed: %v", err)
+ }
+ response, err := http.Get(url)
+ if err != nil {
+ t.Fatal(err)
+ }
+ downloaded, err := ioutil.ReadAll(response.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ from, to := 0, 0
+ for i := 0; i < length; i++ {
+ hpart := md5.New()
+ to += len(data[i])
+ if ld := len(downloaded); to > ld {
+ t.Fatalf("Download falls short: len(downloaded):%d, need:%d (i:%d, length:%d)", ld, to, i, length)
+ }
+ output := downloaded[from:to]
+ from = to
+ if bytes.Compare(output, data[i]) != 0 {
+ t.Fatalf("Unexpected output: expected %v, got %v", data[i], output)
+ }
+ hpart.Write(data[i])
+ checksum := hex.EncodeToString(hpart.Sum(nil))
+ if expected, got := checksum, parts[i].Checksum; expected != got {
+ t.Fatalf("Unexpected checksum: expected %v, got %v", expected, got)
+ }
+ if expected, got := len(data[i]), int(parts[i].Size); expected != got {
+ t.Fatalf("Unexpected size: expected %v, got %v", expected, got)
+ }
+ }
+ if err := binary.Delete(rt.R().NewContext()); err != nil {
+ t.Fatalf("Delete() failed: %v", err)
+ }
+ }
+}
diff --git a/services/mgmt/binary/impl/impl_test.go b/services/mgmt/binary/impl/impl_test.go
index ba13d5d..8212e1e 100644
--- a/services/mgmt/binary/impl/impl_test.go
+++ b/services/mgmt/binary/impl/impl_test.go
@@ -4,7 +4,10 @@
"bytes"
"crypto/md5"
"encoding/hex"
+ "fmt"
"io/ioutil"
+ "net"
+ "net/http"
"os"
"path/filepath"
"testing"
@@ -92,7 +95,7 @@
}
// startServer starts the binary repository server.
-func startServer(t *testing.T, depth int) (repository.Binary, func()) {
+func startServer(t *testing.T, depth int) (repository.Binary, string, func()) {
// Setup the root of the binary repository.
root, err := ioutil.TempDir("", veyronPrefix)
if err != nil {
@@ -107,10 +110,20 @@
if err != nil {
t.Fatalf("NewServer() failed: %v", err)
}
- dispatcher, err := NewDispatcher(root, depth, nil)
+ state, err := NewState(root, depth)
if err != nil {
- t.Fatalf("NewDispatcher(%v, %v, %v) failed: %v", root, depth, nil, err)
+ t.Fatalf("NewState(%v, %v) failed: %v", root, depth, err)
}
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ if err := http.Serve(listener, http.FileServer(NewHTTPRoot(state))); err != nil {
+ vlog.Fatalf("Serve() failed: %v", err)
+ }
+ }()
+ dispatcher := NewDispatcher(state, nil)
endpoint, err := server.Listen(profiles.LocalListenSpec)
if err != nil {
t.Fatalf("Listen(%s) failed: %v", profiles.LocalListenSpec, err)
@@ -124,17 +137,17 @@
if err != nil {
t.Fatalf("BindBinary(%v) failed: %v", name, err)
}
- return binary, func() {
+ return binary, fmt.Sprintf("http://%s/test", listener.Addr()), func() {
// Shutdown the binary repository server.
if err := server.Stop(); err != nil {
t.Fatalf("Stop() failed: %v", err)
}
- if err := os.Remove(path); err != nil {
+ if err := os.RemoveAll(path); err != nil {
t.Fatalf("Remove(%v) failed: %v", path, err)
}
// Check that any directories and files that were created to
// represent the binary objects have been garbage collected.
- if err := os.Remove(root); err != nil {
+ if err := os.RemoveAll(root); err != nil {
t.Fatalf("Remove(%v) failed: %v", root, err)
}
}
@@ -145,7 +158,7 @@
// hierarchy that stores binary objects in the local file system.
func TestHierarchy(t *testing.T) {
for i := 0; i < md5.Size; i++ {
- binary, cleanup := startServer(t, i)
+ binary, _, cleanup := startServer(t, i)
defer cleanup()
// Create up to 4MB of random bytes.
size := testutil.Rand.Intn(1000 * bufferLength)
@@ -188,7 +201,7 @@
// consists of.
func TestMultiPart(t *testing.T) {
for length := 2; length < 5; length++ {
- binary, cleanup := startServer(t, 2)
+ binary, _, cleanup := startServer(t, 2)
defer cleanup()
// Create <length> chunks of up to 4MB of random bytes.
data := make([][]byte, length)
@@ -209,7 +222,6 @@
if err != nil {
t.Fatalf("Stat() failed: %v", err)
}
- h := md5.New()
for i := 0; i < length; i++ {
hpart := md5.New()
output, streamErr, err := invokeDownload(t, binary, int32(i))
@@ -219,7 +231,6 @@
if bytes.Compare(output, data[i]) != 0 {
t.Fatalf("Unexpected output: expected %v, got %v", data[i], output)
}
- h.Write(data[i])
hpart.Write(data[i])
checksum := hex.EncodeToString(hpart.Sum(nil))
if expected, got := checksum, parts[i].Checksum; expected != got {
@@ -240,7 +251,7 @@
// of.
func TestResumption(t *testing.T) {
for length := 2; length < 5; length++ {
- binary, cleanup := startServer(t, 2)
+ binary, _, cleanup := startServer(t, 2)
defer cleanup()
// Create <length> chunks of up to 4MB of random bytes.
data := make([][]byte, length)
@@ -282,9 +293,9 @@
// TestErrors checks that the binary interface correctly reports errors.
func TestErrors(t *testing.T) {
- binary, cleanup := startServer(t, 2)
+ binary, _, cleanup := startServer(t, 2)
defer cleanup()
- length := 2
+ const length = 2
data := make([][]byte, length)
for i := 0; i < length; i++ {
size := testutil.Rand.Intn(1000 * bufferLength)
@@ -320,6 +331,20 @@
if _, streamErr, err := invokeDownload(t, binary, 0); streamErr != nil || err != nil {
t.Fatalf("Download() failed: %v", err)
}
+ // Upload/Download on a part number that's outside the range set forth in
+ // Create should fail.
+ for _, part := range []int32{-1, length} {
+ if _, err := invokeUpload(t, binary, []byte("dummy"), part); err == nil {
+ t.Fatalf("Upload() did not fail when it should have")
+ } else if want := verror.BadArg; !verror.Is(err, want) {
+ t.Fatalf("Unexpected error: %v, expected error id %v", err, want)
+ }
+ if _, _, err := invokeDownload(t, binary, part); err == nil {
+ t.Fatalf("Download() did not fail when it should have")
+ } else if want := verror.BadArg; !verror.Is(err, want) {
+ t.Fatalf("Unexpected error: %v, expected error id %v", err, want)
+ }
+ }
if err := binary.Delete(rt.R().NewContext()); err != nil {
t.Fatalf("Delete() failed: %v", err)
}
diff --git a/services/mgmt/binary/impl/invoker.go b/services/mgmt/binary/impl/invoker.go
index 423f44f..893abcf 100644
--- a/services/mgmt/binary/impl/invoker.go
+++ b/services/mgmt/binary/impl/invoker.go
@@ -4,11 +4,9 @@
// MD5 hash of the suffix and generates the following path in the
// local filesystem: /<root>/<dir_1>/.../<dir_n>/<hash>. The root and
// the directory depth are parameters of the implementation. The
-// contents of the directory include the checksum and data for the
-// object and each of its individual parts:
+// contents of the directory include the checksum and data for each of
+// the individual parts of the binary:
//
-// checksum
-// data
// <part_1>/checksum
// <part_1>/data
// ...
@@ -86,6 +84,7 @@
errNotFound = verror.NoExistf("binary not found")
errInProgress = verror.Internalf("identical upload already in progress")
errInvalidParts = verror.BadArgf("invalid number of binary parts")
+ errInvalidPart = verror.BadArgf("invalid binary part number")
errOperationFailed = verror.Internalf("operation failed")
)
@@ -96,10 +95,8 @@
Size: binary.MissingSize,
}
-// newInvoker is the invoker factory.
-func newInvoker(state *state, suffix string) *invoker {
- // Generate the local filesystem path for the object identified by
- // the object name suffix.
+// dir generates the local filesystem path for the binary identified by suffix.
+func dir(suffix string, state *state) string {
h := md5.New()
h.Write([]byte(suffix))
hash := hex.EncodeToString(h.Sum(nil))
@@ -107,9 +104,13 @@
for j := 0; j < state.depth; j++ {
dir = filepath.Join(dir, hash[j*2:(j+1)*2])
}
- path := filepath.Join(state.root, dir, hash)
+ return filepath.Join(state.root, dir, hash)
+}
+
+// newInvoker is the invoker factory.
+func newInvoker(state *state, suffix string) *invoker {
return &invoker{
- path: path,
+ path: dir(suffix, state),
state: state,
suffix: suffix,
}
@@ -119,18 +120,26 @@
const bufferLength = 4096
-// checksumExists checks whether the given path contains a
-// checksum. The implementation uses the existence of checksum to
-// determine whether the binary (part) identified by the given path
+// checksumExists checks whether the given part path is valid and
+// contains a checksum. The implementation uses the existence of
+// the path dir to determine whether the part is valid, and the
+// existence of checksum to determine whether the binary part
// exists.
-func (i *invoker) checksumExists(path string) error {
+func checksumExists(path string) error {
+ switch _, err := os.Stat(path); {
+ case os.IsNotExist(err):
+ return errInvalidPart
+ case err != nil:
+ vlog.Errorf("Stat(%v) failed: %v", path, err)
+ return errOperationFailed
+ }
checksumFile := filepath.Join(path, checksum)
_, err := os.Stat(checksumFile)
switch {
case os.IsNotExist(err):
return errNotFound
case err != nil:
- vlog.Errorf("Stat(%v) failed: %v", path, err)
+ vlog.Errorf("Stat(%v) failed: %v", checksumFile, err)
return errOperationFailed
default:
return nil
@@ -139,30 +148,38 @@
// generatePartPath generates a path for the given binary part.
func (i *invoker) generatePartPath(part int) string {
- return filepath.Join(i.path, fmt.Sprintf("%d", part))
+ return generatePartPath(i.path, part)
+}
+
+func generatePartPath(dir string, part int) string {
+ return filepath.Join(dir, fmt.Sprintf("%d", part))
}
// getParts returns a collection of paths to the parts of the binary.
-func (i *invoker) getParts() ([]string, error) {
- infos, err := ioutil.ReadDir(i.path)
+func getParts(path string) ([]string, error) {
+ infos, err := ioutil.ReadDir(path)
if err != nil {
- vlog.Errorf("ReadDir(%v) failed: %v", i.path, err)
+ vlog.Errorf("ReadDir(%v) failed: %v", path, err)
return []string{}, errOperationFailed
}
- n := 0
result := make([]string, len(infos))
for _, info := range infos {
if info.IsDir() {
- idx, err := strconv.Atoi(info.Name())
+ partName := info.Name()
+ idx, err := strconv.Atoi(partName)
if err != nil {
- vlog.Errorf("Atoi(%v) failed: %v", info.Name(), err)
+ vlog.Errorf("Atoi(%v) failed: %v", partName, err)
return []string{}, errOperationFailed
}
- result[idx] = filepath.Join(i.path, info.Name())
- n++
+ if idx < 0 || idx >= len(infos) || result[idx] != "" {
+ return []string{}, errOperationFailed
+ }
+ result[idx] = filepath.Join(path, partName)
+ } else {
+ // The only entries should correspond to the part dirs.
+ return []string{}, errOperationFailed
}
}
- result = result[:n]
return result, nil
}
@@ -183,7 +200,7 @@
return errOperationFailed
}
for j := 0; j < int(nparts); j++ {
- partPath, partPerm := filepath.Join(tmpDir, fmt.Sprintf("%d", j)), os.FileMode(0700)
+ partPath, partPerm := generatePartPath(tmpDir, j), os.FileMode(0700)
if err := os.MkdirAll(partPath, partPerm); err != nil {
vlog.Errorf("MkdirAll(%v, %v) failed: %v", partPath, partPerm, err)
if err := os.RemoveAll(tmpDir); err != nil {
@@ -250,13 +267,13 @@
func (i *invoker) Download(context ipc.ServerContext, part int32, stream repository.BinaryServiceDownloadStream) error {
vlog.Infof("%v.Download(%v)", i.suffix, part)
path := i.generatePartPath(int(part))
- if err := i.checksumExists(path); err != nil {
+ if err := checksumExists(path); err != nil {
return err
}
dataPath := filepath.Join(path, data)
file, err := os.Open(dataPath)
if err != nil {
- vlog.Errorf("Open(%v) failed: %v", path, err)
+ vlog.Errorf("Open(%v) failed: %v", dataPath, err)
return errOperationFailed
}
defer file.Close()
@@ -288,7 +305,7 @@
func (i *invoker) Stat(ipc.ServerContext) ([]binary.PartInfo, error) {
vlog.Infof("%v.Stat()", i.suffix)
result := make([]binary.PartInfo, 0)
- parts, err := i.getParts()
+ parts, err := getParts(i.path)
if err != nil {
return []binary.PartInfo{}, err
}
@@ -321,7 +338,7 @@
func (i *invoker) Upload(context ipc.ServerContext, part int32, stream repository.BinaryServiceUploadStream) error {
vlog.Infof("%v.Upload(%v)", i.suffix, part)
path, suffix := i.generatePartPath(int(part)), ""
- err := i.checksumExists(path)
+ err := checksumExists(path)
switch err {
case nil:
return errExists
diff --git a/services/mgmt/binary/impl/merge_file/merge_file.go b/services/mgmt/binary/impl/merge_file/merge_file.go
new file mode 100644
index 0000000..06dd1a2
--- /dev/null
+++ b/services/mgmt/binary/impl/merge_file/merge_file.go
@@ -0,0 +1,184 @@
+// merge_file provides an implementation for http.File that merges
+// several files into one logical file.
+package merge_file
+
+// TODO(caprita): rename this package to multipart, and the constructor to
+// NewFile. Usage: f, err := multipart.NewFile(...).
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "time"
+)
+
+var internalErr = fmt.Errorf("internal error")
+
+// NewMergeFile creates the "merge" file out of the provided parts.
+// The sizes of the parts are captured at the outset and not updated
+// for the lifetime of the merge file (any subsequent modifications
+// in the parts will cause Read and Seek to work incorrectly).
+func NewMergeFile(name string, parts []*os.File) (http.File, error) {
+ fileParts := make([]filePart, len(parts))
+ for i, p := range parts {
+ stat, err := p.Stat()
+ if err != nil {
+ return nil, err
+ }
+ size := stat.Size()
+ // TODO(caprita): we can relax this restriction later.
+ if size == 0 {
+ return nil, fmt.Errorf("Part is empty")
+ }
+ fileParts[i] = filePart{file: p, size: size}
+ }
+ return &mergeFile{name: name, parts: fileParts}, nil
+}
+
+type filePart struct {
+ file *os.File
+ size int64
+}
+
+type mergeFile struct {
+ name string
+ parts []filePart
+ activePart int
+ partOffset int64
+}
+
+func (m *mergeFile) currPos() (res int64) {
+ for i := 0; i < m.activePart; i++ {
+ res += m.parts[i].size
+ }
+ res += m.partOffset
+ return
+}
+
+func (m *mergeFile) totalSize() (res int64) {
+ for _, p := range m.parts {
+ res += p.size
+ }
+ return
+}
+
+// Readdir is not implemented.
+func (*mergeFile) Readdir(int) ([]os.FileInfo, error) {
+ return nil, fmt.Errorf("Not implemented")
+}
+
+type fileInfo struct {
+ name string
+ size int64
+ mode os.FileMode
+ modTime time.Time
+}
+
+// Name returns the name of the merge file.
+func (f *fileInfo) Name() string {
+ return f.name
+}
+
+// Size returns the size of the merge file (the sum of all parts).
+func (f *fileInfo) Size() int64 {
+ return f.size
+}
+
+// Mode is currently hardcoded to 0700.
+func (f *fileInfo) Mode() os.FileMode {
+ return f.mode
+}
+
+// ModTime is set to the current time.
+func (f *fileInfo) ModTime() time.Time {
+ return f.modTime
+}
+
+// IsDir always returns false.
+func (f *fileInfo) IsDir() bool {
+ return false
+}
+
+// Sys always returns nil.
+func (f *fileInfo) Sys() interface{} {
+ return nil
+}
+
+// Stat describes the merge file.
+func (m *mergeFile) Stat() (os.FileInfo, error) {
+ return &fileInfo{
+ name: m.name,
+ size: m.totalSize(),
+ mode: 0700,
+ modTime: time.Now(),
+ }, nil
+}
+
+// Close closes all the parts.
+func (m *mergeFile) Close() error {
+ var lastErr error
+ for _, p := range m.parts {
+ if err := p.file.Close(); err != nil {
+ lastErr = err
+ }
+ }
+ return lastErr
+}
+
+// Read reads from the parts in sequence.
+func (m *mergeFile) Read(buf []byte) (int, error) {
+ if m.activePart >= len(m.parts) {
+ return 0, io.EOF
+ }
+ p := m.parts[m.activePart]
+ n, err := p.file.Read(buf)
+ m.partOffset += int64(n)
+ if m.partOffset > p.size {
+ // Likely, the file has changed.
+ return 0, internalErr
+ }
+ if m.partOffset == p.size {
+ m.activePart++
+ if m.activePart < len(m.parts) {
+ if _, err := m.parts[m.activePart].file.Seek(0, 0); err != nil {
+ return 0, err
+ }
+ m.partOffset = 0
+ }
+ }
+ return n, err
+}
+
+// Seek seeks into the part corresponding to the global offset.
+func (m *mergeFile) Seek(offset int64, whence int) (int64, error) {
+ var target int64
+ switch whence {
+ case 0:
+ target = offset
+ case 1:
+ target = m.currPos() + offset
+ case 2:
+ target = m.totalSize() - offset
+ default:
+ return 0, fmt.Errorf("invalid whence: %d", whence)
+ }
+ if target < 0 || target > m.totalSize() {
+ return 0, fmt.Errorf("invalid offset")
+ }
+ var c int64
+ for i, p := range m.parts {
+ if pSize := p.size; c+pSize <= target {
+ c += pSize
+ continue
+ }
+ m.activePart = i
+ if _, err := p.file.Seek(target-c, 0); err != nil {
+ return 0, err
+ }
+ m.partOffset = target - c
+ return target, nil
+ }
+ // target <= m.totalSize() should ensure this is never reached.
+ return 0, internalErr // Should not be reached.
+}
diff --git a/services/mgmt/binary/impl/merge_file/merge_file_test.go b/services/mgmt/binary/impl/merge_file/merge_file_test.go
new file mode 100644
index 0000000..1d903a1
--- /dev/null
+++ b/services/mgmt/binary/impl/merge_file/merge_file_test.go
@@ -0,0 +1,147 @@
+package merge_file_test
+
+import (
+ "io"
+ "io/ioutil"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+
+ "veyron.io/veyron/veyron/services/mgmt/binary/impl/merge_file"
+)
+
+func read(t *testing.T, m http.File, thisMuch int) string {
+ buf := make([]byte, thisMuch)
+ bytesRead := 0
+ for {
+ n, err := m.Read(buf[bytesRead:])
+ bytesRead += n
+ if bytesRead == thisMuch {
+ return string(buf)
+ }
+ switch err {
+ case nil:
+ case io.EOF:
+ return string(buf[:bytesRead])
+ default:
+ t.Fatalf("Read failed: %v", err)
+ }
+ }
+}
+
+// TestMergeFile verifies the http.File operations on the merge file.
+func TestMergeFile(t *testing.T) {
+ contents := []string{"v", "is", "for", "vanadium"}
+ files := make([]*os.File, len(contents))
+ d, err := ioutil.TempDir("", "merge_files")
+ if err != nil {
+ t.Fatalf("TempDir() failed: %v", err)
+ }
+ defer os.RemoveAll(d)
+ contentsSize := 0
+ for i, c := range contents {
+ contentsSize += len(c)
+ fPath := filepath.Join(d, strconv.Itoa(i))
+ if err := ioutil.WriteFile(fPath, []byte(c), 0600); err != nil {
+ t.Fatalf("WriteFile(%v) failed: %v", fPath, err)
+ }
+ var err error
+ if files[i], err = os.Open(fPath); err != nil {
+ t.Fatalf("Open(%v) failed: %v", fPath, err)
+ }
+ }
+ m, err := merge_file.NewMergeFile("bunnies", files)
+ if err != nil {
+ t.Fatalf("newMergeFile failed: %v", err)
+ }
+ defer func() {
+ if err := m.Close(); err != nil {
+ t.Fatalf("Close failed: %v", err)
+ }
+ }()
+ stat, err := m.Stat()
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+ if want, got := "bunnies", stat.Name(); want != got {
+ t.Fatalf("Name returned %s, expected %s", got, want)
+ }
+ if want, got := int64(contentsSize), stat.Size(); want != got {
+ t.Fatalf("Size returned %d, expected %d", got, want)
+ }
+ if want, got := strings.Join(contents, ""), read(t, m, 1024); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ if want, got := "", read(t, m, 1024); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ if pos, err := m.Seek(0, 0); err != nil {
+ t.Fatalf("Seek failed: %v", err)
+ } else if want, got := int64(0), pos; want != got {
+ t.Fatalf("Pos is %d, wanted %d", got, want)
+ }
+ if want, got := strings.Join(contents, ""), read(t, m, 1024); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ if pos, err := m.Seek(0, 0); err != nil {
+ t.Fatalf("Seek failed: %v", err)
+ } else if want, got := int64(0), pos; want != got {
+ t.Fatalf("Pos is %d, wanted %d", got, want)
+ }
+ for _, c := range contents {
+ if want, got := c, read(t, m, len(c)); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ }
+ if want, got := "", read(t, m, 1024); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ if pos, err := m.Seek(1, 0); err != nil {
+ t.Fatalf("Seek failed: %v", err)
+ } else if want, got := int64(1), pos; want != got {
+ t.Fatalf("Pos is %d, wanted %d", got, want)
+ }
+ if want, got := "isfo", read(t, m, 4); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ if pos, err := m.Seek(2, 1); err != nil {
+ t.Fatalf("Seek failed: %v", err)
+ } else if want, got := int64(7), pos; want != got {
+ t.Fatalf("Pos is %d, wanted %d", got, want)
+ }
+ if want, got := "anadi", read(t, m, 5); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ if _, err := m.Seek(100, 1); err == nil {
+ t.Fatalf("Seek expected to fail")
+ }
+ if want, got := "u", read(t, m, 1); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ if pos, err := m.Seek(8, 2); err != nil {
+ t.Fatalf("Seek failed: %v", err)
+ } else if want, got := int64(6), pos; want != got {
+ t.Fatalf("Pos is %d, wanted %d", got, want)
+ }
+ if want, got := "vanad", read(t, m, 5); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+ if _, err := m.Seek(100, 2); err == nil {
+ t.Fatalf("Seek expected to fail")
+ }
+ if pos, err := m.Seek(9, 2); err != nil {
+ t.Fatalf("Seek failed: %v", err)
+ } else if want, got := int64(5), pos; want != got {
+ t.Fatalf("Pos is %d, wanted %d", got, want)
+ }
+ if want, got := "rvana", read(t, m, 5); want != got {
+ t.Fatalf("Read %v, wanted %v instead", got, want)
+ }
+
+ // TODO(caprita): Add some auto-generated test cases where we seek/read
+ // using various combinations of indices. These can be exhaustive or
+ // randomized, the idea is to get better coverage.
+}
diff --git a/services/mgmt/lib/binary/impl_test.go b/services/mgmt/lib/binary/impl_test.go
index ae29582..8ec796f 100644
--- a/services/mgmt/lib/binary/impl_test.go
+++ b/services/mgmt/lib/binary/impl_test.go
@@ -41,10 +41,11 @@
t.Fatalf("NewServer() failed: %v", err)
}
depth := 2
- dispatcher, err := impl.NewDispatcher(root, depth, nil)
+ state, err := impl.NewState(root, depth)
if err != nil {
- t.Fatalf("NewDispatcher(%v, %v, %v) failed: %v", root, depth, nil, err)
+ t.Fatalf("NewState(%v, %v) failed: %v", root, depth, err)
}
+ dispatcher := impl.NewDispatcher(state, nil)
endpoint, err := server.Listen(profiles.LocalListenSpec)
if err != nil {
t.Fatalf("Listen(%s) failed: %v", profiles.LocalListenSpec, err)
diff --git a/services/mounttable/lib/mounttable.go b/services/mounttable/lib/mounttable.go
index c6e659f..f13dd72 100644
--- a/services/mounttable/lib/mounttable.go
+++ b/services/mounttable/lib/mounttable.go
@@ -275,7 +275,7 @@
// Make sure the server name is reasonable.
epString, _ := naming.SplitAddressName(server)
- ep, err := rt.R().NewEndpoint(epString)
+ _, err := rt.R().NewEndpoint(epString)
if err != nil {
return fmt.Errorf("malformed address %q for mounted server %q", epString, server)
}
@@ -290,7 +290,9 @@
if hasReplaceFlag(flags) {
n.mount = nil
}
- wantMT := hasMTFlag(flags) || ep.ServesMountTable()
+ // TODO(p): When the endpoint actually has the ServesMountTable bit,
+ // or this with ep.ServesMountTable().
+ wantMT := hasMTFlag(flags)
if n.mount == nil {
n.mount = &mount{servers: NewServerList(), mt: wantMT}
} else {
@@ -384,7 +386,11 @@
n.removeUseless()
return
}
- sender.Send(types.MountEntry{Name: name, Servers: m.servers.copyToSlice()})
+ sender.Send(
+ types.MountEntry{
+ Name: name, Servers: m.servers.copyToSlice(),
+ MT: n.mount.mt,
+ })
return
}