diff --git a/main.go b/main.go index 8db5381..b7278a8 100644 --- a/main.go +++ b/main.go @@ -3,14 +3,34 @@ package main import ( "log" "net/http" + "net/url" "os" "github.com/gin-gonic/gin" + "github.com/syssecfsu/web_terminal/term_conn" ) // command line options var cmdToExec = []string{"bash"} +var host *string = nil + +// simple function to check origin +func checkOrigin(r *http.Request) bool { + org := r.Header.Get("Origin") + h, err := url.Parse(org) + + if err != nil { + return false + } + + if (host == nil) || (*host != h.Host) { + log.Println("Failed origin check of ", org) + } + + return (host != nil) && (*host == h.Host) +} + func main() { fp, err := os.OpenFile("web_term.log", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) @@ -31,8 +51,6 @@ func main() { log.Println(cmdToExec) } - registry.init() - rt := gin.Default() rt.SetTrustedProxies(nil) @@ -40,17 +58,17 @@ func main() { rt.GET("/view/*sname", func(c *gin.Context) { c.HTML(http.StatusOK, "index.html", gin.H{ - "title": "Watcher terminal", + "title": "Viewer terminal", "path": "/ws_view", }) }) rt.GET("/ws_do", func(c *gin.Context) { - wsHandler(c.Writer, c.Request, false) + term_conn.ConnectTerm(c.Writer, c.Request, false, cmdToExec) }) rt.GET("/ws_view", func(c *gin.Context) { - wsHandler(c.Writer, c.Request, true) + term_conn.ConnectTerm(c.Writer, c.Request, true, nil) }) // handle static files @@ -58,11 +76,14 @@ func main() { rt.GET("/", func(c *gin.Context) { c.HTML(http.StatusOK, "index.html", gin.H{ - "title": "Master terminal", + "title": "Interactive terminal", "path": "/ws_do", }) + host = &c.Request.Host }) + term_conn.Init(checkOrigin) + rt.RunTLS(":8080", "./tls/cert.pem", "./tls/private-key.pem") } diff --git a/reg.go b/term_conn/reg.go similarity index 98% rename from reg.go rename to term_conn/reg.go index 654da24..15cee7c 100644 --- a/reg.go +++ b/term_conn/reg.go @@ -1,4 +1,4 @@ -package main +package term_conn import ( "errors" diff --git a/relay.go b/term_conn/relay.go similarity index 93% rename from relay.go rename to term_conn/relay.go index ba27773..470cd19 100644 --- a/relay.go +++ b/term_conn/relay.go @@ -1,10 +1,9 @@ //This file contains code to relay traffic between websocket and pty -package main +package term_conn import ( "log" "net/http" - "net/url" "os" "os/exec" "sync" @@ -35,24 +34,11 @@ const ( closeGracePeriod = 10 * time.Second ) -var host *string = nil - var upgrader = websocket.Upgrader{ ReadBufferSize: readBufferSize, WriteBufferSize: WriteBufferSize, CheckOrigin: func(r *http.Request) bool { - org := r.Header.Get("Origin") - h, err := url.Parse(org) - - if err != nil { - return false - } - - if (host == nil) || (*host != h.Host) { - log.Println("Failed origin check of ", org) - } - - return (host != nil) && (*host == h.Host) + return true }, } @@ -310,7 +296,7 @@ func (tc *TermConn) release() { } // handle websockets -func wsHandlePlayer(w http.ResponseWriter, r *http.Request) { +func handlePlayer(w http.ResponseWriter, r *http.Request, cmdline []string) { ws, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -330,7 +316,7 @@ func wsHandlePlayer(w http.ResponseWriter, r *http.Request) { tc.pty_done = make(chan struct{}) tc.vchan = make(chan *websocket.Conn) - if err := tc.createPty(cmdToExec); err != nil { + if err := tc.createPty(cmdline); err != nil { log.Println("Failed to create PTY: ", err) return } @@ -353,7 +339,7 @@ func wsHandlePlayer(w http.ResponseWriter, r *http.Request) { } // handle websockets -func wsHandleViewer(w http.ResponseWriter, r *http.Request) { +func handleViewer(w http.ResponseWriter, r *http.Request) { ws, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -368,10 +354,15 @@ func wsHandleViewer(w http.ResponseWriter, r *http.Request) { } } -func wsHandler(w http.ResponseWriter, r *http.Request, isViewer bool) { +func ConnectTerm(w http.ResponseWriter, r *http.Request, isViewer bool, cmdline []string) { if !isViewer { - wsHandlePlayer(w, r) + handlePlayer(w, r, cmdline) } else { - wsHandleViewer(w, r) + handleViewer(w, r) } } + +func Init(checkOrigin func(r *http.Request) bool) { + upgrader.CheckOrigin = checkOrigin + registry.init() +}