From dd858347155aabf2b81750fa5cf143423e215c94 Mon Sep 17 00:00:00 2001 From: Zhi Wang Date: Mon, 10 Jan 2022 13:04:09 -0500 Subject: [PATCH] properly close the connection with two chans: pty and ws chans --- relay.go | 246 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 154 insertions(+), 92 deletions(-) diff --git a/relay.go b/relay.go index 7f9ba5a..ba27773 100644 --- a/relay.go +++ b/relay.go @@ -7,6 +7,7 @@ import ( "net/url" "os" "os/exec" + "sync" "time" "github.com/creack/pty" @@ -23,7 +24,9 @@ const ( pongWait = 10 * time.Second // Maximum message size allowed from peer. - maxMessageSize = 8192 + maxMessageSize = 4096 + readBufferSize = 1024 + WriteBufferSize = 1024 // Send pings to peer with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 @@ -32,6 +35,27 @@ const ( closeGracePeriod = 10 * time.Second ) +var host *string = nil + +var upgrader = websocket.Upgrader{ + ReadBufferSize: readBufferSize, + WriteBufferSize: WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { + org := r.Header.Get("Origin") + h, err := url.Parse(org) + + if err != nil { + return false + } + + if (host == nil) || (*host != h.Host) { + log.Println("Failed origin check of ", org) + } + + return (host != nil) && (*host == h.Host) + }, +} + // TermConn represents the connected websocket and pty. // if isViewer is true type TermConn struct { @@ -39,10 +63,11 @@ type TermConn struct { name string // only valid for doers - ptmx *os.File // the pty that runs the command - cmd *exec.Cmd // represents the process, we need it to terminate the process - vchan chan *websocket.Conn // channel to receive viewers - done chan struct{} + ptmx *os.File // the pty that runs the command + cmd *exec.Cmd // represents the process, we need it to terminate the process + vchan chan *websocket.Conn // channel to receive viewers + ws_done chan struct{} // ws is closed, only close this chan in ws reader + pty_done chan struct{} // pty is closed, close this chan in pty reader } func (tc *TermConn) createPty(cmdline []string) error { @@ -71,31 +96,14 @@ func (tc *TermConn) createPty(cmdline []string) error { return nil } -var host *string = nil - -var upgrader = websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - CheckOrigin: func(r *http.Request) bool { - org := r.Header.Get("Origin") - h, err := url.Parse(org) - - if err != nil { - return false - } - - if (host == nil) || (*host != h.Host) { - log.Println("Failed origin check of ", org) - } - - return (host != nil) && (*host == h.Host) - }, -} - // Periodically send ping message to detect the status of the ws -func (tc *TermConn) ping() { +func (tc *TermConn) ping(wg *sync.WaitGroup) { + defer wg.Done() + ticker := time.NewTicker(pingPeriod) defer ticker.Stop() + +out: for { select { case <-ticker.C: @@ -104,18 +112,25 @@ func (tc *TermConn) ping() { if err != nil { log.Println("Failed to write ping message:", err) - return + break out } + case <-tc.pty_done: + log.Println("Exit ping routine as pty is going away") + break out - case <-tc.done: - log.Println("Exit ping routine as pty/ws is going away") - return + case <-tc.ws_done: + log.Println("Exit ping routine as ws is going away") + break out } } + + log.Println("Ping routine exited") } // shovel data from websocket to pty stdin -func (tc *TermConn) wsToPtyStdin() { +func (tc *TermConn) wsToPtyStdin(wg *sync.WaitGroup) { + defer wg.Done() + tc.ws.SetReadLimit(maxMessageSize) // set the readdeadline. The idea here is simple, @@ -128,75 +143,122 @@ func (tc *TermConn) wsToPtyStdin() { return nil }) - // we do not need to forward user input to viewers, only the stdout - for { - _, buf, err := tc.ws.ReadMessage() + bufChan := make(chan []byte) - if err != nil { - log.Println("Failed to receive data from ws:", err) - break + go func() { //create a goroutine to read from ws + for { + _, buf, err := tc.ws.ReadMessage() + + if err != nil { + log.Println("Failed to receive data from ws:", err) + close(bufChan) // close chan by producer + close(tc.ws_done) + break + } + + bufChan <- buf } + }() + // we do not need to forward user input to viewers, only the stdout +out: + for { + select { + case buf, ok := <-bufChan: + if !ok { + log.Println("Exit wsToPtyStdin routine pty stdin error") + break out + } + _, err := tc.ptmx.Write(buf) - _, err = tc.ptmx.Write(buf) - - if err != nil { - log.Println("Failed to send data to pty stdin: ", err) - break + if err != nil { + log.Println("Failed to send data to pty stdin: ", err) + break out + } + case <-tc.ws_done: + log.Println("Exit wsToPtyStdin routine as ws is going away") + break out + case <-tc.pty_done: + log.Println("Exit wsToPtyStdin routine as pty is going away") + break out } } + + log.Println("wsToPtyStdin routine exited") } // shovel data from pty Stdout to WS -func (tc *TermConn) ptyStdoutToWs() { +func (tc *TermConn) ptyStdoutToWs(wg *sync.WaitGroup) { var viewers []*websocket.Conn - readBuf := make([]byte, 4096) - closed := false + + defer wg.Done() + bufChan := make(chan []byte) + + go func() { //create a goroutine to read from pty + for { + readBuf := make([]byte, 1024) //pty reads in 1024 blocks + n, err := tc.ptmx.Read(readBuf) + + if err != nil { + log.Println("Failed to read from pty stdout: ", err) + close(bufChan) + close(tc.pty_done) + break + } + + readBuf = readBuf[:n] // slice the buffer so that it is exact the size of data read. + bufChan <- readBuf + } + }() out: for { - n, err := tc.ptmx.Read(readBuf) - - if err != nil { - log.Println("Failed to read from pty stdout: ", err) - break - } - // handle viewers, we want to use non-blocking receive select { + case buf, ok := <-bufChan: + if !ok { + tc.ws.SetWriteDeadline(time.Now().Add(writeWait)) + tc.ws.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed")) + + break out + } + // We could add ws to viewers as well (then we can use io.MultiWriter), + // but we want to handle errors differently + tc.ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := tc.ws.WriteMessage(websocket.BinaryMessage, buf); err != nil { + log.Println("Failed to write message: ", err) + break out + } + + //write to the viewer + for i, w := range viewers { + if w == nil { + continue + } + + // if the viewer exits, we will just ignore the error + w.SetWriteDeadline(time.Now().Add(viewWait)) + if err := w.WriteMessage(websocket.BinaryMessage, buf); err != nil { + log.Println("Failed to write message to viewer: ", err) + + viewers[i] = nil + w.Close() // we own the socket and need to close it + } + } + case viewer := <-tc.vchan: log.Println("Received viewer") viewers = append(viewers, viewer) - case <-tc.done: - log.Println("Websocket is closed by main goroutine") - closed = true + case <-tc.ws_done: + log.Println("Exit ptyStdoutToWs routine as ws is going away") break out - default: // do not block on these two channels + case <-tc.pty_done: + log.Println("Exit ptyStdoutToWs routine as pty is going away") + break out // do not block on these two channels } - // We could add ws to viewers as well (then we can use io.MultiWriter), - // but we want to handle errors differently - tc.ws.SetWriteDeadline(time.Now().Add(writeWait)) - if err = tc.ws.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil { - log.Println("Failed to write message: ", err) - break - } - - for i, w := range viewers { - if w == nil { - continue - } - - // if the viewer exits, we will just ignore the error - w.SetWriteDeadline(time.Now().Add(viewWait)) - if err = w.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil { - log.Println("Failed to write message to viewer: ", err) - - viewers[i] = nil - w.Close() // we own the socket and need to close it - } - } } // close the watcher @@ -206,12 +268,7 @@ out: } } - if !closed { // If the error is caused by pty, try to close the socket - tc.ws.SetWriteDeadline(time.Now().Add(writeWait)) - tc.ws.WriteMessage(websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed")) - time.Sleep(closeGracePeriod) - } + log.Println("ptyStdoutToWs routine exited") } // this function should be executed by the main goroutine for the connection @@ -247,7 +304,6 @@ func (tc *TermConn) release() { } close(tc.vchan) - close(tc.done) } tc.ws.Close() @@ -270,24 +326,30 @@ func wsHandlePlayer(w http.ResponseWriter, r *http.Request) { defer tc.release() log.Println("\n\nCreated the websocket") + tc.ws_done = make(chan struct{}) + tc.pty_done = make(chan struct{}) + tc.vchan = make(chan *websocket.Conn) + if err := tc.createPty(cmdToExec); err != nil { log.Println("Failed to create PTY: ", err) return } - tc.done = make(chan struct{}) - tc.vchan = make(chan *websocket.Conn) - registry.addPlayer("main", &tc) // main event loop to shovel data between ws and pty // do not call ptyStdoutToWs in this goroutine, otherwise // the websocket will not close. This is because ptyStdoutToWs // is usually blocked in the pty.Read - go tc.ping() - go tc.ptyStdoutToWs() + var wg sync.WaitGroup + wg.Add(3) - tc.wsToPtyStdin() + go tc.ping(&wg) + go tc.ptyStdoutToWs(&wg) + go tc.wsToPtyStdin(&wg) + + wg.Wait() + log.Println("Wait returned") } // handle websockets