// Copyright (c) 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license

package utils

import (
	"crypto/tls"
	"encoding/binary"
	"io"
	"net"
	"strings"
	"sync"
	"time"
)

const (
	// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
	// "a 108-byte buffer is always enough to store all the line and a trailing zero
	// for string processing."
	maxProxyLineLenV1 = 107
)

// XXX implement net.Error with a Temporary() method that returns true;
// otherwise, ErrBadProxyLine will cause (*http.Server).Serve() to exit
type proxyLineError struct{}

func (p *proxyLineError) Error() string {
	return "invalid PROXY line"
}

func (p *proxyLineError) Timeout() bool {
	return false
}

func (p *proxyLineError) Temporary() bool {
	return true
}

var (
	ErrBadProxyLine error = &proxyLineError{}
)

// ListenerConfig is all the information about how to process
// incoming IRC connections on a listener.
type ListenerConfig struct {
	TLSConfig     *tls.Config
	ProxyDeadline time.Duration
	RequireProxy  bool
	// these are just metadata for easier tracking,
	// they are not used by ReloadableListener:
	Tor       bool
	STSOnly   bool
	WebSocket bool
	HideSTS   bool
}

// read a PROXY header (either v1 or v2), ensuring we don't read anything beyond
// the header into a buffer (this would break the TLS handshake)
func readRawProxyLine(conn net.Conn, deadline time.Duration) (result []byte, err error) {
	// normally this is covered by ping timeouts, but we're doing this outside
	// of the normal client goroutine:
	conn.SetDeadline(time.Now().Add(deadline))
	defer conn.SetDeadline(time.Time{})

	// read the first 16 bytes of the proxy header
	buf := make([]byte, 16, maxProxyLineLenV1)
	_, err = io.ReadFull(conn, buf)
	if err != nil {
		return
	}

	switch buf[0] {
	case 'P':
		// PROXY v1: starts with "PROXY"
		return readRawProxyLineV1(conn, buf)
	case '\r':
		// PROXY v2: starts with "\r\n\r\n"
		return readRawProxyLineV2(conn, buf)
	default:
		return nil, ErrBadProxyLine
	}
}

func readRawProxyLineV1(conn net.Conn, buf []byte) (result []byte, err error) {
	for {
		i := len(buf)
		if i >= maxProxyLineLenV1 {
			return nil, ErrBadProxyLine // did not find \r\n, fail
		}
		// prepare a single byte of free space, then read into it
		buf = buf[0 : i+1]
		_, err = io.ReadFull(conn, buf[i:])
		if err != nil {
			return nil, err
		}
		if buf[i] == '\n' {
			return buf, nil
		}
	}
}

func readRawProxyLineV2(conn net.Conn, buf []byte) (result []byte, err error) {
	// "The 15th and 16th bytes is the address length in bytes in network endian order."
	addrLen := int(binary.BigEndian.Uint16(buf[14:16]))
	if addrLen == 0 {
		return buf[0:16], nil
	} else if addrLen <= cap(buf)-16 {
		buf = buf[0 : 16+addrLen]
	} else {
		// proxy source is unix domain, we don't really handle this
		buf2 := make([]byte, 16+addrLen)
		copy(buf2[0:16], buf[0:16])
		buf = buf2
	}
	_, err = io.ReadFull(conn, buf[16:16+addrLen])
	if err != nil {
		return
	}
	return buf[0 : 16+addrLen], nil
}

// ParseProxyLine parses a PROXY protocol (v1 or v2) line and returns the remote IP.
func ParseProxyLine(line []byte) (ip net.IP, err error) {
	if len(line) == 0 {
		return nil, ErrBadProxyLine
	}
	switch line[0] {
	case 'P':
		return ParseProxyLineV1(string(line))
	case '\r':
		return parseProxyLineV2(line)
	default:
		return nil, ErrBadProxyLine
	}
}

// ParseProxyLineV1 parses a PROXY protocol (v1) line and returns the remote IP.
func ParseProxyLineV1(line string) (ip net.IP, err error) {
	params := strings.Fields(line)
	if len(params) != 6 || params[0] != "PROXY" {
		return nil, ErrBadProxyLine
	}
	ip = net.ParseIP(params[2])
	if ip == nil {
		return nil, ErrBadProxyLine
	}
	return ip.To16(), nil
}

