wspr: Share the type stream for server messages.

MultiPart: 2/2
Change-Id: I971fe9521b30c4e42cb826da6e53ed2890154fd1
diff --git a/services/wspr/internal/app/app.go b/services/wspr/internal/app/app.go
index f26d773..2c89518 100644
--- a/services/wspr/internal/app/app.go
+++ b/services/wspr/internal/app/app.go
@@ -981,3 +981,7 @@
 func (c *Controller) TypeEncoder() *vom.TypeEncoder {
 	return c.typeEncoder
 }
+
+func (c *Controller) TypeDecoder() *vom.TypeDecoder {
+	return c.typeDecoder
+}
diff --git a/services/wspr/internal/app/app_test.go b/services/wspr/internal/app/app_test.go
index b71c6ff..b4007ac 100644
--- a/services/wspr/internal/app/app_test.go
+++ b/services/wspr/internal/app/app_test.go
@@ -323,7 +323,6 @@
 	writer           *testwriter.Writer
 	mounttableServer rpc.Server
 	proxyShutdown    func()
-	typeStream       *typeWriter
 	typeEncoder      *vom.TypeEncoder
 }
 
@@ -341,6 +340,15 @@
 	return hex.EncodeToString(buf.Bytes()), nil
 }
 
+type typeEncoderWriter struct {
+	c *Controller
+}
+
+func (t *typeEncoderWriter) Write(p []byte) (int, error) {
+	t.c.HandleTypeMessage(hex.EncodeToString(p))
+	return len(p), nil
+}
+
 func serveServer(ctx *context.T, writer lib.ClientWriter, setController func(*Controller)) (*runningTest, error) {
 	mounttableServer, endpoint, err := startMountTableServer(ctx)
 	if err != nil {
@@ -377,7 +385,7 @@
 	}
 
 	v23.GetNamespace(controller.Context()).SetRoots("/" + endpoint.String())
-	typeStream := &typeWriter{}
+	typeStream := &typeEncoderWriter{c: controller}
 	typeEncoder := vom.NewTypeEncoder(typeStream)
 	req, err := makeRequest(typeEncoder, RpcRequest{
 		Name:       "__controller",
@@ -387,21 +395,12 @@
 		Deadline:   vdltime.Deadline{},
 	}, "adder", 0, []RpcServerOption{})
 
-	for _, r := range typeStream.resps {
-		b, err := base64.StdEncoding.DecodeString(r.Message.(string))
-		if err != nil {
-			panic(err)
-		}
-
-		controller.HandleTypeMessage(hex.EncodeToString(b))
-	}
-	typeStream.resps = []lib.Response{}
 	controller.HandleVeyronRequest(ctx, 0, req, writer)
 
 	testWriter, _ := writer.(*testwriter.Writer)
 	return &runningTest{
 		controller, testWriter, mounttableServer, proxyShutdown,
-		typeStream, typeEncoder,
+		typeEncoder,
 	}, nil
 }
 
@@ -455,6 +454,8 @@
 	rt, err := serveServer(ctx, mock, func(controller *Controller) {
 		mock.controller = controller
 	})
+
+	mock.typeEncoder = rt.typeEncoder
 	defer rt.mounttableServer.Stop()
 	defer rt.proxyShutdown()
 	defer rt.controller.Cleanup()
diff --git a/services/wspr/internal/app/mock_jsServer_test.go b/services/wspr/internal/app/mock_jsServer_test.go
index 27891d2..196b119 100644
--- a/services/wspr/internal/app/mock_jsServer_test.go
+++ b/services/wspr/internal/app/mock_jsServer_test.go
@@ -48,6 +48,8 @@
 
 	typeReader  *lib.TypeReader
 	typeDecoder *vom.TypeDecoder
+
+	typeEncoder *vom.TypeEncoder
 }
 
 func (m *mockJSServer) Send(responseType lib.ResponseType, msg interface{}) error {
@@ -80,7 +82,7 @@
 	return fmt.Errorf("Unknown message type: %d", responseType)
 }
 
