From e1896c7eb577ca6ad3642eeb66a4a43347537b4c Mon Sep 17 00:00:00 2001 From: Zhi Wang Date: Thu, 6 Jan 2022 15:29:02 -0500 Subject: [PATCH] rewrite the ws handling according to command example in websocket lib. Basically adding timeout for receive and send --- assets/main.js | 1 + go.mod | 2 +- main.go | 208 +++++++++++++++++++++++++++++++++---------------- 3 files changed, 145 insertions(+), 66 deletions(-) diff --git a/assets/main.js b/assets/main.js index 0e737a7..4085b47 100644 --- a/assets/main.js +++ b/assets/main.js @@ -28,6 +28,7 @@ function createTerminal() { fontSize: 12, theme: baseTheme, convertEol: true, + cursorBlink: true, }); term.open(document.getElementById('terminal_view')); diff --git a/go.mod b/go.mod index 12838c7..cb43687 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/syssecfsu/wsterm +module github.com/syssecfsu/web_terminal go 1.17 diff --git a/main.go b/main.go index d935f7b..3fea16e 100644 --- a/main.go +++ b/main.go @@ -2,19 +2,37 @@ package main import ( "fmt" + "log" "net/http" "net/url" "os" "os/exec" "strings" + "time" "github.com/creack/pty" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "golang.org/x/term" ) -func createPty(cmdline string) (*os.File, *term.State, error) { +const ( + // Time allowed to write a message to the peer. + writeWait = 5 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 30 * time.Second + + // Maximum message size allowed from peer. + maxMessageSize = 8192 + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Time to wait before force close on connection. + closeGracePeriod = 10 * time.Second +) + +func createPty(cmdline string) (*os.File, *exec.Cmd, error) { // Create a shell command. cmd := exec.Command(cmdline) @@ -33,16 +51,7 @@ func createPty(cmdline string) (*os.File, *term.State, error) { Rows: 36, }) - // Set stdin in raw mode. This might cause problems in ssh. - // ignore the error if it so happens - termState, err := term.MakeRaw(int(os.Stdin.Fd())) - - if err != nil { - fmt.Println(err) - return ptmx, nil, err - } - - return ptmx, termState, nil + return ptmx, cmd, nil } var host *string = nil @@ -59,80 +68,141 @@ var upgrader = websocket.Upgrader{ } if (host == nil) || (*host != h.Host) { - fmt.Println("failed origin check of ", org, "against", *host) + log.Println("Failed origin check of ", org) } return (host != nil) && (*host == h.Host) }, } -// handle websockets -func wsHandler(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - - if err != nil { - fmt.Println(err) - return - } - - fmt.Println("Created the websocket") - - ptmx, termState, err := createPty("bash") - - defer func() { - //close the terminal and restore the terminal state - ptmx.Close() - - if termState != nil { - term.Restore(int(os.Stdin.Fd()), termState) - } - }() - - if err != nil { - fmt.Println("failed to create PTY", err) - return - } - - // pipe the msgs from WS to pty, we need to use goroutine here - go func() { - for { - _, buf, err := conn.ReadMessage() +// Periodically send ping message to detect the status of the ws +func ping(ws *websocket.Conn, done chan struct{}) { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + err := ws.WriteControl(websocket.PingMessage, + []byte{}, time.Now().Add(writeWait)) if err != nil { - fmt.Println(err) - // We need to close pty so the goroutine and this one can end - // using defer will cause problems - ptmx.Close() - return + log.Println("Failed to write ping message:", err) } - _, err = ptmx.Write(buf) - - if err != nil { - fmt.Println(err) - ptmx.Close() - return - } + case <-done: + log.Println("Exit ping routine as stdout is going away") + return } - }() + } +} +// shovel data from websocket to pty stdin +func toPtyStdin(ws *websocket.Conn, ptmx *os.File) { + ws.SetReadLimit(maxMessageSize) + + // set the readdeadline. The idea here is simple, + // as long as we keep receiving pong message, + // the readdeadline will keep updating. Otherwise + // read will timeout. + ws.SetReadDeadline(time.Now().Add(pongWait)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + + for { + _, buf, err := ws.ReadMessage() + + if err != nil { + log.Println("Failed to receive data from ws:", err) + break + } + + _, err = ptmx.Write(buf) + + if err != nil { + log.Println("Failed to send data to pty stdin: ", err) + break + } + } +} + +// shovel data from websocket to pty stdin +func fromPtyStdout(ws *websocket.Conn, ptmx *os.File, done chan struct{}) { readBuf := make([]byte, 4096) for { n, err := ptmx.Read(readBuf) if err != nil { - fmt.Println(err) - ptmx.Close() - return + log.Println("Failed to read from pty stdout: ", err) + break } - if err = conn.WriteMessage(websocket.BinaryMessage, readBuf[:n]); err != nil { - ptmx.Close() - fmt.Println(err) - return + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err = ws.WriteMessage(websocket.BinaryMessage, readBuf[:n]); err != nil { + log.Println("Failed to write message: ", err) + break } } + + close(done) + + ws.SetWriteDeadline(time.Now().Add(writeWait)) + ws.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed")) + time.Sleep(closeGracePeriod) +} + +// handle websockets +func wsHandler(w http.ResponseWriter, r *http.Request) { + ws, err := upgrader.Upgrade(w, r, nil) + + if err != nil { + log.Println("Failed to create websocket: ", err) + return + } + + defer ws.Close() + + log.Println("\n\nCreated the websocket") + + ptmx, cmd, err := createPty("bash") + + if err != nil { + log.Println("Failed to create PTY: ", err) + return + } + + done := make(chan struct{}) + + go fromPtyStdout(ws, ptmx, done) + go ping(ws, done) + + toPtyStdin(ws, ptmx) + + // cleanup the pty and its related process + ptmx.Close() + proc := cmd.Process + + // send an interrupt, this will cause the shell process to + // return from syscalls if any is pending + if err := proc.Signal(os.Interrupt); err != nil { + log.Println("Failed to send Interrupt to shell process: ", err) + } + + // Wait for a second for shell process to interrupt before kill it + time.Sleep(time.Second) + + log.Printf("Try to kill the shell process") + + if err := proc.Signal(os.Kill); err != nil { + log.Println("Failed to send KILL to shell process: ", err) + } + + if _, err := proc.Wait(); err != nil { + log.Println("Failed to wait for shell process: ", err) + } } // return files @@ -144,7 +214,7 @@ func fileHandler(c *gin.Context, fname string) { } fname = fname[1:] //fname always starts with / - fmt.Println(fname) + log.Println("Sending ", fname) if strings.HasSuffix(fname, "html") { c.HTML(200, fname, nil) @@ -156,6 +226,14 @@ func fileHandler(c *gin.Context, fname string) { } func main() { + fp, err := os.OpenFile("web_term.log", os.O_RDWR|os.O_CREATE, 0644) + + if err == nil { + defer fp.Close() + log.SetOutput(fp) + gin.DefaultWriter = fp + } + rt := gin.Default() rt.SetTrustedProxies(nil)