ref/runtime/internal/flow/conn: Do not double encrypt encasulated connections.
Change-Id: I1621774ff58b6a0c4218876520c308baa16a4c88
diff --git a/runtime/internal/flow/conn/auth.go b/runtime/internal/flow/conn/auth.go
index 15eb1a9..3b2f121 100644
--- a/runtime/internal/flow/conn/auth.go
+++ b/runtime/internal/flow/conn/auth.go
@@ -125,7 +125,13 @@
if rSetup.PeerNaClPublicKey == nil {
return nil, NewErrMissingSetupOption(ctx, "peerNaClPublicKey")
}
- return c.mp.setupEncryption(ctx, pk, sk, rSetup.PeerNaClPublicKey), nil
+ binding := c.mp.setupEncryption(ctx, pk, sk, rSetup.PeerNaClPublicKey)
+ // if we're encapsulated in another flow, tell that flow to stop
+ // encrypting now that we've started.
+ if f, ok := c.mp.rw.(*flw); ok {
+ f.disableEncryption()
+ }
+ return binding, nil
}
func (c *Conn) readRemoteAuth(ctx *context.T, binding []byte) error {
diff --git a/runtime/internal/flow/conn/flow.go b/runtime/internal/flow/conn/flow.go
index a3207dc..820cba9 100644
--- a/runtime/internal/flow/conn/flow.go
+++ b/runtime/internal/flow/conn/flow.go
@@ -25,6 +25,7 @@
opened bool
q *readq
bkey, dkey uint64
+ noEncrypt bool
}
// Ensure that *flw implements flow.Flow.
@@ -46,6 +47,11 @@
return f
}
+// disableEncrytion should not be called concurrently with Write* methods.
+func (f *flw) disableEncryption() {
+ f.noEncrypt = false
+}
+
// Implement io.Reader.
// Read and ReadMsg should not be called concurrently with themselves
// or each other.
@@ -134,6 +140,9 @@
if alsoClose && done {
d.Flags |= message.CloseFlag
}
+ if f.noEncrypt {
+ d.Flags |= message.DisableEncryptionFlag
+ }
sent += size
return size, done, f.conn.mp.writeMsg(f.ctx, d)
})
diff --git a/runtime/internal/flow/conn/message.go b/runtime/internal/flow/conn/message.go
index dc59431..fa8fabf 100644
--- a/runtime/internal/flow/conn/message.go
+++ b/runtime/internal/flow/conn/message.go
@@ -58,10 +58,16 @@
if err = p.cipher.Seal(p.writeBuf); err != nil {
return err
}
- if _, err = p.rw.WriteMsg(p.writeBuf); err == nil {
- ctx.VI(2).Infof("Wrote low-level message: %#v", m)
+ if _, err = p.rw.WriteMsg(p.writeBuf); err != nil {
+ return err
}
- return err
+ if data, ok := m.(*message.Data); ok && (data.Flags&message.DisableEncryptionFlag != 0) {
+ if _, err = p.rw.WriteMsg(data.Payload...); err != nil {
+ return err
+ }
+ }
+ ctx.VI(2).Infof("Wrote low-level message: %#v", m)
+ return nil
}
func (p *messagePipe) readMsg(ctx *context.T) (message.Message, error) {