-func internalErr(args interface{}) string {
+func internalErr(args interface{}, typeEncoder *vom.TypeEncoder) string {
 	err := verror.E{
 		ID:        verror.ID("v.io/v23/verror.Internal"),
 		Action:    verror.ActionCode(0),
@@ -89,7 +91,7 @@
 
 	return lib.HexVomEncodeOrDie(server.LookupReply{
 		Err: err,
-	}, nil)
+	}, typeEncoder)
 }
 
 func (m *mockJSServer) Error(err error) {
@@ -124,19 +126,19 @@
 
 	msg, err := normalize(v)
 	if err != nil {
-		m.controller.HandleLookupResponse(m.flowCount, internalErr(err))
+		m.controller.HandleLookupResponse(m.flowCount, internalErr(err, m.typeEncoder))
 		return nil
 	}
 	expected := map[string]interface{}{"serverId": 0.0, "suffix": "adder"}
 	if !reflect.DeepEqual(msg, expected) {
-		m.controller.HandleLookupResponse(m.flowCount, internalErr(fmt.Sprintf("got: %v, want: %v", msg, expected)))
+		m.controller.HandleLookupResponse(m.flowCount, internalErr(fmt.Sprintf("got: %v, want: %v", msg, expected), m.typeEncoder))
 		return nil
 	}
 	lookupReply := lib.HexVomEncodeOrDie(server.LookupReply{
 		Handle:        0,
 		Signature:     m.serviceSignature,
 		HasAuthorizer: m.hasAuthorizer,
-	}, nil)
+	}, m.typeEncoder)
 	m.controller.HandleLookupResponse(m.flowCount, lookupReply)
 	return nil
 }
@@ -152,56 +154,56 @@
 
 	m.hasCalledAuth = true
 	if !m.hasAuthorizer {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr("unexpected auth request"))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr("unexpected auth request", m.typeEncoder))
 		return nil
 	}
 
 	var msg server.AuthRequest
-	if err := lib.HexVomDecode(v.(string), &msg, nil); err != nil {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("error decoding %v:", err)))
+	if err := lib.HexVomDecode(v.(string), &msg, m.typeDecoder); err != nil {
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("error decoding %v:", err), m.typeEncoder))
 		return nil
 	}
 
 	if msg.Handle != 0 {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected handled: %v", msg.Handle)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected handled: %v", msg.Handle), m.typeEncoder))
 		return nil
 	}
 
 	call := msg.Call
 	if field, got, want := "Method", call.Method, lib.LowercaseFirstCharacter(m.method); got != want {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
 		return nil
 	}
 
 	if field, got, want := "Suffix", call.Suffix, "adder"; got != want {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
 		return nil
 	}
 
 	// We expect localBlessings and remoteBlessings to be a non-zero id
 	if call.LocalBlessings == 0 {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad local blessing: %v", call.LocalBlessings)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad local blessing: %v", call.LocalBlessings), m.typeEncoder))
 		return nil
 	}
 	if call.RemoteBlessings == 0 {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad remote blessing: %v", call.RemoteBlessings)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad remote blessing: %v", call.RemoteBlessings), m.typeEncoder))
 		return nil
 	}
 
 	// We expect endpoints to be set
 	if !validateEndpoint(call.LocalEndpoint) {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad endpoint:%v", call.LocalEndpoint)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad endpoint:%v", call.LocalEndpoint), m.typeEncoder))
 		return nil
 	}
 
 	if !validateEndpoint(call.RemoteEndpoint) {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad endpoint:%v", call.RemoteEndpoint)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad endpoint:%v", call.RemoteEndpoint), m.typeEncoder))
 		return nil
 	}
 
 	authReply := lib.HexVomEncodeOrDie(server.AuthReply{
 		Err: m.authError,
-	}, nil)
+	}, m.typeEncoder)
 
 	m.controller.HandleAuthResponse(m.flowCount, authReply)
 	return nil
@@ -213,23 +215,23 @@
 	}()
 
 	if m.hasCalledAuth != m.hasAuthorizer {
-		m.controller.HandleServerResponse(m.flowCount, internalErr("authorizer hasn't been called yet"))
+		m.controller.HandleServerResponse(m.flowCount, internalErr("authorizer hasn't been called yet", m.typeEncoder))
 		return nil
 	}
 
 	var msg server.ServerRpcRequest
-	if err := lib.HexVomDecode(v.(string), &msg, nil); err != nil {
-		m.controller.HandleServerResponse(m.flowCount, internalErr(err))
+	if err := lib.HexVomDecode(v.(string), &msg, m.typeDecoder); err != nil {
+		m.controller.HandleServerResponse(m.flowCount, internalErr(err, m.typeEncoder))
 		return nil
 	}
 
 	if field, got, want := "Method", msg.Method, lib.LowercaseFirstCharacter(m.method); got != want {
-		m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+		m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
 		return nil
 	}
 
 	if field, got, want := "Handle", msg.Handle, int32(0); got != want {
-		m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+		m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
 		return nil
 	}
 
@@ -240,23 +242,23 @@
 		}
 	}
 	if field, got, want := "Args", vals, m.inArgs; !reflectutil.DeepEqual(got, want, &reflectutil.DeepEqualOpts{SliceEqNilEmpty: true}) {
-		m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+		m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
 		return nil
 	}
 
 	call := msg.Call.SecurityCall
 	if field, got, want := "Suffix", call.Suffix, "adder"; got != want {
-		m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+		m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
 		return nil
 	}
 
 	// We expect localBlessings and remoteBlessings to be a non-zero id
 	if call.LocalBlessings == 0 {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad local blessing: %v", call.LocalBlessings)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad local blessing: %v", call.LocalBlessings), m.typeEncoder))
 		return nil
 	}
 	if call.RemoteBlessings == 0 {
-		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad remote blessing: %v", call.RemoteBlessings)))
+		m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad remote blessing: %v", call.RemoteBlessings), m.typeEncoder))
 		return nil
 	}
 
