//This file contains code to relay traffic between websocket and pty
package term_conn

import (
	"encoding/json"
	"log"
	"net/http"
	"os"
	"os/exec"
	"strconv"
	"sync"
	"time"

	"github.com/creack/pty"
	"github.com/gorilla/websocket"
)

const (
	// Time allowed to write a message to the peer.
	readWait  = 10 * time.Second
	writeWait = 10 * time.Second
	viewWait  = 3 * time.Second

	// Time allowed to read the next pong message from the peer.
	pongWait = 10 * time.Second

	// Maximum message size allowed from peer.
	maxMessageSize  = 4096
	readBufferSize  = 1024
	WriteBufferSize = 1024

	// Send pings to peer with this period. Must be less than pongWait.
	pingPeriod = (pongWait * 9) / 10

	recordCmd = 1
	stopCmd   = 0
)

// simple function to check origin
func checkOrigin(r *http.Request) bool {
	org := r.Header.Get("Origin")

	if org != "https://"+r.Host {
		log.Println("Failed origin check of ", org, r.Host)
		return false
	}

	return true
}

var upgrader = websocket.Upgrader{
	ReadBufferSize:  readBufferSize,
	WriteBufferSize: WriteBufferSize,
	CheckOrigin:     checkOrigin,
}

// TermConn represents the connected websocket and pty.
// if isViewer is true
type TermConn struct {
	Name string
	Ip   string

	ws          *websocket.Conn
	ptmx        *os.File             // the pty that runs the command
	record      *os.File             // record session
	lastRecTime time.Time            // last time a record is written
	cmd         *exec.Cmd            // represents the process, we need it to terminate the process
	viewChan    chan *websocket.Conn // channel to receive viewers
	recordChan  chan int             // channel to start/stop recording
	ws_done     chan struct{}        // ws is closed, only close this chan in ws reader
	pty_done    chan struct{}        // pty is closed, close this chan in pty reader
}

type WriteRecord struct {
	Dur  time.Duration `json:"Duration"`
	Data []byte        `json:"Data"`
}

func (tc *TermConn) createPty(cmdline []string) error {
	// Create a shell command.
	cmd := exec.Command(cmdline[0], cmdline[1:]...)

	// Start the command with a pty.
	ptmx, err := pty.Start(cmd)

	if err != nil {
		return err
	}

	// Use fixed size, the xterm is initalized as 122x37,
	// But we set pty to 120x36. Using fullsize will lead
	// some program to misbehave.
	pty.Setsize(ptmx, &pty.Winsize{
		Cols: 120,
		Rows: 36,
	})

	tc.ptmx = ptmx
	tc.cmd = cmd

	log.Printf("Create shell process %v (%v)", cmdline, cmd.Process.Pid)
	return nil
}

// Periodically send ping message to detect the status of the ws
func (tc *TermConn) ping(wg *sync.WaitGroup) {
	defer wg.Done()

	ticker := time.NewTicker(pingPeriod)
	defer ticker.Stop()

out:
	for {
		select {
		case <-ticker.C:
			err := tc.ws.WriteControl(websocket.PingMessage,
				[]byte{}, time.Now().Add(writeWait))

			if err != nil {
				log.Println("Failed to write ping message:", err)
				break out
			}
		case <-tc.pty_done:
			log.Println("Exit ping routine as pty is going away")
			break out

		case <-tc.ws_done:
			log.Println("Exit ping routine as ws is going away")
			break out
		}
	}

	log.Println("Ping routine exited")
}