func parseProxyLineV2(line []byte) (ip net.IP, err error) {
	if len(line) < 16 {
		return nil, ErrBadProxyLine
	}
	// this doesn't allocate
	if string(line[:12]) != "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a" {
		return nil, ErrBadProxyLine
	}
	// "The next byte (the 13th one) is the protocol version and command."
	versionCmd := line[12]
	// "The highest four bits contains the version [....] it must always be sent as \x2"
	if (versionCmd >> 4) != 2 {
		return nil, ErrBadProxyLine
	}
	// "The lowest four bits represents the command"
	switch versionCmd & 0x0f {
	case 0:
		return nil, nil // LOCAL command
	case 1:
		// PROXY command, continue below
	default:
		// "Receivers must drop connections presenting unexpected values here"
		return nil, ErrBadProxyLine
	}

	var addrLen int
	// "The 14th byte contains the transport protocol and address family."
	protoAddr := line[13]
	// "The highest 4 bits contain the address family"
	switch protoAddr >> 4 {
	case 1:
		addrLen = 4 // AF_INET
	case 2:
		addrLen = 16 // AF_INET6
	default:
		return nil, nil // AF_UNSPEC or AF_UNIX, either way there's no IP address
	}

	// header, source and destination address, two 16-bit port numbers:
	expectedLen := 16 + 2*addrLen + 4
	if len(line) < expectedLen {
		return nil, ErrBadProxyLine
	}

	// "Starting from the 17th byte, addresses are presented in network byte order.
	//  The address order is always the same :
	//    - source layer 3 address in network byte order [...]"
	if addrLen == 4 {
		ip = net.IP(line[16 : 16+addrLen]).To16()
	} else {
		ip = make(net.IP, addrLen)
		copy(ip, line[16:16+addrLen])
	}
	return ip, nil
}

/// WrappedConn is a net.Conn with some additional data stapled to it;
// the proxied IP, if one was read via the PROXY protocol, and the listener
// configuration.
type WrappedConn struct {
	net.Conn
	ProxiedIP net.IP
	Config    ListenerConfig
	// Secure indicates whether we believe the connection between us and the client
	// was secure against interception and modification (including all proxies):
	Secure bool
}

// ReloadableListener is a wrapper for net.Listener that allows reloading
// of config data for postprocessing connections (TLS, PROXY protocol, etc.)
type ReloadableListener struct {
	// TODO: make this lock-free
	sync.Mutex
	realListener net.Listener
	config       ListenerConfig
	isClosed     bool
}

func NewReloadableListener(realListener net.Listener, config ListenerConfig) *ReloadableListener {
	return &ReloadableListener{
		realListener: realListener,
		config:       config,
	}
}

func (rl *ReloadableListener) Reload(config ListenerConfig) {
	rl.Lock()
	rl.config = config
	rl.Unlock()
}

func (rl *ReloadableListener) Accept() (conn net.Conn, err error) {
	conn, err = rl.realListener.Accept()

	rl.Lock()
	config := rl.config
	isClosed := rl.isClosed
	rl.Unlock()

	if isClosed {
		if err == nil {
			conn.Close()
		}
		err = net.ErrClosed
	}
	if err != nil {
		return nil, err
	}

	var proxiedIP net.IP
	if config.RequireProxy {
		// this will occur synchronously on the goroutine calling Accept(),
		// but that's OK because this listener *requires* a PROXY line,
		// therefore it must be used with proxies that always send the line
		// and we won't get slowloris'ed waiting for the client response
		proxyLine, err := readRawProxyLine(conn, config.ProxyDeadline)
		if err == nil {
			proxiedIP, err = ParseProxyLine(proxyLine)
		}
		if err != nil {
			conn.Close()
			return nil, err
		}
	}

	if config.TLSConfig != nil {
		conn = tls.Server(conn, config.TLSConfig)
	}

	return &WrappedConn{
		Conn:      conn,
		ProxiedIP: proxiedIP,
		Config:    config,
	}, nil
}

func (rl *ReloadableListener) Close() error {
	rl.Lock()
	rl.isClosed = true
	rl.Unlock()

	return rl.realListener.Close()
}

func (rl *ReloadableListener) Addr() net.Addr {
	return rl.realListener.Addr()
}