@@ -278,14 +280,10 @@
 		Results: make([]error, len(req.Cavs)),
 	}
 
-	var b bytes.Buffer
-	enc := vom.NewEncoder(&b)
-	if err := enc.Encode(resp); err != nil {
-		panic(err)
-	}
+	res := lib.HexVomEncodeOrDie(resp, m.typeEncoder)
 
 	m.controllerReady.RLock()
-	m.controller.HandleCaveatValidationResponse(m.flowCount, fmt.Sprintf("%x", b.Bytes()))
+	m.controller.HandleCaveatValidationResponse(m.flowCount, res)
 	m.controllerReady.RUnlock()
 	return nil
 }
@@ -294,7 +292,7 @@
 	defer m.sender.Done()
 	m.controllerReady.RLock()
 	for _, v := range m.serverStream {
-		m.controller.SendOnStream(m.rpcFlow, lib.HexVomEncodeOrDie(v, nil), m)
+		m.controller.SendOnStream(m.rpcFlow, lib.HexVomEncodeOrDie(v, m.typeEncoder), m)
 	}
 	m.controllerReady.RUnlock()
 }
@@ -322,7 +320,7 @@
 	}
 
 	m.controllerReady.RLock()
-	m.controller.HandleServerResponse(m.rpcFlow, lib.HexVomEncodeOrDie(reply, nil))
+	m.controller.HandleServerResponse(m.rpcFlow, lib.HexVomEncodeOrDie(reply, m.typeEncoder))
 	m.controllerReady.RUnlock()
 	return nil
 }
diff --git a/services/wspr/internal/rpc/server/dispatcher.go b/services/wspr/internal/rpc/server/dispatcher.go
index cb1c457..1b49fea 100644
--- a/services/wspr/internal/rpc/server/dispatcher.go
+++ b/services/wspr/internal/rpc/server/dispatcher.go
@@ -41,19 +41,21 @@
 	invokerFactory     invokerFactory
 	authFactory        authFactory
 	outstandingLookups map[int32]chan LookupReply
+	vomHelper          VomHelper
 	closed             bool
 }
 
 var _ rpc.Dispatcher = (*dispatcher)(nil)
 
 // newDispatcher is a dispatcher factory.
-func newDispatcher(serverId uint32, flowFactory flowFactory, invokerFactory invokerFactory, authFactory authFactory) *dispatcher {
+func newDispatcher(serverId uint32, flowFactory flowFactory, invokerFactory invokerFactory, authFactory authFactory, vomHelper VomHelper) *dispatcher {
 	return &dispatcher{
 		serverId:           serverId,
 		flowFactory:        flowFactory,
 		invokerFactory:     invokerFactory,
 		authFactory:        authFactory,
 		outstandingLookups: make(map[int32]chan LookupReply),
+		vomHelper:          vomHelper,
 	}
 }
 
@@ -128,7 +130,7 @@
 	}
 
 	var lookupReply LookupReply
-	if err := lib.HexVomDecode(data, &lookupReply, nil); err != nil {
+	if err := lib.HexVomDecode(data, &lookupReply, d.vomHelper.TypeDecoder()); err != nil {
 		err2 := verror.Convert(verror.ErrInternal, nil, err)
 		lookupReply = LookupReply{Err: err2}
 		vlog.Errorf("unmarshaling invoke request failed: %v, %s", err, data)
diff --git a/services/wspr/internal/rpc/server/dispatcher_test.go b/services/wspr/internal/rpc/server/dispatcher_test.go
index 0e6615b..b58bc80 100644
--- a/services/wspr/internal/rpc/server/dispatcher_test.go
+++ b/services/wspr/internal/rpc/server/dispatcher_test.go
@@ -15,10 +15,21 @@
 	"v.io/v23/vdl"
 	"v.io/v23/vdlroot/signature"
 	"v.io/v23/verror"
+	"v.io/v23/vom"
 	"v.io/x/ref/services/wspr/internal/lib"
 	"v.io/x/ref/services/wspr/internal/lib/testwriter"
 )
 
+type mockVomHelper struct{}
+
+func (mockVomHelper) TypeEncoder() *vom.TypeEncoder {
+	return nil
+}
+
+func (mockVomHelper) TypeDecoder() *vom.TypeDecoder {
+	return nil
+}
+
 type mockFlowFactory struct {
 	writer testwriter.Writer
 }
@@ -80,7 +91,7 @@
 
 func TestSuccessfulLookup(t *testing.T) {
 	flowFactory := &mockFlowFactory{}
-	d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{})
+	d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{}, mockVomHelper{})
 	expectedSig := []signature.Interface{
 		{Name: "AName"},
 	}