// shovel data from websocket to pty stdin
func (tc *TermConn) wsToPtyStdin(wg *sync.WaitGroup) {
	defer wg.Done()

	tc.ws.SetReadLimit(maxMessageSize)

	// set the readdeadline. The idea here is simple,
	// as long as we keep receiving pong message,
	// the readdeadline will keep updating. Otherwise
	// read will timeout.
	tc.ws.SetReadDeadline(time.Now().Add(pongWait))
	tc.ws.SetPongHandler(func(string) error {
		tc.ws.SetReadDeadline(time.Now().Add(pongWait))
		return nil
	})

	bufChan := make(chan []byte)

	go func() { //create a goroutine to read from ws
		for {
			_, buf, err := tc.ws.ReadMessage()

			if err != nil {
				log.Println("Failed to receive data from ws:", err)
				close(bufChan) // close chan by producer
				close(tc.ws_done)
				break
			}

			bufChan <- buf
		}
	}()
	// we do not need to forward user input to viewers, only the stdout
out:
	for {
		select {
		case buf, ok := <-bufChan:
			if !ok {
				log.Println("Exit wsToPtyStdin routine pty stdin error")
				break out
			}
			_, err := tc.ptmx.Write(buf)

			if err != nil {
				log.Println("Failed to send data to pty stdin: ", err)
				break out
			}
		case <-tc.ws_done:
			log.Println("Exit wsToPtyStdin routine as ws is going away")
			break out
		case <-tc.pty_done:
			log.Println("Exit wsToPtyStdin routine as pty is going away")
			break out
		}
	}

	log.Println("wsToPtyStdin routine exited")
}

