diff --git a/irc/ircconn.go b/irc/ircconn.go index 088909a2..baeeb5cf 100644 --- a/irc/ircconn.go +++ b/irc/ircconn.go @@ -5,6 +5,7 @@ package irc import ( "bytes" + "io" "net" "unicode/utf8" @@ -93,21 +94,25 @@ func (cc *IRCStreamConn) Close() (err error) { // IRCWSConn is an IRCConn over a websocket. type IRCWSConn struct { conn *websocket.Conn + buf []byte binary bool } -func NewIRCWSConn(conn *websocket.Conn) IRCWSConn { - binary := conn.Subprotocol() == "binary.ircv3.net" - return IRCWSConn{conn: conn, binary: binary} +func NewIRCWSConn(conn *websocket.Conn) *IRCWSConn { + return &IRCWSConn{ + conn: conn, + binary: conn.Subprotocol() == "binary.ircv3.net", + buf: make([]byte, initialBufferSize), + } } -func (wc IRCWSConn) UnderlyingConn() *utils.WrappedConn { +func (wc *IRCWSConn) UnderlyingConn() *utils.WrappedConn { // just assume that the type is OK wConn, _ := wc.conn.UnderlyingConn().(*utils.WrappedConn) return wConn } -func (wc IRCWSConn) WriteLine(buf []byte) (err error) { +func (wc *IRCWSConn) WriteLine(buf []byte) (err error) { buf = bytes.TrimSuffix(buf, crlf) // #1483: if we have websockets at all, then we're enforcing utf8 messageType := websocket.TextMessage @@ -117,7 +122,7 @@ func (wc IRCWSConn) WriteLine(buf []byte) (err error) { return wc.conn.WriteMessage(messageType, buf) } -func (wc IRCWSConn) WriteLines(buffers [][]byte) (err error) { +func (wc *IRCWSConn) WriteLines(buffers [][]byte) (err error) { for _, buf := range buffers { err = wc.WriteLine(buf) if err != nil { @@ -127,20 +132,47 @@ func (wc IRCWSConn) WriteLines(buffers [][]byte) (err error) { return } -func (wc IRCWSConn) ReadLine() (line []byte, err error) { - messageType, line, err := wc.conn.ReadMessage() - if err == nil { - if messageType == websocket.BinaryMessage && !utf8.Valid(line) { +func (wc *IRCWSConn) ReadLine() (line []byte, err error) { + _, reader, err := wc.conn.NextReader() + switch err { + case nil: + // OK + case websocket.ErrReadLimit: + return line, ircreader.ErrReadQ + default: + return line, err + } + + line, err = wc.readFull(reader) + switch err { + case io.ErrUnexpectedEOF, io.EOF: + // these are OK. io.ErrUnexpectedEOF is the good case: + // it means we read the full message and it consumed less than the full wc.buf + if !utf8.Valid(line) { return line, errInvalidUtf8 } return line, nil - } else if err == websocket.ErrReadLimit { + case nil, websocket.ErrReadLimit: + // nil means we filled wc.buf without exhausting the reader: return line, ircreader.ErrReadQ - } else { + default: return line, err } } -func (wc IRCWSConn) Close() (err error) { +func (wc *IRCWSConn) readFull(reader io.Reader) (line []byte, err error) { + // XXX this is io.ReadFull with a single attempt to resize upwards + n, err := io.ReadFull(reader, wc.buf) + if err == nil && len(wc.buf) < maxReadQBytes() { + newBuf := make([]byte, maxReadQBytes()) + copy(newBuf, wc.buf[:n]) + wc.buf = newBuf + n2, err := io.ReadFull(reader, wc.buf[n:]) + return wc.buf[:n+n2], err + } + return wc.buf[:n], err +} + +func (wc *IRCWSConn) Close() (err error) { return wc.conn.Close() }