third_party/mdns: pack messages into as few packets as possible

Change-Id: Iaaa5098a3222c4f17605b81bfefe1a40aca8e51f
diff --git a/go/src/github.com/presotto/go-mdns-sd/go_dns/msg.go b/go/src/github.com/presotto/go-mdns-sd/go_dns/msg.go
index 9133188..fa2b87e 100644
--- a/go/src/github.com/presotto/go-mdns-sd/go_dns/msg.go
+++ b/go/src/github.com/presotto/go-mdns-sd/go_dns/msg.go
@@ -782,27 +782,32 @@
 	Extra    []RR
 }
 
-func (dns *Msg) Pack() (msg []byte, ok bool) {
+func (dns *Msg) dnsHeaderBits() uint16 {
+	bits := uint16(dns.Opcode)<<11 | uint16(dns.Rcode)
+	if dns.RecursionAvailable {
+		bits |= _RA
+	}
+	if dns.RecursionDesired {
+		bits |= _RD
+	}
+	if dns.Truncated {
+		bits |= _TC
+	}
+	if dns.Authoritative {
+		bits |= _AA
+	}
+	if dns.Response {
+		bits |= _QR
+	}
+	return bits
+}
+
+func (dns *Msg) Pack() ([]byte, bool) {
 	var dh dnsHeader
 
 	// Convert convenient Msg into wire-like dnsHeader.
 	dh.Id = dns.ID
-	dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode)
-	if dns.RecursionAvailable {
-		dh.Bits |= _RA
-	}
-	if dns.RecursionDesired {
-		dh.Bits |= _RD
-	}
-	if dns.Truncated {
-		dh.Bits |= _TC
-	}
-	if dns.Authoritative {
-		dh.Bits |= _AA
-	}
-	if dns.Response {
-		dh.Bits |= _QR
-	}
+	dh.Bits = dns.dnsHeaderBits()
 
 	// Prepare variable sized arrays.
 	question := dns.Question
@@ -818,27 +823,62 @@
 	// Could work harder to calculate message size,
 	// but this is far more than we need and not
 	// big enough to hurt the allocator.
-	msg = make([]byte, 2000)
+	msg := make([]byte, 1460)
 
 	// Pack it in: header and then the pieces.
-	off := 0
-	off, ok = packStruct(&dh, msg, off)
-	for i := 0; i < len(question); i++ {
+	off, ok := packStruct(&dh, msg, 0)
+	for i := 0; ok && i < len(question); i++ {
 		off, ok = packStruct(&question[i], msg, off)
 	}
-	for i := 0; i < len(answer); i++ {
+	for i := 0; ok && i < len(answer); i++ {
 		off, ok = packRR(answer[i], msg, off)
 	}
-	for i := 0; i < len(ns); i++ {
+	for i := 0; ok && i < len(ns); i++ {
 		off, ok = packRR(ns[i], msg, off)
 	}
-	for i := 0; i < len(extra); i++ {
+	for i := 0; ok && i < len(extra); i++ {
 		off, ok = packRR(extra[i], msg, off)
 	}
 	if !ok {
 		return nil, false
 	}
-	return msg[0:off], true
+	return msg[:off], true
+}
+
+func (dns *Msg) PackTo(msg []byte) ([]byte, bool) {
+	var dh dnsHeader
+	if _, ok := unpackStruct(&dh, msg, 0); !ok {
+		return nil, false
+	}
+	// We only want to combine answers to the same question without breaking
+	// ordering in the DNS packet.
+	if dh.Bits != dns.dnsHeaderBits() || dh.Nscount > 0 || dh.Arcount > 0 {
+		return nil, false
+	}
+	if len(dns.Question) > 0 || len(dns.NS) > 0 || len(dns.Extra) > 0 {
+		return nil, false
+	}
+
+	answer := dns.Answer
+	dh.Ancount += uint16(len(answer))
+
+	off := len(msg)
+	msg = msg[:cap(msg)]
+
+	var ok bool
+	for i := 0; i < len(answer); i++ {
+		off, ok = packRR(answer[i], msg, off)
+		if !ok {
+			return nil, false
+		}
+	}
+
+	// Update the packed header.
+	_, ok = packStruct(&dh, msg, 0)
+	if !ok {
+		panic("pack failed") // should never happen!
+	}
+	return msg[:off], true
 }
 
 func (dns *Msg) Unpack(msg []byte) bool {
diff --git a/go/src/github.com/presotto/go-mdns-sd/go_dns/msg_test.go b/go/src/github.com/presotto/go-mdns-sd/go_dns/msg_test.go
index 93d9731..89f9980 100644
--- a/go/src/github.com/presotto/go-mdns-sd/go_dns/msg_test.go
+++ b/go/src/github.com/presotto/go-mdns-sd/go_dns/msg_test.go
@@ -35,6 +35,32 @@
 	}
 }
 
+func TestDNSPack(t *testing.T) {
+	rrs := []RR{
+		&RR_TXT{RR_Header{"foo.local.", TypeTXT, ClassINET, 10, 0}, []string{"foo1", "foo2"}},
+		&RR_TXT{RR_Header{"bar.local.", TypeTXT, ClassINET, 20, 0}, []string{"bar1", "bar2"}},
+		&RR_TXT{RR_Header{"baz.local.", TypeTXT, ClassINET, 30, 0}, []string{"baz1", "baz2"}},
+	}
+
+	msg := &Msg{Answer: []RR{rrs[0]}}
+	data, ok := msg.Pack()
+	if !ok {
+		t.Fatalf("failed to pack message %v", msg)
+	}
+	for i := 1; i < len(rrs); i++ {
+		msg = &Msg{Answer: []RR{rrs[i]}}
+		data, ok = msg.PackTo(data)
+		if !ok {
+			t.Fatalf("failed to pack message %v", msg)
+		}
+	}
+	msg = new(Msg)
+	msg.Unpack(data)
+	if !reflect.DeepEqual(rrs, msg.Answer) {
+		t.Errorf("expected %v, but got %v", rrs, msg.Answer)
+	}
+}
+
 func TestDNSParseSRVReply(t *testing.T) {
 	data, err := hex.DecodeString(dnsSRVReply)
 	if err != nil {
@@ -71,9 +97,9 @@
 	msg2.Unpack(data2)
 	switch {
 	case !ok:
-		t.Errorf("failed to repack message")
+		t.Error("failed to repack message")
 	case !reflect.DeepEqual(msg, msg2):
-		t.Errorf("repacked message differs from original")
+		t.Error("repacked message differs from original")
 	}
 }
 
diff --git a/go/src/github.com/presotto/go-mdns-sd/mdns.go b/go/src/github.com/presotto/go-mdns-sd/mdns.go
index 62642e7..9a33fb4 100644
--- a/go/src/github.com/presotto/go-mdns-sd/mdns.go
+++ b/go/src/github.com/presotto/go-mdns-sd/mdns.go
@@ -144,27 +144,46 @@
 }
 
 // Send a message on a multicast net and cache it locally.
-func (m *multicastIfc) sendMessage(msg *dns.Msg) {
-	if m.mdns.logLevel >= 2 {
-		log.Printf("sending message %v\n", msg)
-	}
-	buf, ok := msg.Pack()
-	if !ok {
-		if m.mdns.logLevel >= 1 {
-			log.Printf("can't pack address message\n")
+func (m *multicastIfc) sendMessages(msgs ...*dns.Msg) {
+	var ok bool
+	var buf []byte
+	var bufs [][]byte
+	for _, msg := range msgs {
+		if m.mdns.logLevel >= 2 {
+			log.Printf("sending message %v\n", msg)
 		}
-		return
-	}
-	if _, err := m.conn.WriteTo(buf, m.addr); err != nil {
-		if m.mdns.logLevel >= 1 {
-			log.Printf("WriteTo failed %v %v", m.addr, err)
+		if buf != nil {
+			old := buf
+			buf, ok = msg.PackTo(buf)
+			if !ok {
+				// Probably due to space.
+				bufs = append(bufs, old)
+			}
+		}
+		if buf == nil {
+			buf, ok = msg.Pack()
+			if !ok {
+				if m.mdns.logLevel >= 1 {
+					log.Printf("can't pack message %v\n", msg)
+				}
+				continue
+			}
+		}
+		// Cache these RRs in case we ask about ourself.
+		for _, rr := range msg.Answer {
+			if m.cache.Add(rr) {
+				m.mdns.changedRR(rr)
+			}
 		}
 	}
-
-	// Cache these RRs in case we ask about ourself.
-	for _, rr := range msg.Answer {
-		if m.cache.Add(rr) {
-			m.mdns.changedRR(rr)
+	if buf != nil {
+		bufs = append(bufs, buf)
+	}
+	for _, buf := range bufs {
+		if _, err := m.conn.WriteTo(buf, m.addr); err != nil {
+			if m.mdns.logLevel >= 1 {
+				log.Printf("WriteTo failed %v %v", m.addr, err)
+			}
 		}
 	}
 }
@@ -173,21 +192,21 @@
 func (m *multicastIfc) announceHost(host string, ttl uint32) {
 	msg := newDnsMsg(0, true, true)
 	m.appendHostAddresses(msg, host, dns.TypeALL, ttl)
-	m.sendMessage(msg)
+	m.sendMessages(msg)
 }
 
 // Announce a service and how to reach it.
 func (m *multicastIfc) announceService(service, host string, port uint16, txt []string, ttl uint32) {
 	msg := newDnsMsg(0, true, true)
 	m.appendDiscoveryRecords(msg, service, host, port, txt, ttl)
-	m.sendMessage(msg)
+	m.sendMessages(msg)
 }
 
 // Ask a question.
 func (m *multicastIfc) sendQuestion(q []dns.Question) {
 	msg := newDnsMsg(0, false, false)
 	msg.Question = q
-	m.sendMessage(msg)
+	m.sendMessages(msg)
 }
 
 type lookupRequest struct {
@@ -712,10 +731,8 @@
 			msgs = append(msgs, s.answerTXT(m, q)...)
 		}
 	}
-	for _, msg := range msgs {
-		if len(msg.Answer) > 0 {
-			m.mifc.sendMessage(msg)
-		}
+	if len(msgs) > 0 {
+		m.mifc.sendMessages(msgs...)
 	}
 }