// shovel data from pty Stdout to WS
func (tc *TermConn) ptyStdoutToWs(wg *sync.WaitGroup) {
	var viewers []*websocket.Conn

	defer wg.Done()
	bufChan := make(chan []byte)

	go func() { //create a goroutine to read from pty
		for {
			readBuf := make([]byte, 1024) //pty reads in 1024 blocks
			n, err := tc.ptmx.Read(readBuf)

			if err != nil {
				log.Println("Failed to read from pty stdout: ", err)
				close(bufChan)
				close(tc.pty_done)
				break
			}

			readBuf = readBuf[0:n] // slice the buffer so that it is exact the size of data read.
			bufChan <- readBuf
		}
	}()

out:
	for {
		// handle viewers, we want to use non-blocking receive
		select {
		case buf, ok := <-bufChan:
			if !ok {
				tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
				tc.ws.WriteMessage(websocket.CloseMessage,
					websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Pty closed"))

				break out
			}
			// We could add ws to viewers as well (then we can use io.MultiWriter),
			// but we want to handle errors differently
			tc.ws.SetWriteDeadline(time.Now().Add(writeWait))
			if err := tc.ws.WriteMessage(websocket.BinaryMessage, buf); err != nil {
				log.Println("Failed to write message: ", err)
				break out
			}

			//write to the viewer
			for i, w := range viewers {
				if w == nil {
					continue
				}

				// if the viewer exits, we will just ignore the error
				w.SetWriteDeadline(time.Now().Add(viewWait))
				if err := w.WriteMessage(websocket.BinaryMessage, buf); err != nil {
					log.Println("Failed to write message to viewer: ", err)

					viewers[i] = nil
					w.Close() // we own the socket and need to close it
				}
			}

			// Do we need to record the session?
			if tc.record != nil {
				jbuf, err := json.Marshal(WriteRecord{Dur: time.Since(tc.lastRecTime), Data: buf})
				if err != nil {
					log.Println("Failed to marshal record", err)
				} else {
					tc.record.Write([]byte(",")) // write a deliminator
					tc.record.Write(jbuf)
				}

				tc.lastRecTime = time.Now()
			}

		case cmd := <-tc.recordChan:
			var err error
			if cmd == recordCmd {
				// use the session ID and current as file name
				fname := "./records/" + tc.Name + "_" + strconv.FormatInt(time.Now().Unix(), 16) + ".scr"

				tc.record, err = os.OpenFile(fname, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
				if err != nil {
					log.Println("Failed to create record file", fname, err)
					tc.record = nil
				}

				tc.record.Write([]byte("[")) // write a [ for an array of json objs

				// write a dummy record
				tc.lastRecTime = time.Now()
				jbuf, _ := json.Marshal(WriteRecord{Dur: time.Since(tc.lastRecTime), Data: []byte("")})
				tc.record.Write(jbuf)

			} else {
				tc.record.Write([]byte("]"))
				tc.record.Close()
				tc.record = nil
			}

		case viewer := <-tc.viewChan:
			log.Println("Received viewer", viewer.RemoteAddr().String())
			viewers = append(viewers, viewer)

		case <-tc.ws_done:
			log.Println("Exit ptyStdoutToWs routine as ws is going away")
			break out

		case <-tc.pty_done:
			log.Println("Exit ptyStdoutToWs routine as pty is going away")
			break out // do not block on these two channels
		}

	}

	// close the watcher
	for _, w := range viewers {
		if w != nil {
			w.Close()
		}
	}

	log.Println("ptyStdoutToWs routine exited")
}

// this function should be executed by the main goroutine for the connection
func (tc *TermConn) release() {
	log.Println("Releasing terminal connection", tc.Name)

	registry.removePlayer(tc.Name)

	if tc.ptmx != nil {
		// cleanup the pty and its related process
		tc.ptmx.Close()

		// terminate the command line process
		proc := tc.cmd.Process

		// send an interrupt, this will cause the shell process to
		// return from syscalls if any is pending
		if err := proc.Signal(os.Interrupt); err != nil {
			log.Printf("Failed to send Interrupt to shell process(%v): %v ", proc.Pid, err)
		}

		// Wait for a second for shell process to interrupt before kill it
		time.Sleep(time.Second)

		log.Printf("Try to kill the shell process(%v)", proc.Pid)

		if err := proc.Signal(os.Kill); err != nil {
			log.Printf("Failed to send KILL to shell process(%v): %v", proc.Pid, err)
		}

		if _, err := proc.Wait(); err != nil {
			log.Printf("Failed to wait for shell process(%v): %v", proc.Pid, err)
		}

		close(tc.viewChan)
		close(tc.recordChan)

		if tc.record != nil {
			// write a ] and close the file
			tc.record.Write([]byte("]"))
			tc.record.Close()
			tc.record = nil
		}
	}

	tc.ws.Close()
}

// handle websockets
func handlePlayer(w http.ResponseWriter, r *http.Request, name string, cmdline []string) {
	ws, err := upgrader.Upgrade(w, r, nil)

	if err != nil {
		log.Println("Failed to create websocket: ", err)
		return
	}

	tc := TermConn{
		ws:   ws,
		Name: name,
		Ip:   ws.RemoteAddr().String(),
	}

	defer tc.release()
	log.Println("Created the websocket to", ws.RemoteAddr().String())

	tc.ws_done = make(chan struct{})
	tc.pty_done = make(chan struct{})
	tc.viewChan = make(chan *websocket.Conn)
	tc.recordChan = make(chan int)

	if err := tc.createPty(cmdline); err != nil {
		log.Println("Failed to create PTY: ", err)
		return
	}

	registry.addPlayer(&tc)

	// main event loop to shovel data between ws and pty
	// do not call ptyStdoutToWs in this goroutine, otherwise
	// the websocket will not close. This is because ptyStdoutToWs
	// is usually blocked in the pty.Read
	var wg sync.WaitGroup
	wg.Add(3)

	go tc.ping(&wg)
	go tc.ptyStdoutToWs(&wg)
	go tc.wsToPtyStdin(&wg)

	wg.Wait()
	log.Println("Wait returned")
}

// handle websockets
func handleViewer(w http.ResponseWriter, r *http.Request, path string) {
	ws, err := upgrader.Upgrade(w, r, nil)

	if err != nil {
		log.Println("Failed to create websocket: ", err)
		return
	}

	log.Println("Created the websocket to", ws.RemoteAddr().String())
	if !registry.sendToPlayer(path, ws) {
		log.Println("Failed to send websocket to player, close it")
		ws.Close()
	}
}

func ConnectTerm(w http.ResponseWriter, r *http.Request, isViewer bool, name string, cmdline []string) {
	if !isViewer {
		handlePlayer(w, r, name, cmdline)
	} else {
		handleViewer(w, r, name)
	}
}

func Init() {
	registry.init()
}

func StartRecord(id string) {
	registry.recordSession(id, recordCmd)
}

func StopRecord(id string) {
	registry.recordSession(id, stopCmd)
}