Merge "websocket: avoid panic from invalid websocket requests"
diff --git a/runtime/internal/lib/websocket/listener.go b/runtime/internal/lib/websocket/listener.go
index e32a116..1206e8c 100644
--- a/runtime/internal/lib/websocket/listener.go
+++ b/runtime/internal/lib/websocket/listener.go
@@ -182,11 +182,17 @@
}
ws, err := websocket.Upgrade(w, r, nil, bufferSize, bufferSize)
if _, ok := err.(websocket.HandshakeError); ok {
+ // Close the connection to not serve HTTP requests from this connection
+ // any more. Otherwise panic from negative httpReq counter can occur.
+ // Although go http.Server gracefully shutdowns the server from a panic,
+ // it would be nice to avoid it.
+ w.Header().Set("Connection", "close")
http.Error(w, "Not a websocket handshake", http.StatusBadRequest)
logger.Global().Errorf("Rejected a non-websocket request: %v", err)
return
}
if err != nil {
+ w.Header().Set("Connection", "close")
http.Error(w, "Internal Error", http.StatusInternalServerError)
logger.Global().Errorf("Rejected a non-websocket request: %v", err)
return
diff --git a/runtime/internal/lib/websocket/listener_test.go b/runtime/internal/lib/websocket/listener_test.go
index e112714..477965a 100644
--- a/runtime/internal/lib/websocket/listener_test.go
+++ b/runtime/internal/lib/websocket/listener_test.go
@@ -7,7 +7,10 @@
package websocket
import (
+ "bytes"
+ "log"
"net"
+ "strings"
"testing"
"time"
)
@@ -41,8 +44,7 @@
close(portscan)
// Keep the connection alive by blocking on a read. (The read
// should return once the test exits).
- var buf [1024]byte
- conn.Read(buf[:])
+ conn.Read(make([]byte, 1024))
}()
// Another client that dials a legitimate connection should not be
// blocked on the portscanner.
@@ -54,3 +56,39 @@
}
conn.Close()
}
+
+func TestNonWebsocketRequest(t *testing.T) {
+ ln, err := HybridListener("wsh", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() { go ln.Close() }()
+
+ // Goroutine that continuously accepts connections.
+ go func() {
+ for {
+ _, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ }
+ }()
+
+ var out bytes.Buffer
+ log.SetOutput(&out)
+
+ // Imagine some client keeps sending non-websocket requests.
+ conn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ }
+ for i := 0; i < 2; i++ {
+ conn.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+ conn.Read(make([]byte, 1024))
+ }
+
+ logs := out.String()
+ if strings.Contains(logs, "panic") {
+ t.Errorf("Unexpected panic:\n%s", logs)
+ }
+}