third_party/mdns: add RemoveService
RemoveService() removes the announced service from the mdns instance by
sending a goodbye packet, which is mdns response with a TTL of zero.
Also added UnsubscribeFromService() removes the service from 'subscribed' map.
Change-Id: I44a41334f63091c6c570134ffc0962828456cf40
diff --git a/go/src/github.com/presotto/go-mdns-sd/mdns.go b/go/src/github.com/presotto/go-mdns-sd/mdns.go
index a3a32db..daf492d 100644
--- a/go/src/github.com/presotto/go-mdns-sd/mdns.go
+++ b/go/src/github.com/presotto/go-mdns-sd/mdns.go
@@ -29,13 +29,14 @@
import (
"errors"
"fmt"
- "github.com/presotto/go-mdns-sd/go_dns"
"log"
"net"
"reflect"
"strings"
"sync"
"time"
+
+ "github.com/presotto/go-mdns-sd/go_dns"
)
// All incoming network messages carries enough context for a network appropriate response.
@@ -78,7 +79,7 @@
ifc: ifc,
addr: addr,
addresses: addresses,
- cache: newRRCache(),
+ cache: newRRCache(mdns.debug),
mdns: mdns,
ipver: ipver,
}
@@ -137,7 +138,11 @@
msg.Answer = append(msg.Answer, NewPtrRR(serviceDN, dns.ClassINET, ttl, uniqueServiceDN))
m.appendTxtRR(msg, service, host, txt, ttl)
m.appendSrvRR(msg, service, host, port, ttl)
- m.appendHostAddresses(msg, host, dns.TypeALL, ttl)
+ if ttl > 0 {
+ // Do not append host address in a goodbye packet, since host may be
+ // shared by other services and we do not want to delete it.
+ m.appendHostAddresses(msg, host, dns.TypeALL, ttl)
+ }
}
// Send a message on a multicast net and cache it locally.
@@ -196,6 +201,11 @@
txt []string
}
+type goodbyeRequest struct {
+ service string
+ host string
+}
+
type watchedService struct {
c *sync.Cond
gen int
@@ -218,6 +228,7 @@
// All access methods turn into channel requests to the main loop to make synchronization trivial.
announce chan announceRequest
+ goodbye chan goodbyeRequest
lookup chan lookupRequest
refreshAlarm chan struct{}
cleanupAlarm chan struct{}
@@ -315,6 +326,7 @@
// Allocate channels for communications internal to MDNS
s.fromNet = make(chan *msgFromNet, 10)
s.announce = make(chan announceRequest)
+ s.goodbye = make(chan goodbyeRequest)
s.lookup = make(chan lookupRequest)
s.refreshAlarm = make(chan struct{})
s.cleanupAlarm = make(chan struct{})
@@ -679,6 +691,15 @@
for _, mifc := range s.mifcs {
mifc.announceService(req.service, req.host, req.port, req.txt, s.ttl)
}
+ case req := <-s.goodbye:
+ // Removing a service
+ delete(s.services, req.service)
+ log.Printf("removing service %s %s\n", req.service, req.host)
+
+ // Tell all the networks about the goodbye
+ for _, mifc := range s.mifcs {
+ mifc.announceService(req.service, req.host, 0, nil, 0)
+ }
case req := <-s.lookup:
// Reply with all matching requests from all interfaces and then close the channel.
for _, mifc := range s.mifcs {
@@ -746,6 +767,23 @@
return nil
}
+// Remove a service. If the host name is empty, we just use the host name from NewMDNS. If the host name ends in .local. we strip it off.
+func (s *MDNS) RemoveService(service, host string) error {
+ if len(service) == 0 {
+ return errors.New("service name cannot be null")
+ }
+ if len(host) == 0 {
+ if s.hostName == "" {
+ return errors.New("RemoveService requires a host name")
+ }
+ host = s.hostName
+ } else {
+ host = hostUnqualify(host)
+ }
+ s.goodbye <- goodbyeRequest{service, host}
+ return nil
+}
+
// Resolve a particular RR type.
func (s *MDNS) ResolveRR(dn string, rrtype uint16) []dns.RR {
dn = hostFQDN(dn)
@@ -818,7 +856,7 @@
return ips, minttl
}
-// SubscriberToService declare our interest in a service. This should elicit responses from everyone implementing that service. This is
+// SubscriberToService declares our interest in a service. This should elicit responses from everyone implementing that service. This is
// orthogonal to offering the service ourselves.
func (s *MDNS) SubscribeToService(service string) {
serviceDN := serviceFQDN(service)
@@ -833,6 +871,14 @@
}
}
+// UnsubscribeFromService withholds our interest in a service.
+func (s *MDNS) UnsubscribeFromService(service string) {
+ serviceDN := serviceFQDN(service)
+ s.watchedLock.Lock()
+ delete(s.subscribed, serviceDN)
+ s.watchedLock.Unlock()
+}
+
type ServiceInstance struct {
Name string
SrvRRs []*dns.RR_SRV
diff --git a/go/src/github.com/presotto/go-mdns-sd/mdns_test.go b/go/src/github.com/presotto/go-mdns-sd/mdns_test.go
index 7eb70d4..ba71794 100644
--- a/go/src/github.com/presotto/go-mdns-sd/mdns_test.go
+++ b/go/src/github.com/presotto/go-mdns-sd/mdns_test.go
@@ -131,6 +131,18 @@
watchFor(t, c, instances[1])
close(c)
+ // Remove a service from one of the mdns instances.
+ s1.RemoveService("veyronns", instances[0].host)
+
+ // Wait for a goodbye message to get out and get reflected back.
+ time.Sleep(3 * time.Second)
+
+ // Make sure service discovery doesn't return the removed service.
+ discovered = s1.ServiceDiscovery("veyronns")
+ checkDiscovered(t, instances[0].host, discovered, []instance{instances[1]})
+ discovered = s2.ServiceDiscovery("veyronns")
+ checkDiscovered(t, instances[1].host, discovered, []instance{instances[1]})
+
s1.Stop()
s2.Stop()
}
diff --git a/go/src/github.com/presotto/go-mdns-sd/rrcache.go b/go/src/github.com/presotto/go-mdns-sd/rrcache.go
index f842896..f57c0c9 100644
--- a/go/src/github.com/presotto/go-mdns-sd/rrcache.go
+++ b/go/src/github.com/presotto/go-mdns-sd/rrcache.go
@@ -6,10 +6,11 @@
// A cache of DNS RRs (resource records).
import (
- "github.com/presotto/go-mdns-sd/go_dns"
"log"
"reflect"
"time"
+
+ "github.com/presotto/go-mdns-sd/go_dns"
)
type rrCacheEntry struct {
@@ -25,9 +26,10 @@
}
// Create a new rr cache. Make sure at least the top level map exists.
-func newRRCache() *rrCache {
+func newRRCache(debug bool) *rrCache {
rrcache := new(rrCache)
rrcache.cache = make(map[string]map[uint16][]*rrCacheEntry, 0)
+ rrcache.debug = debug
return rrcache
}
@@ -42,8 +44,8 @@
// Create an entry for the domain name if none exists.
dnmap, ok := c.cache[rr.Header().Name]
if !ok {
- c.cache[rr.Header().Name] = make(map[uint16][]*rrCacheEntry, 0)
- dnmap = c.cache[rr.Header().Name]
+ dnmap = make(map[uint16][]*rrCacheEntry, 0)
+ c.cache[rr.Header().Name] = dnmap
}
// Remove all rr's matching this one's type if a cache flush is requested.
@@ -54,9 +56,16 @@
dnmap[rr.Header().Rrtype] = make([]*rrCacheEntry, 0)
}
- // Don't believe TTLs greater than 75 minutes. Entries should refresh much faster than this.
- if rr.Header().Ttl > 4500 {
+ switch {
+ case rr.Header().Ttl > 4500:
+ // Don't believe TTLs greater than 75 minutes. Entries should refresh much faster than this.
rr.Header().Ttl = 4500
+ case rr.Header().Ttl == 0:
+ // This is a goodbye packet. RFC 6762 specifies that queriers receiving a multicast DNS
+ // response with a TTL of zero should record a TTL of 1 and then delete the record one
+ // second later. This gives the other cooperating responders one second to rescue the
+ // records when the goodbye packet was sent incorrectly.
+ rr.Header().Ttl = 1
}
// Add absolute expiration time to the entry.
diff --git a/go/src/github.com/presotto/go-mdns-sd/rrcache_test.go b/go/src/github.com/presotto/go-mdns-sd/rrcache_test.go
index 7916e12..2a926e6 100644
--- a/go/src/github.com/presotto/go-mdns-sd/rrcache_test.go
+++ b/go/src/github.com/presotto/go-mdns-sd/rrcache_test.go
@@ -23,6 +23,10 @@
&dns.RR_TXT{dns.RR_Header{"x.local.", dns.TypeTXT, dns.ClassINET | 0x8000, 10000, 0}, []string{"except on tuesday"}},
&dns.RR_PTR{dns.RR_Header{"x.local.", dns.TypePTR, dns.ClassINET | 0x8000, 10000, 0}, "q.local."},
}
+ goodbye []dns.RR = []dns.RR{
+ &dns.RR_TXT{dns.RR_Header{"x.local.", dns.TypeTXT, dns.ClassINET, 0, 0}, []string{"except on tuesday"}},
+ &dns.RR_PTR{dns.RR_Header{"x.local.", dns.TypePTR, dns.ClassINET, 0, 0}, "q.local."},
+ }
)
// Lookup RRs that match.
@@ -77,7 +81,7 @@
}
func TestRRCache(t *testing.T) {
- cache := newRRCache()
+ cache := newRRCache(*debugFlag)
// Cache a number of RRs with short TTLs.
for _, rr := range short {
cache.Add(rr)
@@ -115,4 +119,14 @@
if !compare(x, override) {
t.Errorf("%v != %v", x, override)
}
+
+ // Make sure goodbye works. The entries should be deleted after one second.
+ for _, rr := range goodbye {
+ cache.Add(rr)
+ }
+ time.Sleep(2 * time.Second)
+ x = lookup(cache, "x.local.", dns.TypeALL)
+ if len(x) != 0 {
+ t.Errorf("%v != []", x)
+ }
}