properly close the connection with two chans: pty and ws chans

This commit is contained in:
Zhi Wang 2022-01-10 13:04:09 -05:00
parent 38f8f279a0
commit dd85834715

246
relay.go
View File

@ -7,6 +7,7 @@ import (
"net/url"
"os"
"os/exec"
"sync"
"time"
"github.com/creack/pty"
@ -23,7 +24,9 @@ const (
pongWait = 10 * time.Second
// Maximum message size allowed from peer.
maxMessageSize = 8192
maxMessageSize = 4096
readBufferSize = 1024
WriteBufferSize = 1024
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
@ -32,6 +35,27 @@ const (
closeGracePeriod = 10 * time.Second
)
var host *string = nil
var upgrader = websocket.Upgrader{
ReadBufferSize: readBufferSize,
WriteBufferSize: WriteBufferSize,
CheckOrigin: func(r *http.Request) bool {
org := r.Header.Get("Origin")
h, err := url.Parse(org)
if err != nil {
return false
}
if (host == nil) || (*host != h.Host) {
log.Println("Failed origin check of ", org)
}
return (host != nil) && (*host == h.Host)
},
}
// TermConn represents the connected websocket and pty.
// if isViewer is true
type TermConn struct {
@ -39,10 +63,11 @@ type TermConn struct {
name string
// only valid for doers
ptmx *os.File // the pty that runs the command
cmd *exec.Cmd // represents the process, we need it to terminate the process
vchan chan *websocket.Conn // channel to receive viewers
done chan struct{}
ptmx *os.File // the pty that runs the command
cmd *exec.Cmd // represents the process, we need it to terminate the process
vchan chan *websocket.Conn // channel to receive viewers
ws_done chan struct{} // ws is closed, only close this chan in ws reader
pty_done chan struct{} // pty is closed, close this chan in pty reader
}
func (tc *TermConn) createPty(cmdline []string) error {
@ -71,31 +96,14 @@ func (tc *TermConn) createPty(cmdline []string) error {
return nil
}
var host *string = nil
var upgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
org := r.Header.Get("Origin")
h, err := url.Parse(org)
if err != nil {
return false
}
if (host == nil) || (*host != h.Host) {
log.Println("Failed origin check of ", org)
}
return (host != nil) && (*host == h.Host)
},
}
// Periodically send ping message to detect the status of the ws
func (tc *TermConn) ping() {
func (tc *TermConn) ping(wg *sync.WaitGroup) {
defer wg.Done()
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
out:
for {
select {
case <-ticker.C:
@ -104,18 +112,25 @@ func (tc *TermConn) ping() {
if err != nil {
log.Println("Failed to write ping message:", err)
return
break out
}
case <-tc.pty_done:
log.Println("Exit ping routine as pty is going away")
break out
case <-tc.done:
log.Println("Exit ping routine as pty/ws is going away")
return
case <-tc.ws_done:
log.Println("Exit ping routine as ws is going away")
break out
}
}
log.Println("Ping routine exited")
}
// shovel data from websocket to pty stdin
func (tc *TermConn) wsToPtyStdin() {
func (tc *TermConn) wsToPtyStdin(wg *sync.WaitGroup) {
defer wg.Done()
tc.ws.SetReadLimit(maxMessageSize)
// set the readdeadline. The idea here is simple,
@ -128,75 +143,122 @@ func (tc *TermConn) wsToPtyStdin() {
return nil
})
// we do not need to forward user input to viewers, only the stdout
for {
_, buf, err := tc.ws.ReadMessage()
bufChan := make(chan []byte)
if err != nil {
log.Println("Failed to receive data from ws:", err)
break
go func() { //create a goroutine to read from ws
for {
_, buf, err := tc.ws.ReadMessage()
if err != nil {
log.Println("Failed to receive data from ws:", err)
close(bufChan) // close chan by producer
close(tc.ws_done)
break
}
bufChan <- buf
}
}()
// we do not need to forward user input to viewers, only the stdout
out:
for {
select {
case buf, ok := <-bufChan:
if !ok {
log.Println("Exit wsToPtyStdin routine pty stdin error")
break out
}
_, err := tc.ptmx.Write(buf)
_, err = tc.ptmx.Write(buf)
if err != nil {
log.Println("Failed to send data to pty stdin: ", err)
break
if err != nil {
log.Println("Failed to send data to pty stdin: ", err)
break out
}
case <-tc.ws_done:
log.Println("Exit wsToPtyStdin routine as ws is going away")
break out
case <-tc.pty_done:
log.Println("Exit wsToPtyStdin routine as pty is going away")
break out
}
}
log.Println("wsToPtyStdin routine exited")
}
// shovel data from pty Stdout to WS
func (tc *TermConn) ptyStdoutToWs() {
func (tc *TermConn) ptyStdoutToWs(wg *sync.WaitGroup) {
var viewers []*websocket.Conn
readBuf := make([]byte, 4096)
closed := false
defer wg.Done()
bufChan := make(chan []byte)
go func() { //create a goroutine to read from pty
for {
readBuf := make([]byte, 1024) //pty reads in 1024 blocks
n, err := tc.ptmx.Read(readBuf)
if err != nil {
log.Println("Failed to read from pty stdout: ", err)
close(bufChan)
close(tc.pty_done)
break
}
readBuf = readBuf[:n] // slice the buffer so that it is exact the size of data read.
bufChan <- readBuf
}
}()
out:
for {
n, err := tc.ptmx.Read(readBuf)
if err != nil {
log.Println("Failed to read from pty stdout: ", err)
break
}
// handle viewers, we want to use non-blocking receive
select {
case buf, ok := <-bufChan:
if !ok {
tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
tc.ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed"))
break out
}
// We could add ws to viewers as well (then we can use io.MultiWriter),
// but we want to handle errors differently
tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
if err := tc.ws.WriteMessage(websocket.BinaryMessage, buf); err != nil {
log.Println("Failed to write message: ", err)
break out
}
//write to the viewer
for i, w := range viewers {
if w == nil {
continue
}
// if the viewer exits, we will just ignore the error
w.SetWriteDeadline(time.Now().Add(viewWait))
if err := w.WriteMessage(websocket.BinaryMessage, buf); err != nil {
log.Println("Failed to write message to viewer: ", err)
viewers[i] = nil
w.Close() // we own the socket and need to close it
}
}
case viewer := <-tc.vchan:
log.Println("Received viewer")
viewers = append(viewers, viewer)
case <-tc.done:
log.Println("Websocket is closed by main goroutine")
closed = true
case <-tc.ws_done:
log.Println("Exit ptyStdoutToWs routine as ws is going away")
break out
default: // do not block on these two channels
case <-tc.pty_done:
log.Println("Exit ptyStdoutToWs routine as pty is going away")
break out // do not block on these two channels
}
// We could add ws to viewers as well (then we can use io.MultiWriter),
// but we want to handle errors differently
tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
if err = tc.ws.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil {
log.Println("Failed to write message: ", err)
break
}
for i, w := range viewers {
if w == nil {
continue
}
// if the viewer exits, we will just ignore the error
w.SetWriteDeadline(time.Now().Add(viewWait))
if err = w.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil {
log.Println("Failed to write message to viewer: ", err)
viewers[i] = nil
w.Close() // we own the socket and need to close it
}
}
}
// close the watcher
@ -206,12 +268,7 @@ out:
}
}
if !closed { // If the error is caused by pty, try to close the socket
tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
tc.ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed"))
time.Sleep(closeGracePeriod)
}
log.Println("ptyStdoutToWs routine exited")
}
// this function should be executed by the main goroutine for the connection
@ -247,7 +304,6 @@ func (tc *TermConn) release() {
}
close(tc.vchan)
close(tc.done)
}
tc.ws.Close()
@ -270,24 +326,30 @@ func wsHandlePlayer(w http.ResponseWriter, r *http.Request) {
defer tc.release()
log.Println("\n\nCreated the websocket")
tc.ws_done = make(chan struct{})
tc.pty_done = make(chan struct{})
tc.vchan = make(chan *websocket.Conn)
if err := tc.createPty(cmdToExec); err != nil {
log.Println("Failed to create PTY: ", err)
return
}
tc.done = make(chan struct{})
tc.vchan = make(chan *websocket.Conn)
registry.addPlayer("main", &tc)
// main event loop to shovel data between ws and pty
// do not call ptyStdoutToWs in this goroutine, otherwise
// the websocket will not close. This is because ptyStdoutToWs
// is usually blocked in the pty.Read
go tc.ping()
go tc.ptyStdoutToWs()
var wg sync.WaitGroup
wg.Add(3)
tc.wsToPtyStdin()
go tc.ping(&wg)
go tc.ptyStdoutToWs(&wg)
go tc.wsToPtyStdin(&wg)
wg.Wait()
log.Println("Wait returned")
}
// handle websockets