properly close the connection with two chans: pty and ws chans

This commit is contained in:
Zhi Wang 2022-01-10 13:04:09 -05:00
parent 38f8f279a0
commit dd85834715

188
relay.go
View File

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"os" "os"
"os/exec" "os/exec"
"sync"
"time" "time"
"github.com/creack/pty" "github.com/creack/pty"
@ -23,7 +24,9 @@ const (
pongWait = 10 * time.Second pongWait = 10 * time.Second
// Maximum message size allowed from peer. // Maximum message size allowed from peer.
maxMessageSize = 8192 maxMessageSize = 4096
readBufferSize = 1024
WriteBufferSize = 1024
// Send pings to peer with this period. Must be less than pongWait. // Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10 pingPeriod = (pongWait * 9) / 10
@ -32,6 +35,27 @@ const (
closeGracePeriod = 10 * time.Second closeGracePeriod = 10 * time.Second
) )
var host *string = nil
var upgrader = websocket.Upgrader{
ReadBufferSize: readBufferSize,
WriteBufferSize: WriteBufferSize,
CheckOrigin: func(r *http.Request) bool {
org := r.Header.Get("Origin")
h, err := url.Parse(org)
if err != nil {
return false
}
if (host == nil) || (*host != h.Host) {
log.Println("Failed origin check of ", org)
}
return (host != nil) && (*host == h.Host)
},
}
// TermConn represents the connected websocket and pty. // TermConn represents the connected websocket and pty.
// if isViewer is true // if isViewer is true
type TermConn struct { type TermConn struct {
@ -42,7 +66,8 @@ type TermConn struct {
ptmx *os.File // the pty that runs the command ptmx *os.File // the pty that runs the command
cmd *exec.Cmd // represents the process, we need it to terminate the process cmd *exec.Cmd // represents the process, we need it to terminate the process
vchan chan *websocket.Conn // channel to receive viewers vchan chan *websocket.Conn // channel to receive viewers
done chan struct{} ws_done chan struct{} // ws is closed, only close this chan in ws reader
pty_done chan struct{} // pty is closed, close this chan in pty reader
} }
func (tc *TermConn) createPty(cmdline []string) error { func (tc *TermConn) createPty(cmdline []string) error {
@ -71,31 +96,14 @@ func (tc *TermConn) createPty(cmdline []string) error {
return nil return nil
} }
var host *string = nil
var upgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
org := r.Header.Get("Origin")
h, err := url.Parse(org)
if err != nil {
return false
}
if (host == nil) || (*host != h.Host) {
log.Println("Failed origin check of ", org)
}
return (host != nil) && (*host == h.Host)
},
}
// Periodically send ping message to detect the status of the ws // Periodically send ping message to detect the status of the ws
func (tc *TermConn) ping() { func (tc *TermConn) ping(wg *sync.WaitGroup) {
defer wg.Done()
ticker := time.NewTicker(pingPeriod) ticker := time.NewTicker(pingPeriod)
defer ticker.Stop() defer ticker.Stop()
out:
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
@ -104,18 +112,25 @@ func (tc *TermConn) ping() {
if err != nil { if err != nil {
log.Println("Failed to write ping message:", err) log.Println("Failed to write ping message:", err)
return break out
}
case <-tc.pty_done:
log.Println("Exit ping routine as pty is going away")
break out
case <-tc.ws_done:
log.Println("Exit ping routine as ws is going away")
break out
}
} }
case <-tc.done: log.Println("Ping routine exited")
log.Println("Exit ping routine as pty/ws is going away")
return
}
}
} }
// shovel data from websocket to pty stdin // shovel data from websocket to pty stdin
func (tc *TermConn) wsToPtyStdin() { func (tc *TermConn) wsToPtyStdin(wg *sync.WaitGroup) {
defer wg.Done()
tc.ws.SetReadLimit(maxMessageSize) tc.ws.SetReadLimit(maxMessageSize)
// set the readdeadline. The idea here is simple, // set the readdeadline. The idea here is simple,
@ -128,61 +143,94 @@ func (tc *TermConn) wsToPtyStdin() {
return nil return nil
}) })
// we do not need to forward user input to viewers, only the stdout bufChan := make(chan []byte)
go func() { //create a goroutine to read from ws
for { for {
_, buf, err := tc.ws.ReadMessage() _, buf, err := tc.ws.ReadMessage()
if err != nil { if err != nil {
log.Println("Failed to receive data from ws:", err) log.Println("Failed to receive data from ws:", err)
close(bufChan) // close chan by producer
close(tc.ws_done)
break break
} }
_, err = tc.ptmx.Write(buf) bufChan <- buf
}
}()
// we do not need to forward user input to viewers, only the stdout
out:
for {
select {
case buf, ok := <-bufChan:
if !ok {
log.Println("Exit wsToPtyStdin routine pty stdin error")
break out
}
_, err := tc.ptmx.Write(buf)
if err != nil { if err != nil {
log.Println("Failed to send data to pty stdin: ", err) log.Println("Failed to send data to pty stdin: ", err)
break break out
}
case <-tc.ws_done:
log.Println("Exit wsToPtyStdin routine as ws is going away")
break out
case <-tc.pty_done:
log.Println("Exit wsToPtyStdin routine as pty is going away")
break out
} }
} }
log.Println("wsToPtyStdin routine exited")
} }
// shovel data from pty Stdout to WS // shovel data from pty Stdout to WS
func (tc *TermConn) ptyStdoutToWs() { func (tc *TermConn) ptyStdoutToWs(wg *sync.WaitGroup) {
var viewers []*websocket.Conn var viewers []*websocket.Conn
readBuf := make([]byte, 4096)
closed := false
out: defer wg.Done()
bufChan := make(chan []byte)
go func() { //create a goroutine to read from pty
for { for {
readBuf := make([]byte, 1024) //pty reads in 1024 blocks
n, err := tc.ptmx.Read(readBuf) n, err := tc.ptmx.Read(readBuf)
if err != nil { if err != nil {
log.Println("Failed to read from pty stdout: ", err) log.Println("Failed to read from pty stdout: ", err)
close(bufChan)
close(tc.pty_done)
break break
} }
readBuf = readBuf[:n] // slice the buffer so that it is exact the size of data read.
bufChan <- readBuf
}
}()
out:
for {
// handle viewers, we want to use non-blocking receive // handle viewers, we want to use non-blocking receive
select { select {
case viewer := <-tc.vchan: case buf, ok := <-bufChan:
log.Println("Received viewer") if !ok {
viewers = append(viewers, viewer) tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
tc.ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed"))
case <-tc.done:
log.Println("Websocket is closed by main goroutine")
closed = true
break out break out
default: // do not block on these two channels
} }
// We could add ws to viewers as well (then we can use io.MultiWriter), // We could add ws to viewers as well (then we can use io.MultiWriter),
// but we want to handle errors differently // but we want to handle errors differently
tc.ws.SetWriteDeadline(time.Now().Add(writeWait)) tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
if err = tc.ws.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil { if err := tc.ws.WriteMessage(websocket.BinaryMessage, buf); err != nil {
log.Println("Failed to write message: ", err) log.Println("Failed to write message: ", err)
break break out
} }
//write to the viewer
for i, w := range viewers { for i, w := range viewers {
if w == nil { if w == nil {
continue continue
@ -190,13 +238,27 @@ out:
// if the viewer exits, we will just ignore the error // if the viewer exits, we will just ignore the error
w.SetWriteDeadline(time.Now().Add(viewWait)) w.SetWriteDeadline(time.Now().Add(viewWait))
if err = w.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil { if err := w.WriteMessage(websocket.BinaryMessage, buf); err != nil {
log.Println("Failed to write message to viewer: ", err) log.Println("Failed to write message to viewer: ", err)
viewers[i] = nil viewers[i] = nil
w.Close() // we own the socket and need to close it w.Close() // we own the socket and need to close it
} }
} }
case viewer := <-tc.vchan:
log.Println("Received viewer")
viewers = append(viewers, viewer)
case <-tc.ws_done:
log.Println("Exit ptyStdoutToWs routine as ws is going away")
break out
case <-tc.pty_done:
log.Println("Exit ptyStdoutToWs routine as pty is going away")
break out // do not block on these two channels
}
} }
// close the watcher // close the watcher
@ -206,12 +268,7 @@ out:
} }
} }
if !closed { // If the error is caused by pty, try to close the socket log.Println("ptyStdoutToWs routine exited")
tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
tc.ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed"))
time.Sleep(closeGracePeriod)
}
} }
// this function should be executed by the main goroutine for the connection // this function should be executed by the main goroutine for the connection
@ -247,7 +304,6 @@ func (tc *TermConn) release() {
} }
close(tc.vchan) close(tc.vchan)
close(tc.done)
} }
tc.ws.Close() tc.ws.Close()
@ -270,24 +326,30 @@ func wsHandlePlayer(w http.ResponseWriter, r *http.Request) {
defer tc.release() defer tc.release()
log.Println("\n\nCreated the websocket") log.Println("\n\nCreated the websocket")
tc.ws_done = make(chan struct{})
tc.pty_done = make(chan struct{})
tc.vchan = make(chan *websocket.Conn)
if err := tc.createPty(cmdToExec); err != nil { if err := tc.createPty(cmdToExec); err != nil {
log.Println("Failed to create PTY: ", err) log.Println("Failed to create PTY: ", err)
return return
} }
tc.done = make(chan struct{})
tc.vchan = make(chan *websocket.Conn)
registry.addPlayer("main", &tc) registry.addPlayer("main", &tc)
// main event loop to shovel data between ws and pty // main event loop to shovel data between ws and pty
// do not call ptyStdoutToWs in this goroutine, otherwise // do not call ptyStdoutToWs in this goroutine, otherwise
// the websocket will not close. This is because ptyStdoutToWs // the websocket will not close. This is because ptyStdoutToWs
// is usually blocked in the pty.Read // is usually blocked in the pty.Read
go tc.ping() var wg sync.WaitGroup
go tc.ptyStdoutToWs() wg.Add(3)
tc.wsToPtyStdin() go tc.ping(&wg)
go tc.ptyStdoutToWs(&wg)
go tc.wsToPtyStdin(&wg)
wg.Wait()
log.Println("Wait returned")
} }
// handle websockets // handle websockets