rewrite the ws handling according to command example in websocket lib. Basically adding timeout for receive and send

This commit is contained in:
Zhi Wang 2022-01-06 15:29:02 -05:00
parent 43bff23d6c
commit e1896c7eb5
3 changed files with 145 additions and 66 deletions

View File

@ -28,6 +28,7 @@ function createTerminal() {
fontSize: 12, fontSize: 12,
theme: baseTheme, theme: baseTheme,
convertEol: true, convertEol: true,
cursorBlink: true,
}); });
term.open(document.getElementById('terminal_view')); term.open(document.getElementById('terminal_view'));

2
go.mod
View File

@ -1,4 +1,4 @@
module github.com/syssecfsu/wsterm module github.com/syssecfsu/web_terminal
go 1.17 go 1.17

192
main.go
View File

@ -2,19 +2,37 @@ package main
import ( import (
"fmt" "fmt"
"log"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"time"
"github.com/creack/pty" "github.com/creack/pty"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"golang.org/x/term"
) )
func createPty(cmdline string) (*os.File, *term.State, error) { const (
// Time allowed to write a message to the peer.
writeWait = 5 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 30 * time.Second
// Maximum message size allowed from peer.
maxMessageSize = 8192
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Time to wait before force close on connection.
closeGracePeriod = 10 * time.Second
)
func createPty(cmdline string) (*os.File, *exec.Cmd, error) {
// Create a shell command. // Create a shell command.
cmd := exec.Command(cmdline) cmd := exec.Command(cmdline)
@ -33,16 +51,7 @@ func createPty(cmdline string) (*os.File, *term.State, error) {
Rows: 36, Rows: 36,
}) })
// Set stdin in raw mode. This might cause problems in ssh. return ptmx, cmd, nil
// ignore the error if it so happens
termState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
fmt.Println(err)
return ptmx, nil, err
}
return ptmx, termState, nil
} }
var host *string = nil var host *string = nil
@ -59,79 +68,140 @@ var upgrader = websocket.Upgrader{
} }
if (host == nil) || (*host != h.Host) { if (host == nil) || (*host != h.Host) {
fmt.Println("failed origin check of ", org, "against", *host) log.Println("Failed origin check of ", org)
} }
return (host != nil) && (*host == h.Host) return (host != nil) && (*host == h.Host)
}, },
} }
// handle websockets // Periodically send ping message to detect the status of the ws
func wsHandler(w http.ResponseWriter, r *http.Request) { func ping(ws *websocket.Conn, done chan struct{}) {
conn, err := upgrader.Upgrade(w, r, nil) ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
if err != nil {
fmt.Println(err)
return
}
fmt.Println("Created the websocket")
ptmx, termState, err := createPty("bash")
defer func() {
//close the terminal and restore the terminal state
ptmx.Close()
if termState != nil {
term.Restore(int(os.Stdin.Fd()), termState)
}
}()
if err != nil {
fmt.Println("failed to create PTY", err)
return
}
// pipe the msgs from WS to pty, we need to use goroutine here
go func() {
for { for {
_, buf, err := conn.ReadMessage() select {
case <-ticker.C:
err := ws.WriteControl(websocket.PingMessage,
[]byte{}, time.Now().Add(writeWait))
if err != nil { if err != nil {
fmt.Println(err) log.Println("Failed to write ping message:", err)
// We need to close pty so the goroutine and this one can end }
// using defer will cause problems
ptmx.Close() case <-done:
log.Println("Exit ping routine as stdout is going away")
return return
} }
}
}
// shovel data from websocket to pty stdin
func toPtyStdin(ws *websocket.Conn, ptmx *os.File) {
ws.SetReadLimit(maxMessageSize)
// set the readdeadline. The idea here is simple,
// as long as we keep receiving pong message,
// the readdeadline will keep updating. Otherwise
// read will timeout.
ws.SetReadDeadline(time.Now().Add(pongWait))
ws.SetPongHandler(func(string) error {
ws.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, buf, err := ws.ReadMessage()
if err != nil {
log.Println("Failed to receive data from ws:", err)
break
}
_, err = ptmx.Write(buf) _, err = ptmx.Write(buf)
if err != nil { if err != nil {
fmt.Println(err) log.Println("Failed to send data to pty stdin: ", err)
ptmx.Close() break
return
} }
} }
}() }
// shovel data from websocket to pty stdin
func fromPtyStdout(ws *websocket.Conn, ptmx *os.File, done chan struct{}) {
readBuf := make([]byte, 4096) readBuf := make([]byte, 4096)
for { for {
n, err := ptmx.Read(readBuf) n, err := ptmx.Read(readBuf)
if err != nil { if err != nil {
fmt.Println(err) log.Println("Failed to read from pty stdout: ", err)
ptmx.Close() break
}
ws.SetWriteDeadline(time.Now().Add(writeWait))
if err = ws.WriteMessage(websocket.BinaryMessage, readBuf[:n]); err != nil {
log.Println("Failed to write message: ", err)
break
}
}
close(done)
ws.SetWriteDeadline(time.Now().Add(writeWait))
ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed"))
time.Sleep(closeGracePeriod)
}
// handle websockets
func wsHandler(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("Failed to create websocket: ", err)
return return
} }
if err = conn.WriteMessage(websocket.BinaryMessage, readBuf[:n]); err != nil { defer ws.Close()
ptmx.Close()
fmt.Println(err) log.Println("\n\nCreated the websocket")
ptmx, cmd, err := createPty("bash")
if err != nil {
log.Println("Failed to create PTY: ", err)
return return
} }
done := make(chan struct{})
go fromPtyStdout(ws, ptmx, done)
go ping(ws, done)
toPtyStdin(ws, ptmx)
// cleanup the pty and its related process
ptmx.Close()
proc := cmd.Process
// send an interrupt, this will cause the shell process to
// return from syscalls if any is pending
if err := proc.Signal(os.Interrupt); err != nil {
log.Println("Failed to send Interrupt to shell process: ", err)
}
// Wait for a second for shell process to interrupt before kill it
time.Sleep(time.Second)
log.Printf("Try to kill the shell process")
if err := proc.Signal(os.Kill); err != nil {
log.Println("Failed to send KILL to shell process: ", err)
}
if _, err := proc.Wait(); err != nil {
log.Println("Failed to wait for shell process: ", err)
} }
} }
@ -144,7 +214,7 @@ func fileHandler(c *gin.Context, fname string) {
} }
fname = fname[1:] //fname always starts with / fname = fname[1:] //fname always starts with /
fmt.Println(fname) log.Println("Sending ", fname)
if strings.HasSuffix(fname, "html") { if strings.HasSuffix(fname, "html") {
c.HTML(200, fname, nil) c.HTML(200, fname, nil)
@ -156,6 +226,14 @@ func fileHandler(c *gin.Context, fname string) {
} }
func main() { func main() {
fp, err := os.OpenFile("web_term.log", os.O_RDWR|os.O_CREATE, 0644)
if err == nil {
defer fp.Close()
log.SetOutput(fp)
gin.DefaultWriter = fp
}
rt := gin.Default() rt := gin.Default()
rt.SetTrustedProxies(nil) rt.SetTrustedProxies(nil)