blob: d463ef2ef30ec9dfe937e37d7c8ae9513b435ae0 [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package conn
import (
"reflect"
"sync"
"time"
"v.io/v23"
"v.io/v23/context"
"v.io/v23/flow"
"v.io/v23/flow/message"
"v.io/v23/naming"
"v.io/v23/rpc/version"
"v.io/v23/security"
"v.io/v23/verror"
"v.io/x/ref/runtime/internal/flow/flowcontrol"
)
// flowID is a number assigned to identify a flow.
// Each flow on a given conn will have a unique number.
const (
invalidFlowID = iota
blessingsFlowID
reservedFlows = 10
)
const mtu = 1 << 16
const defaultBufferSize = 1 << 20
const (
expressPriority = iota
flowPriority
tearDownPriority
)
// FlowHandlers process accepted flows.
type FlowHandler interface {
// HandleFlow processes an accepted flow.
HandleFlow(flow.Flow) error
}
// Conns are a multiplexing encrypted channels that can host Flows.
type Conn struct {
fc *flowcontrol.FlowController
mp *messagePipe
handler FlowHandler
version version.RPCVersion
lBlessings, rBlessings security.Blessings
local, remote naming.Endpoint
closed chan struct{}
blessingsFlow *blessingsFlow
loopWG sync.WaitGroup
mu sync.Mutex
nextFid uint64
flows map[uint64]*flw
dischargeTimer *time.Timer
lastUsedTime time.Time
}
// Ensure that *Conn implements flow.ManagedConn.
var _ flow.ManagedConn = &Conn{}
// NewDialed dials a new Conn on the given conn.
func NewDialed(
ctx *context.T,
conn flow.MsgReadWriteCloser,
local, remote naming.Endpoint,
versions version.RPCVersionRange,
handler FlowHandler) (*Conn, error) {
c := &Conn{
fc: flowcontrol.New(defaultBufferSize, mtu),
mp: newMessagePipe(conn),
handler: handler,
lBlessings: v23.GetPrincipal(ctx).BlessingStore().Default(),
local: local,
remote: remote,
closed: make(chan struct{}),
nextFid: reservedFlows,
flows: map[uint64]*flw{},
lastUsedTime: time.Now(),
}
if err := c.dialHandshake(ctx, versions); err != nil {
c.Close(ctx, err)
return nil, err
}
c.loopWG.Add(1)
go c.readLoop(ctx)
return c, nil
}
// NewAccepted accepts a new Conn on the given conn.
func NewAccepted(
ctx *context.T,
conn flow.MsgReadWriteCloser,
local naming.Endpoint,
versions version.RPCVersionRange,
handler FlowHandler) (*Conn, error) {
c := &Conn{
fc: flowcontrol.New(defaultBufferSize, mtu),
mp: newMessagePipe(conn),
handler: handler,
lBlessings: v23.GetPrincipal(ctx).BlessingStore().Default(),
local: local,
closed: make(chan struct{}),
nextFid: reservedFlows + 1,
flows: map[uint64]*flw{},
lastUsedTime: time.Now(),
}
if err := c.acceptHandshake(ctx, versions); err != nil {
c.Close(ctx, err)
return nil, err
}
c.loopWG.Add(1)
go c.readLoop(ctx)
return c, nil
}
// Dial dials a new flow on the Conn.
func (c *Conn) Dial(ctx *context.T, fn flow.BlessingsForPeer) (flow.Flow, error) {
if c.rBlessings.IsZero() {
return nil, NewErrDialingNonServer(ctx)
}
rDischarges, err := c.blessingsFlow.getLatestDischarges(ctx, c.rBlessings)
if err != nil {
return nil, err
}
blessings, discharges, err := fn(ctx, c.local, c.remote, c.rBlessings, rDischarges)
if err != nil {
return nil, err
}
bkey, dkey, err := c.blessingsFlow.put(ctx, blessings, discharges)
if err != nil {
return nil, err
}
defer c.mu.Unlock()
c.mu.Lock()
if c.flows == nil {
return nil, NewErrConnectionClosed(ctx)
}
id := c.nextFid
c.nextFid++
return c.newFlowLocked(ctx, id, bkey, dkey, true, false), nil
}
// LocalEndpoint returns the local vanadium Endpoint
func (c *Conn) LocalEndpoint() naming.Endpoint { return c.local }
// RemoteEndpoint returns the remote vanadium Endpoint
func (c *Conn) RemoteEndpoint() naming.Endpoint { return c.remote }
// LastUsedTime returns the time at which the Conn had bytes read or written on it.
func (c *Conn) LastUsedTime() time.Time {
defer c.mu.Unlock()
c.mu.Lock()
return c.lastUsedTime
}
// Closed returns a channel that will be closed after the Conn is shutdown.
// After this channel is closed it is guaranteed that all Dial calls will fail
// with an error and no more flows will be sent to the FlowHandler.
func (c *Conn) Closed() <-chan struct{} { return c.closed }
// Close shuts down a conn.
func (c *Conn) Close(ctx *context.T, err error) {
c.mu.Lock()
var flows map[uint64]*flw
flows, c.flows = c.flows, nil
if c.dischargeTimer != nil {
c.dischargeTimer.Stop()
c.dischargeTimer = nil
}
c.mu.Unlock()
if flows == nil {
// This conn is already being torn down.
<-c.closed
return
}
c.internalClose(ctx, err, flows)
}
func (c *Conn) internalClose(ctx *context.T, err error, flows map[uint64]*flw) {
ctx.VI(2).Infof("Closing connection: %v", err)
if verror.ErrorID(err) != ErrConnClosedRemotely.ID {
msg := ""
if err != nil {
msg = err.Error()
}
cerr := c.fc.Run(ctx, "close", expressPriority, func(_ int) (int, bool, error) {
return 0, true, c.mp.writeMsg(ctx, &message.TearDown{Message: msg})
})
if cerr != nil {
ctx.Errorf("Error sending tearDown on connection to %s: %v", c.remote, cerr)
}
}
for _, f := range flows {
f.close(ctx, NewErrConnectionClosed(ctx))
}
if cerr := c.mp.close(); cerr != nil {
ctx.Errorf("Error closing underlying connection for %s: %v", c.remote, cerr)
}
c.loopWG.Wait()
close(c.closed)
}
func (c *Conn) release(ctx *context.T) {
counts := map[uint64]uint64{}
c.mu.Lock()
for fid, f := range c.flows {
if release := f.q.release(); release > 0 {
counts[fid] = uint64(release)
}
}
c.mu.Unlock()
if len(counts) == 0 {
return
}
err := c.fc.Run(ctx, "release", expressPriority, func(_ int) (int, bool, error) {
err := c.mp.writeMsg(ctx, &message.Release{
Counters: counts,
})
return 0, true, err
})
if err != nil {
c.Close(ctx, NewErrSend(ctx, "release", c.remote.String(), err))
}
}
func (c *Conn) handleMessage(ctx *context.T, m message.Message) error {
switch msg := m.(type) {
case *message.TearDown:
return NewErrConnClosedRemotely(ctx, msg.Message)
case *message.OpenFlow:
if c.handler == nil {
return NewErrUnexpectedMsg(ctx, "openFlow")
}
c.mu.Lock()
f := c.newFlowLocked(ctx, msg.ID, msg.BlessingsKey, msg.DischargeKey, false, true)
c.mu.Unlock()
c.handler.HandleFlow(f)
case *message.Release:
release := make([]flowcontrol.Release, 0, len(msg.Counters))
c.mu.Lock()
for fid, val := range msg.Counters {
if f := c.flows[fid]; f != nil {
release = append(release, flowcontrol.Release{
Worker: f.worker,
Tokens: int(val),
})
}
}
c.mu.Unlock()
if err := c.fc.Release(ctx, release); err != nil {
return err
}
case *message.Data:
c.mu.Lock()
f := c.flows[msg.ID]
c.mu.Unlock()
if f == nil {
ctx.Infof("Ignoring data message for unknown flow on connection to %s: %d", c.remote, msg.ID)
return nil
}
if err := f.q.put(ctx, msg.Payload); err != nil {
return err
}
if msg.Flags&message.CloseFlag != 0 {
f.close(ctx, NewErrFlowClosedRemotely(f.ctx))
}
default:
return NewErrUnexpectedMsg(ctx, reflect.TypeOf(msg).String())
}
return nil
}
func (c *Conn) readLoop(ctx *context.T) {
var err error
for {
msg, rerr := c.mp.readMsg(ctx)
if rerr != nil {
err = NewErrRecv(ctx, c.remote.String(), rerr)
break
}
if err = c.handleMessage(ctx, msg); err != nil {
break
}
}
c.mu.Lock()
var flows map[uint64]*flw
flows, c.flows = c.flows, nil
c.mu.Unlock()
c.loopWG.Done()
if flows != nil {
c.internalClose(ctx, err, flows)
}
}
func (c *Conn) markUsed() {
c.mu.Lock()
c.lastUsedTime = time.Now()
c.mu.Unlock()
}