change the teardown process

This commit is contained in:
Zhi Wang 2022-01-10 09:35:05 -05:00
parent 598e01c1e9
commit 38f8f279a0
2 changed files with 57 additions and 44 deletions

25
reg.go
View File

@ -12,32 +12,31 @@ import (
// design this using channels, but it is simple enough with mutex // design this using channels, but it is simple enough with mutex
type Registry struct { type Registry struct {
mtx sync.Mutex mtx sync.Mutex
doers map[string]*TermConn players map[string]*TermConn
} }
var registry Registry var registry Registry
func (reg *Registry) init() { func (reg *Registry) init() {
reg.doers = make(map[string]*TermConn) reg.players = make(map[string]*TermConn)
} }
func (d *Registry) addDoer(name string, tc *TermConn) { func (d *Registry) addPlayer(name string, tc *TermConn) {
d.mtx.Lock() d.mtx.Lock()
if val, ok := d.doers[name]; ok { if _, ok := d.players[name]; ok {
log.Println(name, "already exist in the dispatcher", val, tc) log.Println(name, "already exist in the dispatcher, skip registration")
delete(d.doers, name) } else {
val.release(false) // do not unregister in release, otherwise it is a deadlock d.players[name] = tc
} }
d.doers[name] = tc
d.mtx.Unlock() d.mtx.Unlock()
} }
func (d *Registry) delDoer(name string) error { func (d *Registry) removePlayer(name string) error {
d.mtx.Lock() d.mtx.Lock()
var err error = errors.New("not found") var err error = errors.New("not found")
if _, ok := d.doers[name]; ok { if _, ok := d.players[name]; ok {
delete(d.doers, name) delete(d.players, name)
err = nil err = nil
} }
@ -46,9 +45,9 @@ func (d *Registry) delDoer(name string) error {
} }
// we do not want to return the channel to viewer so it won't be used out of the critical section // we do not want to return the channel to viewer so it won't be used out of the critical section
func (d *Registry) sendToDoer(name string, ws *websocket.Conn) bool { func (d *Registry) sendToPlayer(name string, ws *websocket.Conn) bool {
d.mtx.Lock() d.mtx.Lock()
tc, ok := d.doers[name] tc, ok := d.players[name]
if ok { if ok {
tc.vchan <- ws tc.vchan <- ws

View File

@ -15,10 +15,12 @@ import (
const ( const (
// Time allowed to write a message to the peer. // Time allowed to write a message to the peer.
writeWait = 5 * time.Second readWait = 10 * time.Second
writeWait = 10 * time.Second
viewWait = 3 * time.Second
// Time allowed to read the next pong message from the peer. // Time allowed to read the next pong message from the peer.
pongWait = 30 * time.Second pongWait = 10 * time.Second
// Maximum message size allowed from peer. // Maximum message size allowed from peer.
maxMessageSize = 8192 maxMessageSize = 8192
@ -102,10 +104,11 @@ func (tc *TermConn) ping() {
if err != nil { if err != nil {
log.Println("Failed to write ping message:", err) log.Println("Failed to write ping message:", err)
return
} }
case <-tc.done: case <-tc.done:
log.Println("Exit ping routine as stdout is going away") log.Println("Exit ping routine as pty/ws is going away")
return return
} }
} }
@ -147,7 +150,9 @@ func (tc *TermConn) wsToPtyStdin() {
func (tc *TermConn) ptyStdoutToWs() { func (tc *TermConn) ptyStdoutToWs() {
var viewers []*websocket.Conn var viewers []*websocket.Conn
readBuf := make([]byte, 4096) readBuf := make([]byte, 4096)
closed := false
out:
for { for {
n, err := tc.ptmx.Read(readBuf) n, err := tc.ptmx.Read(readBuf)
@ -158,17 +163,22 @@ func (tc *TermConn) ptyStdoutToWs() {
// handle viewers, we want to use non-blocking receive // handle viewers, we want to use non-blocking receive
select { select {
case watcher := <-tc.vchan: case viewer := <-tc.vchan:
log.Println("Received viewer", watcher) log.Println("Received viewer")
viewers = append(viewers, watcher) viewers = append(viewers, viewer)
default:
//log.Println("no viewer received") case <-tc.done:
log.Println("Websocket is closed by main goroutine")
closed = true
break out
default: // do not block on these two channels
} }
// We could add ws to viewers as well (then we can use io.MultiWriter), // We could add ws to viewers as well (then we can use io.MultiWriter),
// but we want to handle errors differently // but we want to handle errors differently
tc.ws.SetWriteDeadline(time.Now().Add(writeWait)) tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
if tc.ws.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil { if err = tc.ws.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil {
log.Println("Failed to write message: ", err) log.Println("Failed to write message: ", err)
break break
} }
@ -179,11 +189,12 @@ func (tc *TermConn) ptyStdoutToWs() {
} }
// if the viewer exits, we will just ignore the error // if the viewer exits, we will just ignore the error
w.SetWriteDeadline(time.Now().Add(viewWait))
if err = w.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil { if err = w.WriteMessage(websocket.BinaryMessage, readBuf[0:n]); err != nil {
log.Println("Failed to write message to watcher: ", err) log.Println("Failed to write message to viewer: ", err)
viewers[i] = nil viewers[i] = nil
w.Close() w.Close() // we own the socket and need to close it
} }
} }
} }
@ -195,20 +206,20 @@ func (tc *TermConn) ptyStdoutToWs() {
} }
} }
close(tc.done) if !closed { // If the error is caused by pty, try to close the socket
tc.ws.SetWriteDeadline(time.Now().Add(writeWait)) tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
tc.ws.WriteMessage(websocket.CloseMessage, tc.ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed")) websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed"))
time.Sleep(closeGracePeriod) time.Sleep(closeGracePeriod)
} }
func (tc *TermConn) release(unreg bool) {
log.Println("releasing", tc.name)
if unreg {
registry.delDoer(tc.name)
} }
// this function should be executed by the main goroutine for the connection
func (tc *TermConn) release() {
log.Println("Releasing terminal connection", tc.name)
registry.removePlayer(tc.name)
if tc.ptmx != nil { if tc.ptmx != nil {
// cleanup the pty and its related process // cleanup the pty and its related process
tc.ptmx.Close() tc.ptmx.Close()
@ -236,14 +247,14 @@ func (tc *TermConn) release(unreg bool) {
} }
close(tc.vchan) close(tc.vchan)
close(tc.done)
} }
tc.ws.Close() tc.ws.Close()
} }
// handle websockets // handle websockets
func wsHandleDoer(w http.ResponseWriter, r *http.Request) { func wsHandlePlayer(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil) ws, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
@ -256,7 +267,7 @@ func wsHandleDoer(w http.ResponseWriter, r *http.Request) {
name: "main", name: "main",
} }
defer tc.release(true) defer tc.release()
log.Println("\n\nCreated the websocket") log.Println("\n\nCreated the websocket")
if err := tc.createPty(cmdToExec); err != nil { if err := tc.createPty(cmdToExec); err != nil {
@ -267,13 +278,16 @@ func wsHandleDoer(w http.ResponseWriter, r *http.Request) {
tc.done = make(chan struct{}) tc.done = make(chan struct{})
tc.vchan = make(chan *websocket.Conn) tc.vchan = make(chan *websocket.Conn)
registry.addDoer("main", &tc) registry.addPlayer("main", &tc)
// main event loop to shovel data between ws and pty // main event loop to shovel data between ws and pty
// do not call ptyStdoutToWs in this goroutine, otherwise
// the websocket will not close. This is because ptyStdoutToWs
// is usually blocked in the pty.Read
go tc.ping() go tc.ping()
go tc.wsToPtyStdin() go tc.ptyStdoutToWs()
tc.ptyStdoutToWs() tc.wsToPtyStdin()
} }
// handle websockets // handle websockets
@ -286,15 +300,15 @@ func wsHandleViewer(w http.ResponseWriter, r *http.Request) {
} }
log.Println("\n\nCreated the websocket") log.Println("\n\nCreated the websocket")
if !registry.sendToDoer("main", ws) { if !registry.sendToPlayer("main", ws) {
log.Println("Failed to send websocket to doer, close it") log.Println("Failed to send websocket to player, close it")
ws.Close() ws.Close()
} }
} }
func wsHandler(w http.ResponseWriter, r *http.Request, isViewer bool) { func wsHandler(w http.ResponseWriter, r *http.Request, isViewer bool) {
if !isViewer { if !isViewer {
wsHandleDoer(w, r) wsHandlePlayer(w, r)
} else { } else {
wsHandleViewer(w, r) wsHandleViewer(w, r)
} }