wspr: Share the type stream for server messages.
MultiPart: 2/2
Change-Id: I971fe9521b30c4e42cb826da6e53ed2890154fd1
diff --git a/services/wspr/internal/app/app.go b/services/wspr/internal/app/app.go
index f26d773..2c89518 100644
--- a/services/wspr/internal/app/app.go
+++ b/services/wspr/internal/app/app.go
@@ -981,3 +981,7 @@
func (c *Controller) TypeEncoder() *vom.TypeEncoder {
return c.typeEncoder
}
+
+func (c *Controller) TypeDecoder() *vom.TypeDecoder {
+ return c.typeDecoder
+}
diff --git a/services/wspr/internal/app/app_test.go b/services/wspr/internal/app/app_test.go
index b71c6ff..b4007ac 100644
--- a/services/wspr/internal/app/app_test.go
+++ b/services/wspr/internal/app/app_test.go
@@ -323,7 +323,6 @@
writer *testwriter.Writer
mounttableServer rpc.Server
proxyShutdown func()
- typeStream *typeWriter
typeEncoder *vom.TypeEncoder
}
@@ -341,6 +340,15 @@
return hex.EncodeToString(buf.Bytes()), nil
}
+type typeEncoderWriter struct {
+ c *Controller
+}
+
+func (t *typeEncoderWriter) Write(p []byte) (int, error) {
+ t.c.HandleTypeMessage(hex.EncodeToString(p))
+ return len(p), nil
+}
+
func serveServer(ctx *context.T, writer lib.ClientWriter, setController func(*Controller)) (*runningTest, error) {
mounttableServer, endpoint, err := startMountTableServer(ctx)
if err != nil {
@@ -377,7 +385,7 @@
}
v23.GetNamespace(controller.Context()).SetRoots("/" + endpoint.String())
- typeStream := &typeWriter{}
+ typeStream := &typeEncoderWriter{c: controller}
typeEncoder := vom.NewTypeEncoder(typeStream)
req, err := makeRequest(typeEncoder, RpcRequest{
Name: "__controller",
@@ -387,21 +395,12 @@
Deadline: vdltime.Deadline{},
}, "adder", 0, []RpcServerOption{})
- for _, r := range typeStream.resps {
- b, err := base64.StdEncoding.DecodeString(r.Message.(string))
- if err != nil {
- panic(err)
- }
-
- controller.HandleTypeMessage(hex.EncodeToString(b))
- }
- typeStream.resps = []lib.Response{}
controller.HandleVeyronRequest(ctx, 0, req, writer)
testWriter, _ := writer.(*testwriter.Writer)
return &runningTest{
controller, testWriter, mounttableServer, proxyShutdown,
- typeStream, typeEncoder,
+ typeEncoder,
}, nil
}
@@ -455,6 +454,8 @@
rt, err := serveServer(ctx, mock, func(controller *Controller) {
mock.controller = controller
})
+
+ mock.typeEncoder = rt.typeEncoder
defer rt.mounttableServer.Stop()
defer rt.proxyShutdown()
defer rt.controller.Cleanup()
diff --git a/services/wspr/internal/app/mock_jsServer_test.go b/services/wspr/internal/app/mock_jsServer_test.go
index 27891d2..196b119 100644
--- a/services/wspr/internal/app/mock_jsServer_test.go
+++ b/services/wspr/internal/app/mock_jsServer_test.go
@@ -48,6 +48,8 @@
typeReader *lib.TypeReader
typeDecoder *vom.TypeDecoder
+
+ typeEncoder *vom.TypeEncoder
}
func (m *mockJSServer) Send(responseType lib.ResponseType, msg interface{}) error {
@@ -80,7 +82,7 @@
return fmt.Errorf("Unknown message type: %d", responseType)
}
-func internalErr(args interface{}) string {
+func internalErr(args interface{}, typeEncoder *vom.TypeEncoder) string {
err := verror.E{
ID: verror.ID("v.io/v23/verror.Internal"),
Action: verror.ActionCode(0),
@@ -89,7 +91,7 @@
return lib.HexVomEncodeOrDie(server.LookupReply{
Err: err,
- }, nil)
+ }, typeEncoder)
}
func (m *mockJSServer) Error(err error) {
@@ -124,19 +126,19 @@
msg, err := normalize(v)
if err != nil {
- m.controller.HandleLookupResponse(m.flowCount, internalErr(err))
+ m.controller.HandleLookupResponse(m.flowCount, internalErr(err, m.typeEncoder))
return nil
}
expected := map[string]interface{}{"serverId": 0.0, "suffix": "adder"}
if !reflect.DeepEqual(msg, expected) {
- m.controller.HandleLookupResponse(m.flowCount, internalErr(fmt.Sprintf("got: %v, want: %v", msg, expected)))
+ m.controller.HandleLookupResponse(m.flowCount, internalErr(fmt.Sprintf("got: %v, want: %v", msg, expected), m.typeEncoder))
return nil
}
lookupReply := lib.HexVomEncodeOrDie(server.LookupReply{
Handle: 0,
Signature: m.serviceSignature,
HasAuthorizer: m.hasAuthorizer,
- }, nil)
+ }, m.typeEncoder)
m.controller.HandleLookupResponse(m.flowCount, lookupReply)
return nil
}
@@ -152,56 +154,56 @@
m.hasCalledAuth = true
if !m.hasAuthorizer {
- m.controller.HandleAuthResponse(m.flowCount, internalErr("unexpected auth request"))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr("unexpected auth request", m.typeEncoder))
return nil
}
var msg server.AuthRequest
- if err := lib.HexVomDecode(v.(string), &msg, nil); err != nil {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("error decoding %v:", err)))
+ if err := lib.HexVomDecode(v.(string), &msg, m.typeDecoder); err != nil {
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("error decoding %v:", err), m.typeEncoder))
return nil
}
if msg.Handle != 0 {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected handled: %v", msg.Handle)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected handled: %v", msg.Handle), m.typeEncoder))
return nil
}
call := msg.Call
if field, got, want := "Method", call.Method, lib.LowercaseFirstCharacter(m.method); got != want {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
return nil
}
if field, got, want := "Suffix", call.Suffix, "adder"; got != want {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
return nil
}
// We expect localBlessings and remoteBlessings to be a non-zero id
if call.LocalBlessings == 0 {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad local blessing: %v", call.LocalBlessings)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad local blessing: %v", call.LocalBlessings), m.typeEncoder))
return nil
}
if call.RemoteBlessings == 0 {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad remote blessing: %v", call.RemoteBlessings)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad remote blessing: %v", call.RemoteBlessings), m.typeEncoder))
return nil
}
// We expect endpoints to be set
if !validateEndpoint(call.LocalEndpoint) {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad endpoint:%v", call.LocalEndpoint)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad endpoint:%v", call.LocalEndpoint), m.typeEncoder))
return nil
}
if !validateEndpoint(call.RemoteEndpoint) {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad endpoint:%v", call.RemoteEndpoint)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad endpoint:%v", call.RemoteEndpoint), m.typeEncoder))
return nil
}
authReply := lib.HexVomEncodeOrDie(server.AuthReply{
Err: m.authError,
- }, nil)
+ }, m.typeEncoder)
m.controller.HandleAuthResponse(m.flowCount, authReply)
return nil
@@ -213,23 +215,23 @@
}()
if m.hasCalledAuth != m.hasAuthorizer {
- m.controller.HandleServerResponse(m.flowCount, internalErr("authorizer hasn't been called yet"))
+ m.controller.HandleServerResponse(m.flowCount, internalErr("authorizer hasn't been called yet", m.typeEncoder))
return nil
}
var msg server.ServerRpcRequest
- if err := lib.HexVomDecode(v.(string), &msg, nil); err != nil {
- m.controller.HandleServerResponse(m.flowCount, internalErr(err))
+ if err := lib.HexVomDecode(v.(string), &msg, m.typeDecoder); err != nil {
+ m.controller.HandleServerResponse(m.flowCount, internalErr(err, m.typeEncoder))
return nil
}
if field, got, want := "Method", msg.Method, lib.LowercaseFirstCharacter(m.method); got != want {
- m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+ m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
return nil
}
if field, got, want := "Handle", msg.Handle, int32(0); got != want {
- m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+ m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
return nil
}
@@ -240,23 +242,23 @@
}
}
if field, got, want := "Args", vals, m.inArgs; !reflectutil.DeepEqual(got, want, &reflectutil.DeepEqualOpts{SliceEqNilEmpty: true}) {
- m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+ m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
return nil
}
call := msg.Call.SecurityCall
if field, got, want := "Suffix", call.Suffix, "adder"; got != want {
- m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want)))
+ m.controller.HandleServerResponse(m.flowCount, internalErr(fmt.Sprintf("unexpected value for %s: got %v, want %v", field, got, want), m.typeEncoder))
return nil
}
// We expect localBlessings and remoteBlessings to be a non-zero id
if call.LocalBlessings == 0 {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad local blessing: %v", call.LocalBlessings)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad local blessing: %v", call.LocalBlessings), m.typeEncoder))
return nil
}
if call.RemoteBlessings == 0 {
- m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad remote blessing: %v", call.RemoteBlessings)))
+ m.controller.HandleAuthResponse(m.flowCount, internalErr(fmt.Sprintf("bad remote blessing: %v", call.RemoteBlessings), m.typeEncoder))
return nil
}
@@ -278,14 +280,10 @@
Results: make([]error, len(req.Cavs)),
}
- var b bytes.Buffer
- enc := vom.NewEncoder(&b)
- if err := enc.Encode(resp); err != nil {
- panic(err)
- }
+ res := lib.HexVomEncodeOrDie(resp, m.typeEncoder)
m.controllerReady.RLock()
- m.controller.HandleCaveatValidationResponse(m.flowCount, fmt.Sprintf("%x", b.Bytes()))
+ m.controller.HandleCaveatValidationResponse(m.flowCount, res)
m.controllerReady.RUnlock()
return nil
}
@@ -294,7 +292,7 @@
defer m.sender.Done()
m.controllerReady.RLock()
for _, v := range m.serverStream {
- m.controller.SendOnStream(m.rpcFlow, lib.HexVomEncodeOrDie(v, nil), m)
+ m.controller.SendOnStream(m.rpcFlow, lib.HexVomEncodeOrDie(v, m.typeEncoder), m)
}
m.controllerReady.RUnlock()
}
@@ -322,7 +320,7 @@
}
m.controllerReady.RLock()
- m.controller.HandleServerResponse(m.rpcFlow, lib.HexVomEncodeOrDie(reply, nil))
+ m.controller.HandleServerResponse(m.rpcFlow, lib.HexVomEncodeOrDie(reply, m.typeEncoder))
m.controllerReady.RUnlock()
return nil
}
diff --git a/services/wspr/internal/rpc/server/dispatcher.go b/services/wspr/internal/rpc/server/dispatcher.go
index cb1c457..1b49fea 100644
--- a/services/wspr/internal/rpc/server/dispatcher.go
+++ b/services/wspr/internal/rpc/server/dispatcher.go
@@ -41,19 +41,21 @@
invokerFactory invokerFactory
authFactory authFactory
outstandingLookups map[int32]chan LookupReply
+ vomHelper VomHelper
closed bool
}
var _ rpc.Dispatcher = (*dispatcher)(nil)
// newDispatcher is a dispatcher factory.
-func newDispatcher(serverId uint32, flowFactory flowFactory, invokerFactory invokerFactory, authFactory authFactory) *dispatcher {
+func newDispatcher(serverId uint32, flowFactory flowFactory, invokerFactory invokerFactory, authFactory authFactory, vomHelper VomHelper) *dispatcher {
return &dispatcher{
serverId: serverId,
flowFactory: flowFactory,
invokerFactory: invokerFactory,
authFactory: authFactory,
outstandingLookups: make(map[int32]chan LookupReply),
+ vomHelper: vomHelper,
}
}
@@ -128,7 +130,7 @@
}
var lookupReply LookupReply
- if err := lib.HexVomDecode(data, &lookupReply, nil); err != nil {
+ if err := lib.HexVomDecode(data, &lookupReply, d.vomHelper.TypeDecoder()); err != nil {
err2 := verror.Convert(verror.ErrInternal, nil, err)
lookupReply = LookupReply{Err: err2}
vlog.Errorf("unmarshaling invoke request failed: %v, %s", err, data)
diff --git a/services/wspr/internal/rpc/server/dispatcher_test.go b/services/wspr/internal/rpc/server/dispatcher_test.go
index 0e6615b..b58bc80 100644
--- a/services/wspr/internal/rpc/server/dispatcher_test.go
+++ b/services/wspr/internal/rpc/server/dispatcher_test.go
@@ -15,10 +15,21 @@
"v.io/v23/vdl"
"v.io/v23/vdlroot/signature"
"v.io/v23/verror"
+ "v.io/v23/vom"
"v.io/x/ref/services/wspr/internal/lib"
"v.io/x/ref/services/wspr/internal/lib/testwriter"
)
+type mockVomHelper struct{}
+
+func (mockVomHelper) TypeEncoder() *vom.TypeEncoder {
+ return nil
+}
+
+func (mockVomHelper) TypeDecoder() *vom.TypeDecoder {
+ return nil
+}
+
type mockFlowFactory struct {
writer testwriter.Writer
}
@@ -80,7 +91,7 @@
func TestSuccessfulLookup(t *testing.T) {
flowFactory := &mockFlowFactory{}
- d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{})
+ d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{}, mockVomHelper{})
expectedSig := []signature.Interface{
{Name: "AName"},
}
@@ -129,7 +140,7 @@
func TestSuccessfulLookupWithAuthorizer(t *testing.T) {
flowFactory := &mockFlowFactory{}
- d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{})
+ d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{}, mockVomHelper{})
expectedSig := []signature.Interface{
{Name: "AName"},
}
@@ -178,7 +189,7 @@
func TestFailedLookup(t *testing.T) {
flowFactory := &mockFlowFactory{}
- d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{})
+ d := newDispatcher(0, flowFactory, mockInvokerFactory{}, mockAuthorizerFactory{}, mockVomHelper{})
go func() {
if err := flowFactory.writer.WaitForMessage(1); err != nil {
t.Errorf("failed to get dispatch request %v", err)
diff --git a/services/wspr/internal/rpc/server/server.go b/services/wspr/internal/rpc/server/server.go
index f331764..12a6385 100644
--- a/services/wspr/internal/rpc/server/server.go
+++ b/services/wspr/internal/rpc/server/server.go
@@ -45,15 +45,19 @@
GetOrAddBlessingsHandle(blessings security.Blessings) principal.BlessingsHandle
}
+type VomHelper interface {
+ TypeEncoder() *vom.TypeEncoder
+
+ TypeDecoder() *vom.TypeDecoder
+}
type ServerHelper interface {
FlowHandler
HandleStore
+ VomHelper
SendLogMessage(level lib.LogLevel, msg string) error
BlessingsCache() *principal.BlessingsCache
- TypeEncoder() *vom.TypeEncoder
-
Context() *context.T
}
@@ -180,7 +184,7 @@
Args: vdlValArgs,
Call: rpcCall,
}
- vomMessage, err := lib.HexVomEncode(message, nil)
+ vomMessage, err := lib.HexVomEncode(message, s.helper.TypeEncoder())
if err != nil {
return errHandler(err)
}
@@ -289,7 +293,7 @@
Args: []*vdl.Value{vdl.ValueOf(pattern)},
Call: rpcCall,
}
- vomMessage, err := lib.HexVomEncode(message, nil)
+ vomMessage, err := lib.HexVomEncode(message, s.helper.TypeEncoder())
if err != nil {
return errHandler(err)
}
@@ -578,7 +582,7 @@
}
vlog.VI(0).Infof("Sending out auth request for %v, %v", flow.ID, message)
- vomMessage, err := lib.HexVomEncode(message, nil)
+ vomMessage, err := lib.HexVomEncode(message, s.helper.TypeEncoder())
if err != nil {
replyChan <- verror.Convert(verror.ErrInternal, nil, err)
} else if err := flow.Writer.Send(lib.ResponseAuthRequest, vomMessage); err != nil {
@@ -630,7 +634,7 @@
defer s.serverStateLock.Unlock()
if s.dispatcher == nil {
- s.dispatcher = newDispatcher(s.id, s, s, s)
+ s.dispatcher = newDispatcher(s.id, s, s, s, s.helper)
}
if !s.isListening {
@@ -668,7 +672,7 @@
// Decode the result and send it through the channel
var reply lib.ServerRpcReply
- if err := lib.HexVomDecode(data, &reply, nil); err != nil {
+ if err := lib.HexVomDecode(data, &reply, s.helper.TypeDecoder()); err != nil {
reply.Err = err
}
@@ -710,7 +714,7 @@
}
// Decode the result and send it through the channel
var reply AuthReply
- if err := lib.HexVomDecode(data, &reply, nil); err != nil {
+ if err := lib.HexVomDecode(data, &reply, s.helper.TypeDecoder()); err != nil {
err = verror.Convert(verror.ErrInternal, nil, err)
reply = AuthReply{Err: err}
}
@@ -740,7 +744,7 @@
}
var reply CaveatValidationResponse
- if err := lib.HexVomDecode(data, &reply, nil); err != nil {
+ if err := lib.HexVomDecode(data, &reply, s.helper.TypeDecoder()); err != nil {
vlog.Errorf("failed to decode validation response %q: error %v", data, err)
ch <- []error{}
return