@@ -129,7 +140,7 @@
 
 func TestSuccessfulLookupWithAuthorizer(t *testing.T) {
 	flowFactory := &mockFlowFactory{}
-	d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{})
+	d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{}, mockVomHelper{})
 	expectedSig := []signature.Interface{
 		{Name: "AName"},
 	}
@@ -178,7 +189,7 @@
 
 func TestFailedLookup(t *testing.T) {
 	flowFactory := &mockFlowFactory{}
-	d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{})
+	d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{}, mockVomHelper{})
 	go func() {
 		if err := flowFactory.writer.WaitForMessage(1); err != nil {
 			t.Errorf("failed to get dispatch request %v", err)
diff --git a/services/wspr/internal/rpc/server/server.go b/services/wspr/internal/rpc/server/server.go
index f331764..12a6385 100644
--- a/services/wspr/internal/rpc/server/server.go
+++ b/services/wspr/internal/rpc/server/server.go
@@ -45,15 +45,19 @@
 	GetOrAddBlessingsHandle(blessings security.Blessings) principal.BlessingsHandle
 }
 
+type VomHelper interface {
+	TypeEncoder() *vom.TypeEncoder
+
+	TypeDecoder() *vom.TypeDecoder
+}
 type ServerHelper interface {
 	FlowHandler
 	HandleStore
+	VomHelper
 
 	SendLogMessage(level lib.LogLevel, msg string) error
 	BlessingsCache() *principal.BlessingsCache
 
-	TypeEncoder() *vom.TypeEncoder
-
 	Context() *context.T
 }
 
@@ -180,7 +184,7 @@
 			Args:     vdlValArgs,
 			Call:     rpcCall,
 		}
-		vomMessage, err := lib.HexVomEncode(message, nil)
+		vomMessage, err := lib.HexVomEncode(message, s.helper.TypeEncoder())
 		if err != nil {
 			return errHandler(err)
 		}
@@ -289,7 +293,7 @@
 			Args:     []*vdl.Value{vdl.ValueOf(pattern)},
 			Call:     rpcCall,
 		}
-		vomMessage, err := lib.HexVomEncode(message, nil)
+		vomMessage, err := lib.HexVomEncode(message, s.helper.TypeEncoder())
 		if err != nil {
 			return errHandler(err)
 		}
@@ -578,7 +582,7 @@
 	}
 	vlog.VI(0).Infof("Sending out auth request for %v, %v", flow.ID, message)
 
-	vomMessage, err := lib.HexVomEncode(message, nil)
+	vomMessage, err := lib.HexVomEncode(message, s.helper.TypeEncoder())
 	if err != nil {
 		replyChan <- verror.Convert(verror.ErrInternal, nil, err)
 	} else if err := flow.Writer.Send(lib.ResponseAuthRequest, vomMessage); err != nil {
@@ -630,7 +634,7 @@
 	defer s.serverStateLock.Unlock()
 
 	if s.dispatcher == nil {
-		s.dispatcher = newDispatcher(s.id, s, s, s)
+		s.dispatcher = newDispatcher(s.id, s, s, s, s.helper)
 	}
 
 	if !s.isListening {
@@ -668,7 +672,7 @@
 
 	// Decode the result and send it through the channel
 	var reply lib.ServerRpcReply
-	if err := lib.HexVomDecode(data, &reply, nil); err != nil {
+	if err := lib.HexVomDecode(data, &reply, s.helper.TypeDecoder()); err != nil {
 		reply.Err = err
 	}
 
@@ -710,7 +714,7 @@
 	}
 	// Decode the result and send it through the channel
 	var reply AuthReply
-	if err := lib.HexVomDecode(data, &reply, nil); err != nil {
+	if err := lib.HexVomDecode(data, &reply, s.helper.TypeDecoder()); err != nil {
 		err = verror.Convert(verror.ErrInternal, nil, err)
 		reply = AuthReply{Err: err}
 	}
@@ -740,7 +744,7 @@
 	}
 
 	var reply CaveatValidationResponse
-	if err := lib.HexVomDecode(data, &reply, nil); err != nil {
+	if err := lib.HexVomDecode(data, &reply, s.helper.TypeDecoder()); err != nil {
 		vlog.Errorf("failed to decode validation response %q: error %v", data, err)
 		ch <- []error{}
 		return