blob: a3e086911c730b6dfee625214bb227d952a8f3a6 [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rpc
import (
"fmt"
"net"
"sort"
"strings"
"sync"
"v.io/v23/naming"
"v.io/v23/verror"
"v.io/x/lib/netstate"
)
var (
// These errors are intended to be used as arguments to higher
// level errors and hence {1}{2} is omitted from their format
// strings to avoid repeating these n-times in the final error
// message visible to the user.
errMalformedEndpoint = reg(".errMalformedEndpoint", "malformed endpoint{:3}")
errUndesiredProtocol = reg(".errUndesiredProtocol", "undesired protocol{:3}")
errIncompatibleEndpointVersions = reg(".errIncompatibleEndpointVersions", "incompatible endpoint versions{:3}")
errNoCompatibleServers = reg(".errNoComaptibleServers", "failed to find any compatible servers{:3}")
)
var defaultPreferredProtocolOrder = mkProtocolRankMap([]string{"unixfd", "wsh", "tcp4", "tcp", "*"})
type serverLocality int
const (
unknownNetwork serverLocality = iota
remoteNetwork
localNetwork
)
const maxCacheSize = 1 << 11
var (
cacheMu sync.Mutex
cacheValid <-chan struct{}
ipNetworksCache []*net.IPNet
serversCache = make(map[string]sortableServer) // keyed by concatenated protocols and server name.
)
func init() {
valid := make(chan struct{})
cacheValid = valid
close(valid)
}
// filterAndOrderServers returns a set of servers that are compatible with
// the current client in order of 'preference' specified by the supplied
// protocols and a notion of 'locality' according to the supplied protocol
// list as follows:
// - if the protocol parameter is non-empty, then only servers matching those
// protocols are returned and the endpoints are ordered first by protocol
// and then by locality within each protocol. If tcp4 and unixfd are requested
// for example then only protocols that match tcp4 and unixfd will returned
// with the tcp4 ones preceeding the unixfd ones.
// - if the protocol parameter is empty, then a default protocol ordering
// will be used, but unlike the previous case, any servers that don't support
// these protocols will be returned also, but following the default
// preferences.
func filterAndOrderServers(servers []naming.MountedServer, protocols []string, ipnets ...*net.IPNet) ([]naming.MountedServer, error) {
if ipnets == nil {
if err := refreshCache(); err != nil {
return nil, err
}
cacheMu.Lock()
ipnets = ipNetworksCache
cacheMu.Unlock()
}
var (
errs = verror.SubErrs{}
list = make(sortableServerList, 0, len(servers))
protoRanks = mkProtocolRankMap(protocols)
// TODO(suharshs): We can sort protocols before concatenating to increase
// cache usage.
// We prefix cached server information by protocols because different
// preferred protocols will have different ranks.
protocolsKey = strings.Join(protocols, ",")
)
if len(protoRanks) == 0 {
protoRanks = defaultPreferredProtocolOrder
}
adderr := func(name string, err error) {
errs = append(errs, verror.SubErr{Name: "server=" + name, Err: err, Options: verror.Print})
}
for _, server := range servers {
ss, err := mkSortableServer(server, protoRanks, protocolsKey, ipnets)
if err != nil {
adderr(server.Server, err)
continue
}
list = append(list, ss)
}
if len(list) == 0 {
return nil, verror.AddSubErrs(verror.New(errNoCompatibleServers, nil), nil, errs...)
}
// TODO(ashankar): Don't have to use stable sorting, could
// just use sort.Sort. The only problem with that is the
// unittest.
sort.Stable(list)
// Convert to []naming.MountedServer
ret := make([]naming.MountedServer, len(list))
for idx, item := range list {
ret[idx] = item.server
}
return ret, nil
}
func mkSortableServer(server naming.MountedServer, protoRanks map[string]int, protocolsKey string, ipnets []*net.IPNet) (sortableServer, error) {
name := server.Server
k := name + "," + protocolsKey
cacheMu.Lock()
ss, ok := serversCache[k]
cacheMu.Unlock()
if ok {
return ss, nil
}
ep, err := name2endpoint(name)
if err != nil {
return sortableServer{}, verror.New(errMalformedEndpoint, nil, err)
}
rank, err := protocol2rank(ep.Addr().Network(), protoRanks)
if err != nil {
return sortableServer{}, err
}
defer cacheMu.Unlock()
cacheMu.Lock()
if len(serversCache) >= maxCacheSize {
for k := range serversCache {
delete(serversCache, k)
break
}
}
ss = sortableServer{
server: server,
protocolRank: rank,
locality: locality(ep, ipnets),
}
serversCache[k] = ss
return ss, nil
}
// name2endpoint returns the naming.Endpoint encoded in a name.
func name2endpoint(name string) (naming.Endpoint, error) {
addr := name
if naming.Rooted(name) {
addr, _ = naming.SplitAddressName(name)
}
return naming.ParseEndpoint(addr)
}
// protocol2rank returns the "rank" of a protocol (given a map of ranks).
// The higher the rank, the more preferable the protocol.
func protocol2rank(protocol string, ranks map[string]int) (int, error) {
if r, ok := ranks[protocol]; ok {
return r, nil
}
// Special case: if "wsh" has a rank but "wsh4"/"wsh6" don't,
// then they get the same rank as "wsh". Similar for "tcp" and "ws".
//
// TODO(jhahn): We have similar protocol equivalency checks at a few places.
// Figure out a way for this mapping to be shared.
if p := protocol; p == "wsh4" || p == "wsh6" || p == "tcp4" || p == "tcp6" || p == "ws4" || p == "ws6" {
if r, ok := ranks[p[:len(p)-1]]; ok {
return r, nil
}
}
// "*" means that any protocol is acceptable.
if r, ok := ranks["*"]; ok {
return r, nil
}
// UnknownProtocol should be rare, it typically happens when
// the endpoint is described in <host>:<port> format instead of
// the full fidelity description (@<version>@<protocol>@...).
if protocol == naming.UnknownProtocol {
return -1, nil
}
return 0, verror.New(errUndesiredProtocol, nil, protocol)
}
// locality returns the serverLocality to use given an endpoint and the
// set of IP networks configured on this machine.
func locality(ep naming.Endpoint, ipnets []*net.IPNet) serverLocality {
if len(ipnets) < 1 {
return unknownNetwork // 0 IP networks, locality doesn't matter.
}
host, _, err := net.SplitHostPort(ep.Addr().String())
if err != nil {
host = ep.Addr().String()
}
ip := net.ParseIP(host)
if ip == nil {
// Not an IP address (possibly not an IP network).
return unknownNetwork
}
for _, ipnet := range ipnets {
if ipnet.Contains(ip) {
return localNetwork
}
}
return remoteNetwork
}
// ipNetworks returns the IP networks on this machine.
// The returned chan is closed when the ipnetworks have changed.
func ipNetworks() ([]*net.IPNet, <-chan struct{}, error) {
ifcs, valid, err := netstate.GetAllAddresses()
if err != nil {
return nil, nil, err
}
ret := make([]*net.IPNet, 0, len(ifcs))
for _, a := range ifcs {
_, ipnet, err := net.ParseCIDR(a.String())
if err != nil {
return nil, nil, err
}
ret = append(ret, ipnet)
}
return ret, valid, nil
}
func refreshCache() error {
cacheMu.Lock()
select {
case <-cacheValid:
var err error
if ipNetworksCache, cacheValid, err = ipNetworks(); err != nil {
return err
}
serversCache = make(map[string]sortableServer)
default:
}
cacheMu.Unlock()
return nil
}
type sortableServer struct {
server naming.MountedServer
protocolRank int // larger values are preferred.
locality serverLocality // larger values are preferred.
}
func (s *sortableServer) String() string {
return fmt.Sprintf("%v", s.server)
}
type sortableServerList []sortableServer
func (l sortableServerList) Len() int { return len(l) }
func (l sortableServerList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
func (l sortableServerList) Less(i, j int) bool {
if l[i].protocolRank == l[j].protocolRank {
return l[i].locality > l[j].locality
}
return l[i].protocolRank > l[j].protocolRank
}
func mkProtocolRankMap(list []string) map[string]int {
if len(list) == 0 {
return nil
}
m := make(map[string]int)
for idx, protocol := range list {
m[protocol] = len(list) - idx
}
return m
}