| // Copyright 2013 Gary Burd. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package websocket |
| |
| import ( |
| "bufio" |
| "encoding/binary" |
| "errors" |
| "io" |
| "io/ioutil" |
| "math/rand" |
| "net" |
| "strconv" |
| "time" |
| ) |
| |
| // Close codes defined in RFC 6455, section 11.7. |
| const ( |
| CloseNormalClosure = 1000 |
| CloseGoingAway = 1001 |
| CloseProtocolError = 1002 |
| CloseUnsupportedData = 1003 |
| CloseNoStatusReceived = 1005 |
| CloseAbnormalClosure = 1006 |
| CloseInvalidFramePayloadData = 1007 |
| ClosePolicyViolation = 1008 |
| CloseMessageTooBig = 1009 |
| CloseMandatoryExtension = 1010 |
| CloseInternalServerErr = 1011 |
| CloseTLSHandshake = 1015 |
| ) |
| |
| // The message types are defined in RFC 6455, section 11.8. |
| const ( |
| // TextMessage denotes a text data message. The text message payload is |
| // interpreted as UTF-8 encoded text data. |
| TextMessage = 1 |
| |
| // BinaryMessage denotes a binary data message. |
| BinaryMessage = 2 |
| |
| // CloseMessage denotes a close control message. The optional message |
| // payload contains a numeric code and text. Use the FormatCloseMessage |
| // function to format a close message payload. |
| CloseMessage = 8 |
| |
| // PingMessage denotes a ping control message. The optional message payload |
| // is UTF-8 encoded text. |
| PingMessage = 9 |
| |
| // PongMessage denotes a ping control message. The optional message payload |
| // is UTF-8 encoded text. |
| PongMessage = 10 |
| ) |
| |
| var ( |
| continuationFrame = 0 |
| noFrame = -1 |
| ) |
| |
| var ( |
| ErrCloseSent = errors.New("websocket: close sent") |
| ErrReadLimit = errors.New("websocket: read limit exceeded") |
| ) |
| |
| type websocketError struct { |
| msg string |
| temporary bool |
| timeout bool |
| } |
| |
| func (e *websocketError) Error() string { return e.msg } |
| func (e *websocketError) Temporary() bool { return e.temporary } |
| func (e *websocketError) Timeout() bool { return e.timeout } |
| |
| var ( |
| errWriteTimeout = &websocketError{msg: "websocket: write timeout", timeout: true} |
| errBadWriteOpCode = errors.New("websocket: bad write message type") |
| errWriteClosed = errors.New("websocket: write closed") |
| errInvalidControlFrame = errors.New("websocket: invalid control frame") |
| ) |
| |
| const ( |
| maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask |
| maxControlFramePayloadSize = 125 |
| finalBit = 1 << 7 |
| maskBit = 1 << 7 |
| writeWait = time.Second |
| ) |
| |
| func isControl(frameType int) bool { |
| return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage |
| } |
| |
| func isData(frameType int) bool { |
| return frameType == TextMessage || frameType == BinaryMessage |
| } |
| |
| func maskBytes(key [4]byte, pos int, b []byte) int { |
| for i := range b { |
| b[i] ^= key[pos&3] |
| pos += 1 |
| } |
| return pos & 3 |
| } |
| |
| func newMaskKey() [4]byte { |
| n := rand.Uint32() |
| return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 32)} |
| } |
| |
| // Conn represents a WebSocket connection. |
| type Conn struct { |
| conn net.Conn |
| isServer bool |
| subprotocol string |
| |
| // Write fields |
| mu chan bool // used as mutex to protect write to conn and closeSent |
| closeSent bool // true if close message was sent |
| |
| // Message writer fields. |
| writeErr error |
| writeBuf []byte // frame is constructed in this buffer. |
| writePos int // end of data in writeBuf. |
| writeFrameType int // type of the current frame. |
| writeSeq int // incremented to invalidate message writers. |
| writeDeadline time.Time |
| |
| // Read fields |
| readErr error |
| br *bufio.Reader |
| readRemaining int64 // bytes remaining in current frame. |
| readFinal bool // true the current message has more frames. |
| readSeq int // incremented to invalidate message readers. |
| readLength int64 // Message size. |
| readLimit int64 // Maximum message size. |
| readMaskPos int |
| readMaskKey [4]byte |
| handlePong func(string) error |
| handlePing func(string) error |
| } |
| |
| func newConn(conn net.Conn, isServer bool, readBufSize, writeBufSize int) *Conn { |
| mu := make(chan bool, 1) |
| mu <- true |
| |
| c := &Conn{ |
| isServer: isServer, |
| br: bufio.NewReaderSize(conn, readBufSize), |
| conn: conn, |
| mu: mu, |
| readFinal: true, |
| writeBuf: make([]byte, writeBufSize+maxFrameHeaderSize), |
| writeFrameType: noFrame, |
| writePos: maxFrameHeaderSize, |
| } |
| c.SetPingHandler(nil) |
| c.SetPongHandler(nil) |
| return c |
| } |
| |
| // Subprotocol returns the negotiated protocol for the connection. |
| func (c *Conn) Subprotocol() string { |
| return c.subprotocol |
| } |
| |
| // Close closes the underlying network connection without sending or waiting for a close frame. |
| func (c *Conn) Close() error { |
| return c.conn.Close() |
| } |
| |
| // LocalAddr returns the local network address. |
| func (c *Conn) LocalAddr() net.Addr { |
| return c.conn.LocalAddr() |
| } |
| |
| // RemoteAddr returns the remote network address. |
| func (c *Conn) RemoteAddr() net.Addr { |
| return c.conn.RemoteAddr() |
| } |
| |
| // Write methods |
| |
| func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { |
| <-c.mu |
| defer func() { c.mu <- true }() |
| |
| if c.closeSent { |
| return ErrCloseSent |
| } else if frameType == CloseMessage { |
| c.closeSent = true |
| } |
| |
| c.conn.SetWriteDeadline(deadline) |
| for _, buf := range bufs { |
| if len(buf) > 0 { |
| n, err := c.conn.Write(buf) |
| if n != len(buf) { |
| // Close on partial write. |
| c.conn.Close() |
| } |
| if err != nil { |
| return err |
| } |
| } |
| } |
| return nil |
| } |
| |
| // WriteControl writes a control message with the given deadline. The allowed |
| // message types are CloseMessage, PingMessage and PongMessage. |
| func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { |
| if !isControl(messageType) { |
| return errBadWriteOpCode |
| } |
| if len(data) > maxControlFramePayloadSize { |
| return errInvalidControlFrame |
| } |
| |
| b0 := byte(messageType) | finalBit |
| b1 := byte(len(data)) |
| if !c.isServer { |
| b1 |= maskBit |
| } |
| |
| buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) |
| buf = append(buf, b0, b1) |
| |
| if c.isServer { |
| buf = append(buf, data...) |
| } else { |
| key := newMaskKey() |
| buf = append(buf, key[:]...) |
| buf = append(buf, data...) |
| maskBytes(key, 0, buf[6:]) |
| } |
| |
| d := time.Hour * 1000 |
| if !deadline.IsZero() { |
| d = deadline.Sub(time.Now()) |
| if d < 0 { |
| return errWriteTimeout |
| } |
| } |
| |
| timer := time.NewTimer(d) |
| select { |
| case <-c.mu: |
| timer.Stop() |
| case <-timer.C: |
| return errWriteTimeout |
| } |
| defer func() { c.mu <- true }() |
| |
| if c.closeSent { |
| return ErrCloseSent |
| } else if messageType == CloseMessage { |
| c.closeSent = true |
| } |
| |
| c.conn.SetWriteDeadline(deadline) |
| n, err := c.conn.Write(buf) |
| if n != 0 && n != len(buf) { |
| c.conn.Close() |
| } |
| return err |
| } |
| |
| // NextWriter returns a writer for the next message to send. The writer's |
| // Close method flushes the complete message to the network. |
| // |
| // There can be at most one open writer on a connection. NextWriter closes the |
| // previous writer if the application has not already done so. |
| // |
| // The NextWriter method and the writers returned from the method cannot be |
| // accessed by more than one goroutine at a time. |
| func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { |
| if c.writeErr != nil { |
| return nil, c.writeErr |
| } |
| |
| if c.writeFrameType != noFrame { |
| if err := c.flushFrame(true, nil); err != nil { |
| return nil, err |
| } |
| } |
| |
| if !isControl(messageType) && !isData(messageType) { |
| return nil, errBadWriteOpCode |
| } |
| |
| c.writeFrameType = messageType |
| return messageWriter{c, c.writeSeq}, nil |
| } |
| |
| func (c *Conn) flushFrame(final bool, extra []byte) error { |
| length := c.writePos - maxFrameHeaderSize + len(extra) |
| |
| // Check for invalid control frames. |
| if isControl(c.writeFrameType) && |
| (!final || length > maxControlFramePayloadSize) { |
| c.writeSeq += 1 |
| c.writeFrameType = noFrame |
| c.writePos = maxFrameHeaderSize |
| return errInvalidControlFrame |
| } |
| |
| b0 := byte(c.writeFrameType) |
| if final { |
| b0 |= finalBit |
| } |
| b1 := byte(0) |
| if !c.isServer { |
| b1 |= maskBit |
| } |
| |
| // Assume that the frame starts at beginning of c.writeBuf. |
| framePos := 0 |
| if c.isServer { |
| // Adjust up if mask not included in the header. |
| framePos = 4 |
| } |
| |
| switch { |
| case length >= 65536: |
| c.writeBuf[framePos] = b0 |
| c.writeBuf[framePos+1] = b1 | 127 |
| binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) |
| case length > 125: |
| framePos += 6 |
| c.writeBuf[framePos] = b0 |
| c.writeBuf[framePos+1] = b1 | 126 |
| binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) |
| default: |
| framePos += 8 |
| c.writeBuf[framePos] = b0 |
| c.writeBuf[framePos+1] = b1 | byte(length) |
| } |
| |
| if !c.isServer { |
| key := newMaskKey() |
| copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) |
| maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) |
| if len(extra) > 0 { |
| c.writeErr = errors.New("websocket: internal error, extra used in client mode") |
| return c.writeErr |
| } |
| } |
| |
| // Write the buffers to the connection. |
| c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) |
| |
| // Setup for next frame. |
| c.writePos = maxFrameHeaderSize |
| c.writeFrameType = continuationFrame |
| if final { |
| c.writeSeq += 1 |
| c.writeFrameType = noFrame |
| } |
| return c.writeErr |
| } |
| |
| type messageWriter struct { |
| c *Conn |
| seq int |
| } |
| |
| func (w messageWriter) err() error { |
| c := w.c |
| if c.writeSeq != w.seq { |
| return errWriteClosed |
| } |
| if c.writeErr != nil { |
| return c.writeErr |
| } |
| return nil |
| } |
| |
| func (w messageWriter) ncopy(max int) (int, error) { |
| n := len(w.c.writeBuf) - w.c.writePos |
| if n <= 0 { |
| if err := w.c.flushFrame(false, nil); err != nil { |
| return 0, err |
| } |
| n = len(w.c.writeBuf) - w.c.writePos |
| } |
| if n > max { |
| n = max |
| } |
| return n, nil |
| } |
| |
| func (w messageWriter) write(final bool, p []byte) (int, error) { |
| if err := w.err(); err != nil { |
| return 0, err |
| } |
| |
| if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { |
| // Don't buffer large messages. |
| err := w.c.flushFrame(final, p) |
| if err != nil { |
| return 0, err |
| } |
| return len(p), nil |
| } |
| |
| nn := len(p) |
| for len(p) > 0 { |
| n, err := w.ncopy(len(p)) |
| if err != nil { |
| return 0, err |
| } |
| copy(w.c.writeBuf[w.c.writePos:], p[:n]) |
| w.c.writePos += n |
| p = p[n:] |
| } |
| return nn, nil |
| } |
| |
| func (w messageWriter) Write(p []byte) (int, error) { |
| return w.write(false, p) |
| } |
| |
| func (w messageWriter) WriteString(p string) (int, error) { |
| if err := w.err(); err != nil { |
| return 0, err |
| } |
| |
| nn := len(p) |
| for len(p) > 0 { |
| n, err := w.ncopy(len(p)) |
| if err != nil { |
| return 0, err |
| } |
| copy(w.c.writeBuf[w.c.writePos:], p[:n]) |
| w.c.writePos += n |
| p = p[n:] |
| } |
| return nn, nil |
| } |
| |
| func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { |
| if err := w.err(); err != nil { |
| return 0, err |
| } |
| for { |
| if w.c.writePos == len(w.c.writeBuf) { |
| err = w.c.flushFrame(false, nil) |
| if err != nil { |
| break |
| } |
| } |
| var n int |
| n, err = r.Read(w.c.writeBuf[w.c.writePos:]) |
| w.c.writePos += n |
| nn += int64(n) |
| if err != nil { |
| if err == io.EOF { |
| err = nil |
| } |
| break |
| } |
| } |
| return nn, err |
| } |
| |
| func (w messageWriter) Close() error { |
| if err := w.err(); err != nil { |
| return err |
| } |
| return w.c.flushFrame(true, nil) |
| } |
| |
| // WriteMessage is a helper method for getting a writer using NextWriter, |
| // writing the message and closing the writer. |
| func (c *Conn) WriteMessage(messageType int, data []byte) error { |
| wr, err := c.NextWriter(messageType) |
| if err != nil { |
| return err |
| } |
| w := wr.(messageWriter) |
| if _, err := w.write(true, data); err != nil { |
| return err |
| } |
| if c.writeSeq == w.seq { |
| if err := c.flushFrame(true, nil); err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| // SetWriteDeadline sets the deadline for future calls to NextWriter and the |
| // io.WriteCloser returned from NextWriter. If the deadline is reached, the |
| // call will fail with a timeout instead of blocking. A zero value for t means |
| // Write will not time out. Even if Write times out, it may return n > 0, |
| // indicating that some of the data was successfully written. |
| func (c *Conn) SetWriteDeadline(t time.Time) error { |
| c.writeDeadline = t |
| return nil |
| } |
| |
| // Read methods |
| |
| func (c *Conn) advanceFrame() (int, error) { |
| |
| // 1. Skip remainder of previous frame. |
| |
| if c.readRemaining > 0 { |
| if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { |
| return noFrame, err |
| } |
| } |
| |
| // 2. Read and parse first two bytes of frame header. |
| |
| var b [8]byte |
| if err := c.read(b[:2]); err != nil { |
| return noFrame, err |
| } |
| |
| final := b[0]&finalBit != 0 |
| frameType := int(b[0] & 0xf) |
| reserved := int((b[0] >> 4) & 0x7) |
| mask := b[1]&maskBit != 0 |
| c.readRemaining = int64(b[1] & 0x7f) |
| |
| if reserved != 0 { |
| return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) |
| } |
| |
| switch frameType { |
| case CloseMessage, PingMessage, PongMessage: |
| if c.readRemaining > maxControlFramePayloadSize { |
| return noFrame, c.handleProtocolError("control frame length > 125") |
| } |
| if !final { |
| return noFrame, c.handleProtocolError("control frame not final") |
| } |
| case TextMessage, BinaryMessage: |
| if !c.readFinal { |
| return noFrame, c.handleProtocolError("message start before final message frame") |
| } |
| c.readFinal = final |
| case continuationFrame: |
| if c.readFinal { |
| return noFrame, c.handleProtocolError("continuation after final message frame") |
| } |
| c.readFinal = final |
| default: |
| return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) |
| } |
| |
| // 3. Read and parse frame length. |
| |
| switch c.readRemaining { |
| case 126: |
| if err := c.read(b[:2]); err != nil { |
| return noFrame, err |
| } |
| c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) |
| case 127: |
| if err := c.read(b[:8]); err != nil { |
| return noFrame, err |
| } |
| c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) |
| } |
| |
| // 4. Handle frame masking. |
| |
| if mask != c.isServer { |
| return noFrame, c.handleProtocolError("incorrect mask flag") |
| } |
| |
| if mask { |
| c.readMaskPos = 0 |
| if err := c.read(c.readMaskKey[:]); err != nil { |
| return noFrame, err |
| } |
| } |
| |
| // 5. For text and binary messages, enforce read limit and return. |
| |
| if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { |
| |
| c.readLength += c.readRemaining |
| if c.readLimit > 0 && c.readLength > c.readLimit { |
| c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) |
| return noFrame, ErrReadLimit |
| } |
| |
| return frameType, nil |
| } |
| |
| // 6. Read control frame payload. |
| |
| payload := make([]byte, c.readRemaining) |
| c.readRemaining = 0 |
| if err := c.read(payload); err != nil { |
| return noFrame, err |
| } |
| maskBytes(c.readMaskKey, 0, payload) |
| |
| // 7. Process control frame payload. |
| |
| switch frameType { |
| case PongMessage: |
| if err := c.handlePong(string(payload)); err != nil { |
| return noFrame, err |
| } |
| case PingMessage: |
| if err := c.handlePing(string(payload)); err != nil { |
| return noFrame, err |
| } |
| case CloseMessage: |
| c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait)) |
| if len(payload) < 2 { |
| return noFrame, io.EOF |
| } |
| closeCode := binary.BigEndian.Uint16(payload) |
| switch closeCode { |
| case CloseNormalClosure, CloseGoingAway: |
| return noFrame, io.EOF |
| default: |
| return noFrame, errors.New("websocket: close " + |
| strconv.Itoa(int(closeCode)) + " " + |
| string(payload[2:])) |
| } |
| } |
| |
| return frameType, nil |
| } |
| |
| func (c *Conn) handleProtocolError(message string) error { |
| c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) |
| return errors.New("websocket: " + message) |
| } |
| |
| func (c *Conn) read(buf []byte) error { |
| var err error |
| for len(buf) > 0 && err == nil { |
| var nn int |
| nn, err = c.br.Read(buf) |
| buf = buf[nn:] |
| } |
| if err == io.EOF { |
| if len(buf) == 0 { |
| err = nil |
| } else { |
| err = io.ErrUnexpectedEOF |
| } |
| } |
| return err |
| } |
| |
| // NextReader returns the next data message received from the peer. The |
| // returned messageType is either TextMessage or BinaryMessage. |
| // |
| // There can be at most one open reader on a connection. NextReader discards |
| // the previous message if the application has not already consumed it. |
| // |
| // The NextReader method and the readers returned from the method cannot be |
| // accessed by more than one goroutine at a time. |
| func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { |
| |
| c.readSeq += 1 |
| c.readLength = 0 |
| |
| for c.readErr == nil { |
| var frameType int |
| frameType, c.readErr = c.advanceFrame() |
| if frameType == TextMessage || frameType == BinaryMessage { |
| return frameType, messageReader{c, c.readSeq}, nil |
| } |
| } |
| return noFrame, nil, c.readErr |
| } |
| |
| type messageReader struct { |
| c *Conn |
| seq int |
| } |
| |
| func (r messageReader) Read(b []byte) (n int, err error) { |
| |
| if r.seq != r.c.readSeq { |
| return 0, io.EOF |
| } |
| |
| for r.c.readErr == nil { |
| |
| if r.c.readRemaining > 0 { |
| if int64(len(b)) > r.c.readRemaining { |
| b = b[:r.c.readRemaining] |
| } |
| r.c.readErr = r.c.read(b) |
| r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b) |
| r.c.readRemaining -= int64(len(b)) |
| return len(b), r.c.readErr |
| } |
| |
| if r.c.readFinal { |
| r.c.readSeq += 1 |
| return 0, io.EOF |
| } |
| |
| var frameType int |
| frameType, r.c.readErr = r.c.advanceFrame() |
| |
| if frameType == TextMessage || frameType == BinaryMessage { |
| r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") |
| } |
| } |
| return 0, r.c.readErr |
| } |
| |
| // ReadMessage is a helper method for getting a reader using NextReader and |
| // reading from that reader to a buffer. |
| func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { |
| var r io.Reader |
| messageType, r, err = c.NextReader() |
| if err != nil { |
| return messageType, nil, err |
| } |
| p, err = ioutil.ReadAll(r) |
| return messageType, p, err |
| } |
| |
| // SetReadDeadline sets the deadline for future calls to NextReader and the |
| // io.Reader returned from NextReader. If the deadline is reached, the call |
| // will fail with a timeout instead of blocking. A zero value for t means that |
| // the methods will not time out. |
| func (c *Conn) SetReadDeadline(t time.Time) error { |
| return c.conn.SetReadDeadline(t) |
| } |
| |
| // SetReadLimit sets the maximum size for a message read from the peer. If a |
| // message exceeds the limit, the connection sends a close frame to the peer |
| // and returns ErrReadLimit to the application. |
| func (c *Conn) SetReadLimit(limit int64) { |
| c.readLimit = limit |
| } |
| |
| // SetPingHandler sets the handler for ping messages received from the peer. |
| // The default ping handler sends a pong to the peer. |
| func (c *Conn) SetPingHandler(h func(string) error) { |
| if h == nil { |
| h = func(message string) error { |
| c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) |
| return nil |
| } |
| } |
| c.handlePing = h |
| } |
| |
| // SetPongHandler sets then handler for pong messages received from the peer. |
| // The default pong handler does nothing. |
| func (c *Conn) SetPongHandler(h func(string) error) { |
| if h == nil { |
| h = func(string) error { return nil } |
| } |
| c.handlePong = h |
| } |
| |
| // FormatCloseMessage formats closeCode and text as a WebSocket close message. |
| func FormatCloseMessage(closeCode int, text string) []byte { |
| buf := make([]byte, 2+len(text)) |
| binary.BigEndian.PutUint16(buf, uint16(closeCode)) |
| copy(buf[2:], text) |
| return buf |
| } |