Merge "websocket: avoid panic from invalid websocket requests"
diff --git a/runtime/internal/lib/websocket/listener.go b/runtime/internal/lib/websocket/listener.go
index e32a116..1206e8c 100644
--- a/runtime/internal/lib/websocket/listener.go
+++ b/runtime/internal/lib/websocket/listener.go
@@ -182,11 +182,17 @@
 	}
 	ws, err := websocket.Upgrade(w, r, nil, bufferSize, bufferSize)
 	if _, ok := err.(websocket.HandshakeError); ok {
+		// Close the connection to not serve HTTP requests from this connection
+		// any more. Otherwise panic from negative httpReq counter can occur.
+		// Although go http.Server gracefully shutdowns the server from a panic,
+		// it would be nice to avoid it.
+		w.Header().Set("Connection", "close")
 		http.Error(w, "Not a websocket handshake", http.StatusBadRequest)
 		logger.Global().Errorf("Rejected a non-websocket request: %v", err)
 		return
 	}
 	if err != nil {
+		w.Header().Set("Connection", "close")
 		http.Error(w, "Internal Error", http.StatusInternalServerError)
 		logger.Global().Errorf("Rejected a non-websocket request: %v", err)
 		return
diff --git a/runtime/internal/lib/websocket/listener_test.go b/runtime/internal/lib/websocket/listener_test.go
index e112714..477965a 100644
--- a/runtime/internal/lib/websocket/listener_test.go
+++ b/runtime/internal/lib/websocket/listener_test.go
@@ -7,7 +7,10 @@
 package websocket
 
 import (
+	"bytes"
+	"log"
 	"net"
+	"strings"
 	"testing"
 	"time"
 )
@@ -41,8 +44,7 @@
 		close(portscan)
 		// Keep the connection alive by blocking on a read.  (The read
 		// should return once the test exits).
-		var buf [1024]byte
-		conn.Read(buf[:])
+		conn.Read(make([]byte, 1024))
 	}()
 	// Another client that dials a legitimate connection should not be
 	// blocked on the portscanner.
@@ -54,3 +56,39 @@
 	}
 	conn.Close()
 }
+
+func TestNonWebsocketRequest(t *testing.T) {
+	ln, err := HybridListener("wsh", "127.0.0.1:0")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() { go ln.Close() }()
+
+	// Goroutine that continuously accepts connections.
+	go func() {
+		for {
+			_, err := ln.Accept()
+			if err != nil {
+				return
+			}
+		}
+	}()
+
+	var out bytes.Buffer
+	log.SetOutput(&out)
+
+	// Imagine some client keeps sending non-websocket requests.
+	conn, err := net.Dial("tcp", ln.Addr().String())
+	if err != nil {
+		t.Error(err)
+	}
+	for i := 0; i < 2; i++ {
+		conn.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+		conn.Read(make([]byte, 1024))
+	}
+
+	logs := out.String()
+	if strings.Contains(logs, "panic") {
+		t.Errorf("Unexpected panic:\n%s", logs)
+	}
+}