third_party/mdns: pack messages into as few packets as possible
Change-Id: Iaaa5098a3222c4f17605b81bfefe1a40aca8e51f
diff --git a/go/src/github.com/presotto/go-mdns-sd/go_dns/msg.go b/go/src/github.com/presotto/go-mdns-sd/go_dns/msg.go
index 9133188..fa2b87e 100644
--- a/go/src/github.com/presotto/go-mdns-sd/go_dns/msg.go
+++ b/go/src/github.com/presotto/go-mdns-sd/go_dns/msg.go
@@ -782,27 +782,32 @@
Extra []RR
}
-func (dns *Msg) Pack() (msg []byte, ok bool) {
+func (dns *Msg) dnsHeaderBits() uint16 {
+ bits := uint16(dns.Opcode)<<11 | uint16(dns.Rcode)
+ if dns.RecursionAvailable {
+ bits |= _RA
+ }
+ if dns.RecursionDesired {
+ bits |= _RD
+ }
+ if dns.Truncated {
+ bits |= _TC
+ }
+ if dns.Authoritative {
+ bits |= _AA
+ }
+ if dns.Response {
+ bits |= _QR
+ }
+ return bits
+}
+
+func (dns *Msg) Pack() ([]byte, bool) {
var dh dnsHeader
// Convert convenient Msg into wire-like dnsHeader.
dh.Id = dns.ID
- dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode)
- if dns.RecursionAvailable {
- dh.Bits |= _RA
- }
- if dns.RecursionDesired {
- dh.Bits |= _RD
- }
- if dns.Truncated {
- dh.Bits |= _TC
- }
- if dns.Authoritative {
- dh.Bits |= _AA
- }
- if dns.Response {
- dh.Bits |= _QR
- }
+ dh.Bits = dns.dnsHeaderBits()
// Prepare variable sized arrays.
question := dns.Question
@@ -818,27 +823,62 @@
// Could work harder to calculate message size,
// but this is far more than we need and not
// big enough to hurt the allocator.
- msg = make([]byte, 2000)
+ msg := make([]byte, 1460)
// Pack it in: header and then the pieces.
- off := 0
- off, ok = packStruct(&dh, msg, off)
- for i := 0; i < len(question); i++ {
+ off, ok := packStruct(&dh, msg, 0)
+ for i := 0; ok && i < len(question); i++ {
off, ok = packStruct(&question[i], msg, off)
}
- for i := 0; i < len(answer); i++ {
+ for i := 0; ok && i < len(answer); i++ {
off, ok = packRR(answer[i], msg, off)
}
- for i := 0; i < len(ns); i++ {
+ for i := 0; ok && i < len(ns); i++ {
off, ok = packRR(ns[i], msg, off)
}
- for i := 0; i < len(extra); i++ {
+ for i := 0; ok && i < len(extra); i++ {
off, ok = packRR(extra[i], msg, off)
}
if !ok {
return nil, false
}
- return msg[0:off], true
+ return msg[:off], true
+}
+
+func (dns *Msg) PackTo(msg []byte) ([]byte, bool) {
+ var dh dnsHeader
+ if _, ok := unpackStruct(&dh, msg, 0); !ok {
+ return nil, false
+ }
+ // We only want to combine answers to the same question without breaking
+ // ordering in the DNS packet.
+ if dh.Bits != dns.dnsHeaderBits() || dh.Nscount > 0 || dh.Arcount > 0 {
+ return nil, false
+ }
+ if len(dns.Question) > 0 || len(dns.NS) > 0 || len(dns.Extra) > 0 {
+ return nil, false
+ }
+
+ answer := dns.Answer
+ dh.Ancount += uint16(len(answer))
+
+ off := len(msg)
+ msg = msg[:cap(msg)]
+
+ var ok bool
+ for i := 0; i < len(answer); i++ {
+ off, ok = packRR(answer[i], msg, off)
+ if !ok {
+ return nil, false
+ }
+ }
+
+ // Update the packed header.
+ _, ok = packStruct(&dh, msg, 0)
+ if !ok {
+ panic("pack failed") // should never happen!
+ }
+ return msg[:off], true
}
func (dns *Msg) Unpack(msg []byte) bool {
diff --git a/go/src/github.com/presotto/go-mdns-sd/go_dns/msg_test.go b/go/src/github.com/presotto/go-mdns-sd/go_dns/msg_test.go
index 93d9731..89f9980 100644
--- a/go/src/github.com/presotto/go-mdns-sd/go_dns/msg_test.go
+++ b/go/src/github.com/presotto/go-mdns-sd/go_dns/msg_test.go
@@ -35,6 +35,32 @@
}
}
+func TestDNSPack(t *testing.T) {
+ rrs := []RR{
+ &RR_TXT{RR_Header{"foo.local.", TypeTXT, ClassINET, 10, 0}, []string{"foo1", "foo2"}},
+ &RR_TXT{RR_Header{"bar.local.", TypeTXT, ClassINET, 20, 0}, []string{"bar1", "bar2"}},
+ &RR_TXT{RR_Header{"baz.local.", TypeTXT, ClassINET, 30, 0}, []string{"baz1", "baz2"}},
+ }
+
+ msg := &Msg{Answer: []RR{rrs[0]}}
+ data, ok := msg.Pack()
+ if !ok {
+ t.Fatalf("failed to pack message %v", msg)
+ }
+ for i := 1; i < len(rrs); i++ {
+ msg = &Msg{Answer: []RR{rrs[i]}}
+ data, ok = msg.PackTo(data)
+ if !ok {
+ t.Fatalf("failed to pack message %v", msg)
+ }
+ }
+ msg = new(Msg)
+ msg.Unpack(data)
+ if !reflect.DeepEqual(rrs, msg.Answer) {
+ t.Errorf("expected %v, but got %v", rrs, msg.Answer)
+ }
+}
+
func TestDNSParseSRVReply(t *testing.T) {
data, err := hex.DecodeString(dnsSRVReply)
if err != nil {
@@ -71,9 +97,9 @@
msg2.Unpack(data2)
switch {
case !ok:
- t.Errorf("failed to repack message")
+ t.Error("failed to repack message")
case !reflect.DeepEqual(msg, msg2):
- t.Errorf("repacked message differs from original")
+ t.Error("repacked message differs from original")
}
}
diff --git a/go/src/github.com/presotto/go-mdns-sd/mdns.go b/go/src/github.com/presotto/go-mdns-sd/mdns.go
index 62642e7..9a33fb4 100644
--- a/go/src/github.com/presotto/go-mdns-sd/mdns.go
+++ b/go/src/github.com/presotto/go-mdns-sd/mdns.go
@@ -144,27 +144,46 @@
}
// Send a message on a multicast net and cache it locally.
-func (m *multicastIfc) sendMessage(msg *dns.Msg) {
- if m.mdns.logLevel >= 2 {
- log.Printf("sending message %v\n", msg)
- }
- buf, ok := msg.Pack()
- if !ok {
- if m.mdns.logLevel >= 1 {
- log.Printf("can't pack address message\n")
+func (m *multicastIfc) sendMessages(msgs ...*dns.Msg) {
+ var ok bool
+ var buf []byte
+ var bufs [][]byte
+ for _, msg := range msgs {
+ if m.mdns.logLevel >= 2 {
+ log.Printf("sending message %v\n", msg)
}
- return
- }
- if _, err := m.conn.WriteTo(buf, m.addr); err != nil {
- if m.mdns.logLevel >= 1 {
- log.Printf("WriteTo failed %v %v", m.addr, err)
+ if buf != nil {
+ old := buf
+ buf, ok = msg.PackTo(buf)
+ if !ok {
+ // Probably due to space.
+ bufs = append(bufs, old)
+ }
+ }
+ if buf == nil {
+ buf, ok = msg.Pack()
+ if !ok {
+ if m.mdns.logLevel >= 1 {
+ log.Printf("can't pack message %v\n", msg)
+ }
+ continue
+ }
+ }
+ // Cache these RRs in case we ask about ourself.
+ for _, rr := range msg.Answer {
+ if m.cache.Add(rr) {
+ m.mdns.changedRR(rr)
+ }
}
}
-
- // Cache these RRs in case we ask about ourself.
- for _, rr := range msg.Answer {
- if m.cache.Add(rr) {
- m.mdns.changedRR(rr)
+ if buf != nil {
+ bufs = append(bufs, buf)
+ }
+ for _, buf := range bufs {
+ if _, err := m.conn.WriteTo(buf, m.addr); err != nil {
+ if m.mdns.logLevel >= 1 {
+ log.Printf("WriteTo failed %v %v", m.addr, err)
+ }
}
}
}
@@ -173,21 +192,21 @@
func (m *multicastIfc) announceHost(host string, ttl uint32) {
msg := newDnsMsg(0, true, true)
m.appendHostAddresses(msg, host, dns.TypeALL, ttl)
- m.sendMessage(msg)
+ m.sendMessages(msg)
}
// Announce a service and how to reach it.
func (m *multicastIfc) announceService(service, host string, port uint16, txt []string, ttl uint32) {
msg := newDnsMsg(0, true, true)
m.appendDiscoveryRecords(msg, service, host, port, txt, ttl)
- m.sendMessage(msg)
+ m.sendMessages(msg)
}
// Ask a question.
func (m *multicastIfc) sendQuestion(q []dns.Question) {
msg := newDnsMsg(0, false, false)
msg.Question = q
- m.sendMessage(msg)
+ m.sendMessages(msg)
}
type lookupRequest struct {
@@ -712,10 +731,8 @@
msgs = append(msgs, s.answerTXT(m, q)...)
}
}
- for _, msg := range msgs {
- if len(msg.Answer) > 0 {
- m.mifc.sendMessage(msg)
- }
+ if len(msgs) > 0 {
+ m.mifc.sendMessages(msgs...)
}
}