wspr: Use a type flow to pass type messages between wspr and js.
This change doesn't make any new code paths re-use the type info across
requests. That will happen in subsequent changes.
MultiPart: 2/2
Change-Id: I5e540c2f66b421bca6e8e27317748c0160404058
diff --git a/services/wspr/internal/app/app.go b/services/wspr/internal/app/app.go
index 0aa6f44..565025b 100644
--- a/services/wspr/internal/app/app.go
+++ b/services/wspr/internal/app/app.go
@@ -7,6 +7,8 @@
package app
import (
+ "bytes"
+ "encoding/hex"
"fmt"
"io"
"reflect"
@@ -35,6 +37,8 @@
const (
// pkgPath is the prefix os errors in this package.
pkgPath = "v.io/x/ref/services/wspr/internal/app"
+
+ typeFlow = 0
)
// Errors
@@ -101,13 +105,10 @@
// an outgoing rpc call.
reservedServices map[string]rpc.Invoker
- encoderLock sync.Mutex
- clientEncoder *vom.Encoder
- clientWriter *lib.ProxyWriter
+ typeEncoder *vom.TypeEncoder
- decoderLock sync.Mutex
- clientDecoder *vom.Decoder
- clientReader *lib.ProxyReader
+ typeDecoder *vom.TypeDecoder
+ typeReader *lib.TypeReader
}
var _ ControllerServerMethods = (*Controller)(nil)
@@ -219,13 +220,13 @@
OutArgs: results,
TraceResponse: vtrace.GetResponse(ctx),
}
- c.encoderLock.Lock()
- defer c.encoderLock.Unlock()
- if err := c.clientEncoder.Encode(response); err != nil {
+ var buf bytes.Buffer
+ encoder := vom.NewEncoderWithTypeEncoder(&buf, c.typeEncoder)
+ if err := encoder.Encode(response); err != nil {
w.Error(err)
return
}
- encoded := c.clientWriter.ConsumeBuffer()
+ encoded := hex.EncodeToString(buf.Bytes())
if err := w.Send(lib.ResponseFinal, encoded); err != nil {
w.Error(verror.Convert(marshallingError, ctx, err))
}
@@ -363,6 +364,7 @@
server.Stop()
}
+ c.typeReader.Close()
c.cancel()
}
@@ -371,10 +373,10 @@
c.outstandingRequests = make(map[int32]*outstandingRequest)
c.flowMap = make(map[int32]interface{})
c.servers = make(map[uint32]*server.Server)
- c.clientReader = lib.NewProxyReader()
- c.clientDecoder = vom.NewDecoder(c.clientReader)
- c.clientWriter = lib.NewProxyWriter()
- c.clientEncoder = vom.NewEncoder(c.clientWriter)
+ c.typeReader = lib.NewTypeReader()
+ c.typeDecoder = vom.NewTypeDecoder(c.typeReader)
+ c.typeEncoder = vom.NewTypeEncoder(lib.NewTypeWriter(c.writerCreator(typeFlow)))
+ c.lastGeneratedId += 2
}
// SendOnStream writes data on id's stream. The actual network write will be
@@ -484,42 +486,40 @@
func (l *localCall) RemoteEndpoint() naming.Endpoint { return nil }
func (l *localCall) Security() security.Call { return l }
-func (c *Controller) handleInternalCall(ctx *context.T, invoker rpc.Invoker, msg *RpcRequest, w lib.ClientWriter, span vtrace.Span) {
+func (c *Controller) handleInternalCall(ctx *context.T, invoker rpc.Invoker, msg *RpcRequest, w lib.ClientWriter, span vtrace.Span, decoder *vom.Decoder) {
argptrs, tags, err := invoker.Prepare(msg.Method, int(msg.NumInArgs))
if err != nil {
w.Error(verror.Convert(verror.ErrInternal, ctx, err))
return
}
for _, argptr := range argptrs {
- if err := c.clientDecoder.Decode(argptr); err != nil {
+ if err := decoder.Decode(argptr); err != nil {
w.Error(verror.Convert(verror.ErrInternal, ctx, err))
return
}
}
- go func() {
- results, err := invoker.Invoke(ctx, &localCall{ctx, msg, tags, w}, msg.Method, argptrs)
+ results, err := invoker.Invoke(ctx, &localCall{ctx, msg, tags, w}, msg.Method, argptrs)
+ if err != nil {
+ w.Error(verror.Convert(verror.ErrInternal, ctx, err))
+ return
+ }
+ if msg.IsStreaming {
+ if err := w.Send(lib.ResponseStreamClose, nil); err != nil {
+ w.Error(verror.New(marshallingError, ctx, "ResponseStreamClose"))
+ }
+ }
+
+ // Convert results from []interface{} to []*vdl.Value.
+ vresults := make([]*vdl.Value, len(results))
+ for i, res := range results {
+ vv, err := vdl.ValueFromReflect(reflect.ValueOf(res))
if err != nil {
w.Error(verror.Convert(verror.ErrInternal, ctx, err))
return
}
- if msg.IsStreaming {
- if err := w.Send(lib.ResponseStreamClose, nil); err != nil {
- w.Error(verror.New(marshallingError, ctx, "ResponseStreamClose"))
- }
- }
-
- // Convert results from []interface{} to []*vdl.Value.
- vresults := make([]*vdl.Value, len(results))
- for i, res := range results {
- vv, err := vdl.ValueFromReflect(reflect.ValueOf(res))
- if err != nil {
- w.Error(verror.Convert(verror.ErrInternal, ctx, err))
- return
- }
- vresults[i] = vv
- }
- c.sendRPCResponse(ctx, w, span, vresults)
- }()
+ vresults[i] = vv
+ }
+ c.sendRPCResponse(ctx, w, span, vresults)
}
// HandleCaveatValidationResponse handles the response to caveat validation
@@ -537,15 +537,18 @@
// HandleVeyronRequest starts a vanadium rpc and returns before the rpc has been completed.
func (c *Controller) HandleVeyronRequest(ctx *context.T, id int32, data string, w lib.ClientWriter) {
- c.decoderLock.Lock()
- defer c.decoderLock.Unlock()
- err := c.clientReader.ReplaceBuffer(data)
+ binBytes, err := hex.DecodeString(data)
+ if err != nil {
+ w.Error(verror.Convert(verror.ErrInternal, ctx, fmt.Errorf("Error decoding hex string %q: %v", data, err)))
+ return
+ }
+ decoder := vom.NewDecoderWithTypeDecoder(bytes.NewReader(binBytes), c.typeDecoder)
if err != nil {
w.Error(verror.Convert(verror.ErrInternal, ctx, fmt.Errorf("Error decoding hex string %q: %v", data, err)))
return
}
var msg RpcRequest
- if err := c.clientDecoder.Decode(&msg); err != nil {
+ if err := decoder.Decode(&msg); err != nil {
w.Error(verror.Convert(verror.ErrInternal, ctx, err))
return
}
@@ -568,14 +571,14 @@
// If this message is for an internal service, do a short-circuit dispatch here.
if invoker, ok := c.reservedServices[msg.Name]; ok {
- c.handleInternalCall(ctx, invoker, &msg, w, span)
+ go c.handleInternalCall(ctx, invoker, &msg, w, span, decoder)
return
}
inArgs := make([]interface{}, msg.NumInArgs)
for i := range inArgs {
var v *vdl.Value
- if err := c.clientDecoder.Decode(&v); err != nil {
+ if err := decoder.Decode(&v); err != nil {
w.Error(err)
return
}
@@ -935,6 +938,10 @@
granterStr.Send(data)
}
+func (c *Controller) HandleTypeMessage(data string) {
+ c.typeReader.Add(data)
+}
+
func (c *Controller) BlessingsDebugString(_ *context.T, _ rpc.ServerCall, handle principal.BlessingsHandle) (string, error) {
var inputBlessings security.Blessings
if inputBlessings = c.GetBlessings(handle); inputBlessings.IsZero() {
diff --git a/services/wspr/internal/app/app_test.go b/services/wspr/internal/app/app_test.go
index 4550e58..d7e7333 100644
--- a/services/wspr/internal/app/app_test.go
+++ b/services/wspr/internal/app/app_test.go
@@ -6,6 +6,7 @@
import (
"bytes"
+ "encoding/base64"
"encoding/hex"
"fmt"
"reflect"
@@ -129,6 +130,11 @@
return startAnyServer(ctx, true, mt)
}
+func createWriterCreator(w lib.ClientWriter) func(id int32) lib.ClientWriter {
+ return func(int32) lib.ClientWriter {
+ return w
+ }
+}
func TestGetGoServerSignature(t *testing.T) {
ctx, shutdown := test.V23Init()
defer shutdown()
@@ -143,7 +149,8 @@
spec := v23.GetListenSpec(ctx)
spec.Proxy = "mockVeyronProxyEP"
- controller, err := NewController(ctx, nil, &spec, nil, newBlessedPrincipal(ctx))
+ writer := &testwriter.Writer{}
+ controller, err := NewController(ctx, createWriterCreator(writer), &spec, nil, newBlessedPrincipal(ctx))
if err != nil {
t.Fatalf("Failed to create controller: %v", err)
@@ -164,12 +171,13 @@
}
type goServerTestCase struct {
- method string
- inArgs []interface{}
- numOutArgs int32
- streamingInputs []interface{}
- expectedStream []lib.Response
- expectedError error
+ expectedTypeStream []lib.Response
+ method string
+ inArgs []interface{}
+ numOutArgs int32
+ streamingInputs []interface{}
+ expectedStream []lib.Response
+ expectedError error
}
func runGoServerTestCase(t *testing.T, testCase goServerTestCase) {
@@ -186,14 +194,14 @@
spec := v23.GetListenSpec(ctx)
spec.Proxy = "mockVeyronProxyEP"
- controller, err := NewController(ctx, nil, &spec, nil, newBlessedPrincipal(ctx))
+ writer := testwriter.Writer{}
+ controller, err := NewController(ctx, createWriterCreator(&writer), &spec, nil, newBlessedPrincipal(ctx))
if err != nil {
t.Errorf("unable to create controller: %v", err)
t.Fail()
return
}
- writer := testwriter.Writer{}
var stream *outstandingStream
if len(testCase.streamingInputs) > 0 {
stream = newStream(nil)
@@ -217,26 +225,48 @@
}
controller.sendVeyronRequest(ctx, 0, &request, testCase.inArgs, &writer, stream, vtrace.GetSpan(ctx))
- if err := testwriter.CheckResponses(&writer, testCase.expectedStream, testCase.expectedError); err != nil {
+ if err := testwriter.CheckResponses(&writer, testCase.expectedStream, testCase.expectedTypeStream, testCase.expectedError); err != nil {
t.Error(err)
}
}
-func makeRPCResponse(outArgs ...*vdl.Value) string {
- return lib.HexVomEncodeOrDie(RpcResponse{
+type typeWriter struct {
+ resps []lib.Response
+}
+
+func (t *typeWriter) Write(p []byte) (int, error) {
+ t.resps = append(t.resps, lib.Response{
+ Type: lib.ResponseTypeMessage,
+ Message: base64.StdEncoding.EncodeToString(p),
+ })
+ return len(p), nil
+}
+
+func makeRPCResponse(outArgs ...*vdl.Value) (string, []lib.Response) {
+ writer := typeWriter{}
+ typeEncoder := vom.NewTypeEncoder(&writer)
+ var buf bytes.Buffer
+ encoder := vom.NewEncoderWithTypeEncoder(&buf, typeEncoder)
+ var output = RpcResponse{
OutArgs: outArgs,
TraceResponse: vtrace.Response{},
- })
+ }
+ if err := encoder.Encode(output); err != nil {
+ panic(err)
+ }
+ return hex.EncodeToString(buf.Bytes()), writer.resps
}
func TestCallingGoServer(t *testing.T) {
+ resp, typeMessages := makeRPCResponse(vdl.Int32Value(5))
runGoServerTestCase(t, goServerTestCase{
- method: "Add",
- inArgs: []interface{}{2, 3},
- numOutArgs: 1,
+ expectedTypeStream: typeMessages,
+ method: "Add",
+ inArgs: []interface{}{2, 3},
+ numOutArgs: 1,
expectedStream: []lib.Response{
lib.Response{
- Message: makeRPCResponse(vdl.Int32Value(5)),
+ Message: resp,
Type: lib.ResponseFinal,
},
},
@@ -253,10 +283,12 @@
}
func TestCallingGoWithStreaming(t *testing.T) {
+ resp, typeMessages := makeRPCResponse(vdl.Int32Value(10))
runGoServerTestCase(t, goServerTestCase{
- method: "StreamingAdd",
- streamingInputs: []interface{}{1, 2, 3, 4},
- numOutArgs: 1,
+ expectedTypeStream: typeMessages,
+ method: "StreamingAdd",
+ streamingInputs: []interface{}{1, 2, 3, 4},
+ numOutArgs: 1,
expectedStream: []lib.Response{
lib.Response{
Message: lib.HexVomEncodeOrDie(int32(1)),
@@ -279,7 +311,7 @@
Type: lib.ResponseStreamClose,
},
lib.Response{
- Message: makeRPCResponse(vdl.Int32Value(10)),
+ Message: resp,
Type: lib.ResponseFinal,
},
},
@@ -291,11 +323,13 @@
writer *testwriter.Writer
mounttableServer rpc.Server
proxyShutdown func()
+ typeStream *typeWriter
+ typeEncoder *vom.TypeEncoder
}
-func makeRequest(rpc RpcRequest, args ...interface{}) (string, error) {
+func makeRequest(typeEncoder *vom.TypeEncoder, rpc RpcRequest, args ...interface{}) (string, error) {
var buf bytes.Buffer
- encoder := vom.NewEncoder(&buf)
+ encoder := vom.NewEncoderWithTypeEncoder(&buf, typeEncoder)
if err := encoder.Encode(rpc); err != nil {
return "", err
}
@@ -343,19 +377,31 @@
}
v23.GetNamespace(controller.Context()).SetRoots("/" + endpoint.String())
-
- req, err := makeRequest(RpcRequest{
+ typeStream := &typeWriter{}
+ typeEncoder := vom.NewTypeEncoder(typeStream)
+ req, err := makeRequest(typeEncoder, RpcRequest{
Name: "__controller",
Method: "Serve",
NumInArgs: 3,
NumOutArgs: 1,
Deadline: vdltime.Deadline{},
}, "adder", 0, []RpcServerOption{})
+
+ for _, r := range typeStream.resps {
+ b, err := base64.StdEncoding.DecodeString(r.Message.(string))
+ if err != nil {
+ panic(err)
+ }
+
+ controller.HandleTypeMessage(hex.EncodeToString(b))
+ }
+ typeStream.resps = []lib.Response{}
controller.HandleVeyronRequest(ctx, 0, req, writer)
testWriter, _ := writer.(*testwriter.Writer)
return &runningTest{
controller, testWriter, mounttableServer, proxyShutdown,
+ typeStream, typeEncoder,
}, nil
}
@@ -402,7 +448,10 @@
finalResponse: testCase.finalResponse,
finalError: testCase.err,
controllerReady: sync.RWMutex{},
+ flowCount: 2,
+ typeReader: lib.NewTypeReader(),
}
+ mock.typeDecoder = vom.NewTypeDecoder(mock.typeReader)
rt, err := serveServer(ctx, mock, func(controller *Controller) {
mock.controller = controller
})
diff --git a/services/wspr/internal/app/messaging.go b/services/wspr/internal/app/messaging.go
index 315fa04..aa9c05a 100644
--- a/services/wspr/internal/app/messaging.go
+++ b/services/wspr/internal/app/messaging.go
@@ -81,6 +81,8 @@
// A response to a granter request.
GranterResponseMessage = 22
+
+ TypeMessage = 23
)
type Message struct {
@@ -123,28 +125,29 @@
go c.HandleCaveatValidationResponse(msg.Id, msg.Data)
case GranterResponseMessage:
go c.HandleGranterResponse(msg.Id, msg.Data)
-
+ case TypeMessage:
+ // These messages need to be handled in order so they are done in line.
+ c.HandleTypeMessage(msg.Data)
default:
w.Error(verror.New(errUnknownMessageType, ctx, msg.Type))
}
}
// ConstructOutgoingMessage constructs a message to send to javascript in a consistent format.
-// TODO(bprosnitz) Don't double-encode
func ConstructOutgoingMessage(messageId int32, messageType lib.ResponseType, data interface{}) (string, error) {
var buf bytes.Buffer
+ if _, err := buf.Write(lib.BinaryEncodeUint(uint64(messageId))); err != nil {
+ return "", err
+ }
+ if _, err := buf.Write(lib.BinaryEncodeUint(uint64(messageType))); err != nil {
+ return "", err
+ }
enc := vom.NewEncoder(&buf)
- if err := enc.Encode(lib.Response{Type: messageType, Message: data}); err != nil {
+ if err := enc.Encode(data); err != nil {
return "", err
}
- var buf2 bytes.Buffer
- enc2 := vom.NewEncoder(&buf2)
- if err := enc2.Encode(Message{Id: messageId, Data: fmt.Sprintf("%x", buf.Bytes())}); err != nil {
- return "", err
- }
-
- return fmt.Sprintf("%x", buf2.Bytes()), nil
+ return fmt.Sprintf("%x", buf.Bytes()), nil
}
// FormatAsVerror formats an error as a verror.
diff --git a/services/wspr/internal/app/mock_jsServer_test.go b/services/wspr/internal/app/mock_jsServer_test.go
index 3c2fca4..b0f89c8 100644
--- a/services/wspr/internal/app/mock_jsServer_test.go
+++ b/services/wspr/internal/app/mock_jsServer_test.go
@@ -6,6 +6,7 @@
import (
"bytes"
+ "encoding/hex"
"encoding/json"
"fmt"
"reflect"
@@ -44,6 +45,9 @@
// at the same time.
flowCount int32
rpcFlow int32
+
+ typeReader *lib.TypeReader
+ typeDecoder *vom.TypeDecoder
}
func (m *mockJSServer) Send(responseType lib.ResponseType, msg interface{}) error {
@@ -69,6 +73,9 @@
case lib.ResponseLog, lib.ResponseBlessingsCacheMessage:
m.flowCount += 2
return nil
+ case lib.ResponseTypeMessage:
+ m.handleTypeMessage(msg)
+ return nil
}
return fmt.Errorf("Unknown message type: %d", responseType)
}
@@ -105,6 +112,9 @@
return r.(map[string]interface{}), nil
}
+func (m *mockJSServer) handleTypeMessage(v interface{}) {
+ m.typeReader.Add(hex.EncodeToString(v.([]byte)))
+}
func (m *mockJSServer) handleDispatcherLookup(v interface{}) error {
defer func() {
m.flowCount += 2
diff --git a/services/wspr/internal/browspr/browspr_test.go b/services/wspr/internal/browspr/browspr_test.go
index 03c108f..956fe54 100644
--- a/services/wspr/internal/browspr/browspr_test.go
+++ b/services/wspr/internal/browspr/browspr_test.go
@@ -78,6 +78,26 @@
return s, endpoints[0], nil
}
+func parseBrowsperResponse(data string, t *testing.T) (uint64, uint64, []byte) {
+ receivedBytes, err := hex.DecodeString(data)
+ if err != nil {
+ t.Fatalf("Failed to hex decode outgoing message: %v", err)
+ }
+
+ id, bytesRead, err := lib.BinaryDecodeUint(receivedBytes)
+ if err != nil {
+ t.Fatalf("Failed to read mesage id: %v", err)
+ }
+
+ receivedBytes = receivedBytes[bytesRead:]
+ messageType, bytesRead, err := lib.BinaryDecodeUint(receivedBytes)
+ if err != nil {
+ t.Fatalf("Failed to read message type: %v", err)
+ }
+ receivedBytes = receivedBytes[bytesRead:]
+ return id, messageType, receivedBytes
+}
+
func TestBrowspr(t *testing.T) {
ctx, shutdown := test.V23Init()
defer shutdown()
@@ -154,12 +174,26 @@
receivedResponse := make(chan bool, 1)
var receivedInstanceId int32
var receivedType string
- var receivedMsg string
+ var messageId uint64
+ var messageType uint64
+ var receivedBytes []byte
+ typeReader := lib.NewTypeReader()
var postMessageHandler = func(instanceId int32, ty, msg string) {
+ id, mType, bin := parseBrowsperResponse(msg, t)
+ if mType == lib.ResponseTypeMessage {
+ var decodedBytes []byte
+ if err := vom.Decode(bin, &decodedBytes); err != nil {
+ t.Fatalf("Failed to decode type bytes: %v", err)
+ }
+ typeReader.Add(hex.EncodeToString(decodedBytes))
+ return
+ }
receivedInstanceId = instanceId
receivedType = ty
- receivedMsg = msg
+ messageType = mType
+ messageId = id
+ receivedBytes = bin
receivedResponse <- true
}
@@ -194,14 +228,17 @@
Deadline: vdltime.Deadline{},
}
+ var typeBuf bytes.Buffer
+ typeEncoder := vom.NewTypeEncoder(&typeBuf)
var buf bytes.Buffer
- encoder := vom.NewEncoder(&buf)
+ encoder := vom.NewEncoderWithTypeEncoder(&buf, typeEncoder)
if err := encoder.Encode(rpc); err != nil {
t.Fatalf("Failed to vom encode rpc message: %v", err)
}
if err := encoder.Encode("InputValue"); err != nil {
t.Fatalf("Failed to vom encode rpc message: %v", err)
}
+
vomRPC := hex.EncodeToString(buf.Bytes())
msg := app.Message{
@@ -210,6 +247,11 @@
Type: app.VeyronRequestMessage,
}
+ typeMessage := app.Message{
+ Id: 0,
+ Data: hex.EncodeToString(typeBuf.Bytes()),
+ Type: app.TypeMessage,
+ }
createInstanceMessage := CreateInstanceMessage{
InstanceId: msgInstanceId,
Origin: msgOrigin,
@@ -218,6 +260,11 @@
}
_, err = browspr.HandleCreateInstanceRpc(vdl.ValueOf(createInstanceMessage))
+ err = browspr.HandleMessage(msgInstanceId, msgOrigin, typeMessage)
+ if err != nil {
+ t.Fatalf("Error while handling type message: %v", err)
+ }
+
err = browspr.HandleMessage(msgInstanceId, msgOrigin, msg)
if err != nil {
t.Fatalf("Error while handling message: %v", err)
@@ -232,32 +279,28 @@
t.Errorf("Received unexpected response type. Expected: %q, but got %q", "browsprMsg", receivedType)
}
- var outMsg app.Message
- if err := lib.HexVomDecode(receivedMsg, &outMsg); err != nil {
- t.Fatalf("Failed to unmarshall outgoing message: %v", err)
- }
- if outMsg.Id != int32(1) {
- t.Errorf("Id was %v, expected %v", outMsg.Id, int32(1))
- }
- if outMsg.Type != app.VeyronRequestMessage {
- t.Errorf("Message type was %v, expected %v", outMsg.Type, app.MessageType(0))
+ if messageId != 1 {
+ t.Errorf("Id was %v, expected %v", messageId, int32(1))
}
- var responseMsg lib.Response
- if err := lib.HexVomDecode(outMsg.Data, &responseMsg); err != nil {
+ if lib.ResponseType(messageType) != lib.ResponseFinal {
+ t.Errorf("Message type was %v, expected %v", messageType, lib.ResponseFinal)
+ }
+
+ var data string
+ if err := vom.NewDecoder(bytes.NewBuffer(receivedBytes)).Decode(&data); err != nil {
t.Fatalf("Failed to unmarshall outgoing response: %v", err)
}
- if responseMsg.Type != lib.ResponseFinal {
- t.Errorf("Data was %q, expected %q", outMsg.Data, `["[InputValue]"]`)
- }
- var outArg string
- var ok bool
- if outArg, ok = responseMsg.Message.(string); !ok {
- t.Errorf("Got unexpected response message body of type %T, expected type string", responseMsg.Message)
- }
+
var result app.RpcResponse
- if err := lib.HexVomDecode(outArg, &result); err != nil {
- t.Errorf("Failed to vom decode args from %v: %v", outArg, err)
+ dataBytes, err := hex.DecodeString(data)
+ if err != nil {
+ t.Errorf("Failed to hex decode from %v: %v", data, err)
+ }
+
+ decoder := vom.NewDecoderWithTypeDecoder(bytes.NewBuffer(dataBytes), vom.NewTypeDecoder(typeReader))
+ if err := decoder.Decode(&result); err != nil {
+ t.Errorf("Failed to vom decode args from %v: %v", data, err)
}
if got, want := result.OutArgs[0], vdl.StringValue("[InputValue]"); !vdl.EqualValue(got, want) {
t.Errorf("Result got %v, want %v", got, want)
diff --git a/services/wspr/internal/lib/binary_util.go b/services/wspr/internal/lib/binary_util.go
new file mode 100644
index 0000000..fb4afab
--- /dev/null
+++ b/services/wspr/internal/lib/binary_util.go
@@ -0,0 +1,78 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lib
+
+import (
+ "v.io/v23/verror"
+)
+
+const pkgPath = "v.io/x/ref/services/wspr/internal/lib"
+
+const uint64Size = 8
+
+var (
+ errInvalid = verror.Register(pkgPath+".errInvalid", verror.NoRetry, "{1:}{2:} wspr: invalid encoding{:_}")
+ errEOF = verror.Register(pkgPath+".errEOF", verror.NoRetry, "{1:}{2:} wspr: eof{:_}")
+ errUintOverflow = verror.Register(pkgPath+".errUintOverflow", verror.NoRetry, "{1:}{2:} wspr: scalar larger than 8 bytes{:_}")
+)
+
+// This code has been copied from the vom package and should be kept up to date
+// with it.
+
+// Unsigned integers are the basis for all other primitive values. This is a
+// two-state encoding. If the number is less than 128 (0 through 0x7f), its
+// value is written directly. Otherwise the value is written in big-endian byte
+// order preceded by the negated byte length.
+func BinaryEncodeUint(v uint64) []byte {
+ switch {
+ case v <= 0x7f:
+ return []byte{byte(v)}
+ case v <= 0xff:
+ return []byte{0xff, byte(v)}
+ case v <= 0xffff:
+ return []byte{0xfe, byte(v >> 8), byte(v)}
+ case v <= 0xffffff:
+ return []byte{0xfd, byte(v >> 16), byte(v >> 8), byte(v)}
+ case v <= 0xffffffff:
+ return []byte{0xfc, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
+ case v <= 0xffffffffff:
+ return []byte{0xfb, byte(v >> 32), byte(v >> 24),
+ byte(v >> 16), byte(v >> 8), byte(v)}
+ case v <= 0xffffffffffff:
+ return []byte{0xfa, byte(v >> 40), byte(v >> 32), byte(v >> 24),
+ byte(v >> 16), byte(v >> 8), byte(v)}
+ case v <= 0xffffffffffffff:
+ return []byte{0xf9, byte(v >> 48), byte(v >> 40), byte(v >> 32),
+ byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
+ default:
+ return []byte{0xf9, byte(v >> 56), byte(v >> 48), byte(v >> 40),
+ byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
+ }
+}
+
+func BinaryDecodeUint(input []byte) (v uint64, byteLen int, err error) {
+ if len(input) == 0 {
+ return 0, 0, verror.New(errEOF, nil)
+ }
+ firstByte := input[0]
+ if firstByte <= 0x7f {
+ return uint64(firstByte), 1, nil
+ }
+
+ if firstByte <= 0xdf {
+ return 0, 0, verror.New(errInvalid, nil)
+ }
+ byteLen = int(-int8(firstByte))
+ if byteLen < 1 || byteLen > uint64Size {
+ return 0, 0, verror.New(errUintOverflow, nil)
+ }
+ if len(input) < byteLen {
+ return 0, 0, verror.New(errEOF, nil)
+ }
+ for i := 1; i < byteLen; i++ {
+ v = v<<8 | uint64(input[i])
+ }
+ return
+}
diff --git a/services/wspr/internal/lib/hex_vom.go b/services/wspr/internal/lib/hex_vom.go
index ded745dd..24663e9 100644
--- a/services/wspr/internal/lib/hex_vom.go
+++ b/services/wspr/internal/lib/hex_vom.go
@@ -37,39 +37,3 @@
decoder := vom.NewDecoder(bytes.NewReader(binbytes))
return decoder.Decode(v)
}
-
-// ProxyReader implements io.Reader but allows changing the underlying buffer.
-// This is useful for merging discrete messages that are part of the same flow.
-type ProxyReader struct {
- bytes.Buffer
-}
-
-func NewProxyReader() *ProxyReader {
- return &ProxyReader{}
-}
-
-func (p *ProxyReader) ReplaceBuffer(data string) error {
- binbytes, err := hex.DecodeString(data)
- if err != nil {
- return err
- }
- p.Reset()
- p.Write(binbytes)
- return nil
-}
-
-// ProxyWriter implements io.Writer but allows changing the underlying buffer.
-// This is useful for merging discrete messages that are part of the same flow.
-type ProxyWriter struct {
- bytes.Buffer
-}
-
-func NewProxyWriter() *ProxyWriter {
- return &ProxyWriter{}
-}
-
-func (p *ProxyWriter) ConsumeBuffer() string {
- s := hex.EncodeToString(p.Buffer.Bytes())
- p.Reset()
- return s
-}
diff --git a/services/wspr/internal/lib/hex_vom_test.go b/services/wspr/internal/lib/hex_vom_test.go
new file mode 100644
index 0000000..728f976
--- /dev/null
+++ b/services/wspr/internal/lib/hex_vom_test.go
@@ -0,0 +1,63 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lib
+
+import (
+ "encoding/hex"
+ "reflect"
+ "testing"
+ "time"
+)
+
+func TestReadBeforeData(t *testing.T) {
+ reader := NewTypeReader()
+ input := []byte{0, 2, 3, 4, 5}
+ data := make([]byte, 5)
+ go func() {
+ <-time.After(100 * time.Millisecond)
+ reader.Add(hex.EncodeToString(input))
+ }()
+ n, err := reader.Read(data)
+ if n != len(data) {
+ t.Errorf("Read the wrong number of bytes, wanted:%d, got:%d", len(data), n)
+ }
+
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ if !reflect.DeepEqual(input, data) {
+ t.Errorf("wrong data, want:%x, got:%x", input, data)
+ }
+}
+
+func TestReadWithPartialData(t *testing.T) {
+ reader := NewTypeReader()
+ input := []byte{0, 2, 3, 4, 5}
+ data := make([]byte, 5)
+ reader.Add(hex.EncodeToString(input[:2]))
+ go func() {
+ <-time.After(300 * time.Millisecond)
+ reader.Add(hex.EncodeToString(input[2:]))
+ }()
+ totalRead := 0
+ for {
+ n, err := reader.Read(data[totalRead:])
+ totalRead += n
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ break
+ }
+ if totalRead == 5 {
+ break
+ }
+ }
+ if totalRead != len(data) {
+ t.Errorf("Read the wrong number of bytes, wanted:%d, got:%d", len(data), totalRead)
+ }
+
+ if !reflect.DeepEqual(input, data) {
+ t.Errorf("wrong data, want:%x, got:%x", input, data)
+ }
+}
diff --git a/services/wspr/internal/lib/testwriter/writer.go b/services/wspr/internal/lib/testwriter/writer.go
index 01c78d4..2c9da53 100644
--- a/services/wspr/internal/lib/testwriter/writer.go
+++ b/services/wspr/internal/lib/testwriter/writer.go
@@ -22,8 +22,9 @@
type Writer struct {
sync.Mutex
- Stream []lib.Response // TODO Why not use channel?
- err error
+ TypeStream []lib.Response
+ Stream []lib.Response // TODO Why not use channel?
+ err error
// If this channel is set then a message will be sent
// to this channel after recieving a call to FinishMessage()
notifier chan bool
@@ -45,7 +46,11 @@
w.Lock()
defer w.Unlock()
- w.Stream = append(w.Stream, r)
+ if responseType == lib.ResponseTypeMessage {
+ w.TypeStream = append(w.TypeStream, r)
+ } else {
+ w.Stream = append(w.Stream, r)
+ }
if w.notifier != nil {
w.notifier <- true
}
@@ -101,10 +106,13 @@
return nil
}
-func CheckResponses(w *Writer, wantStream []lib.Response, wantErr error) error {
+func CheckResponses(w *Writer, wantStream []lib.Response, wantTypes []lib.Response, wantErr error) error {
if got, want := w.Stream, wantStream; !reflect.DeepEqual(got, want) {
return fmt.Errorf("streams don't match: got %#v, want %#v", got, want)
}
+ if got, want := w.TypeStream, wantTypes; !reflect.DeepEqual(got, want) {
+ return fmt.Errorf("streams don't match: got %#v, want %#v", got, want)
+ }
if got, want := w.err, wantErr; verror.ErrorID(got) != verror.ErrorID(want) {
return fmt.Errorf("unexpected error, got: %#v, expected: %#v", got, want)
}
diff --git a/services/wspr/internal/lib/type_reader.go b/services/wspr/internal/lib/type_reader.go
new file mode 100644
index 0000000..9b98161
--- /dev/null
+++ b/services/wspr/internal/lib/type_reader.go
@@ -0,0 +1,61 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lib
+
+import (
+ "bytes"
+ "encoding/hex"
+ "io"
+ "sync"
+)
+
+// TypeReader implements io.Reader but allows changing the underlying buffer.
+// This is useful for merging discrete messages that are part of the same flow.
+type TypeReader struct {
+ buf bytes.Buffer
+ mu sync.Mutex
+ isClosed bool
+ cond *sync.Cond
+}
+
+func NewTypeReader() *TypeReader {
+ reader := &TypeReader{}
+ reader.cond = sync.NewCond(&reader.mu)
+ return reader
+}
+
+func (r *TypeReader) Add(data string) error {
+ binBytes, err := hex.DecodeString(data)
+ if err != nil {
+ return err
+ }
+ r.mu.Lock()
+ _, err = r.buf.Write(binBytes)
+ r.mu.Unlock()
+ r.cond.Signal()
+ return err
+}
+
+func (r *TypeReader) Read(p []byte) (int, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for {
+ if r.buf.Len() > 0 {
+ return r.buf.Read(p)
+ }
+ if r.isClosed {
+ return 0, io.EOF
+ }
+ r.cond.Wait()
+ }
+
+}
+
+func (r *TypeReader) Close() {
+ r.mu.Lock()
+ r.isClosed = true
+ r.mu.Unlock()
+ r.cond.Broadcast()
+}
diff --git a/services/wspr/internal/lib/type_writer.go b/services/wspr/internal/lib/type_writer.go
new file mode 100644
index 0000000..3536db2
--- /dev/null
+++ b/services/wspr/internal/lib/type_writer.go
@@ -0,0 +1,23 @@
+// Copyright 2015 The Vanadium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lib
+
+// TypeWriter implements io.Writer but allows changing the underlying buffer.
+// This is useful for merging discrete messages that are part of the same flow.
+type TypeWriter struct {
+ writer ClientWriter
+}
+
+func NewTypeWriter(w ClientWriter) *TypeWriter {
+ return &TypeWriter{writer: w}
+}
+
+func (w *TypeWriter) Write(p []byte) (int, error) {
+ err := w.writer.Send(ResponseTypeMessage, p)
+ if err != nil {
+ return 0, err
+ }
+ return len(p), nil
+}
diff --git a/services/wspr/internal/lib/writer.go b/services/wspr/internal/lib/writer.go
index da3f117..1c90372 100644
--- a/services/wspr/internal/lib/writer.go
+++ b/services/wspr/internal/lib/writer.go
@@ -19,6 +19,7 @@
ResponseLog = 9 // Sends a message to be logged.
ResponseGranterRequest = 10
ResponseBlessingsCacheMessage = 11 // Update the blessings cache
+ ResponseTypeMessage = 12
)
type Response struct {
diff --git a/services/wspr/internal/rpc/server/dispatcher_test.go b/services/wspr/internal/rpc/server/dispatcher_test.go
index 43039ae..df69179 100644
--- a/services/wspr/internal/rpc/server/dispatcher_test.go
+++ b/services/wspr/internal/rpc/server/dispatcher_test.go
@@ -122,7 +122,7 @@
},
},
}
- if err := testwriter.CheckResponses(&flowFactory.writer, expectedResponses, nil); err != nil {
+ if err := testwriter.CheckResponses(&flowFactory.writer, expectedResponses, nil, nil); err != nil {
t.Error(err)
}
}
@@ -171,7 +171,7 @@
},
},
}
- if err := testwriter.CheckResponses(&flowFactory.writer, expectedResponses, nil); err != nil {
+ if err := testwriter.CheckResponses(&flowFactory.writer, expectedResponses, nil, nil); err != nil {
t.Error(err)
}
}
@@ -205,7 +205,7 @@
},
},
}
- if err := testwriter.CheckResponses(&flowFactory.writer, expectedResponses, nil); err != nil {
+ if err := testwriter.CheckResponses(&flowFactory.writer, expectedResponses, nil, nil); err != nil {
t.Error(err)
}
}