3
0
mirror of https://github.com/ergochat/ergo.git synced 2025-01-12 05:02:35 +01:00
ergo/irc/utils/proxy.go
Shivaram Lingamneni 3ceff6a8b1 make ReloadableListener lock-free
Also stop attaching the *tls.Config to the wrapped connection,
since this forces it to be retained beyond its natural lifetime.
2023-01-05 20:18:14 -05:00

299 lines
7.6 KiB
Go

// Copyright (c) 2020 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package utils
import (
"crypto/tls"
"encoding/binary"
"io"
"net"
"strings"
"sync/atomic"
"time"
)
const (
// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
// "a 108-byte buffer is always enough to store all the line and a trailing zero
// for string processing."
maxProxyLineLenV1 = 107
)
// XXX implement net.Error with a Temporary() method that returns true;
// otherwise, ErrBadProxyLine will cause (*http.Server).Serve() to exit
type proxyLineError struct{}
func (p *proxyLineError) Error() string {
return "invalid PROXY line"
}
func (p *proxyLineError) Timeout() bool {
return false
}
func (p *proxyLineError) Temporary() bool {
return true
}
var (
ErrBadProxyLine error = &proxyLineError{}
)
// ListenerConfig is all the information about how to process
// incoming IRC connections on a listener.
type ListenerConfig struct {
TLSConfig *tls.Config
ProxyDeadline time.Duration
RequireProxy bool
// these are just metadata for easier tracking,
// they are not used by ReloadableListener:
Tor bool
STSOnly bool
WebSocket bool
HideSTS bool
}
// read a PROXY header (either v1 or v2), ensuring we don't read anything beyond
// the header into a buffer (this would break the TLS handshake)
func readRawProxyLine(conn net.Conn, deadline time.Duration) (result []byte, err error) {
// normally this is covered by ping timeouts, but we're doing this outside
// of the normal client goroutine:
conn.SetDeadline(time.Now().Add(deadline))
defer conn.SetDeadline(time.Time{})
// read the first 16 bytes of the proxy header
buf := make([]byte, 16, maxProxyLineLenV1)
_, err = io.ReadFull(conn, buf)
if err != nil {
return
}
switch buf[0] {
case 'P':
// PROXY v1: starts with "PROXY"
return readRawProxyLineV1(conn, buf)
case '\r':
// PROXY v2: starts with "\r\n\r\n"
return readRawProxyLineV2(conn, buf)
default:
return nil, ErrBadProxyLine
}
}
func readRawProxyLineV1(conn net.Conn, buf []byte) (result []byte, err error) {
for {
i := len(buf)
if i >= maxProxyLineLenV1 {
return nil, ErrBadProxyLine // did not find \r\n, fail
}
// prepare a single byte of free space, then read into it
buf = buf[0 : i+1]
_, err = io.ReadFull(conn, buf[i:])
if err != nil {
return nil, err
}
if buf[i] == '\n' {
return buf, nil
}
}
}
func readRawProxyLineV2(conn net.Conn, buf []byte) (result []byte, err error) {
// "The 15th and 16th bytes is the address length in bytes in network endian order."
addrLen := int(binary.BigEndian.Uint16(buf[14:16]))
if addrLen == 0 {
return buf[0:16], nil
} else if addrLen <= cap(buf)-16 {
buf = buf[0 : 16+addrLen]
} else {
// proxy source is unix domain, we don't really handle this
buf2 := make([]byte, 16+addrLen)
copy(buf2[0:16], buf[0:16])
buf = buf2
}
_, err = io.ReadFull(conn, buf[16:16+addrLen])
if err != nil {
return
}
return buf[0 : 16+addrLen], nil
}
// ParseProxyLine parses a PROXY protocol (v1 or v2) line and returns the remote IP.
func ParseProxyLine(line []byte) (ip net.IP, err error) {
if len(line) == 0 {
return nil, ErrBadProxyLine
}
switch line[0] {
case 'P':
return ParseProxyLineV1(string(line))
case '\r':
return parseProxyLineV2(line)
default:
return nil, ErrBadProxyLine
}
}
// ParseProxyLineV1 parses a PROXY protocol (v1) line and returns the remote IP.
func ParseProxyLineV1(line string) (ip net.IP, err error) {
params := strings.Fields(line)
if len(params) != 6 || params[0] != "PROXY" {
return nil, ErrBadProxyLine
}
ip = net.ParseIP(params[2])
if ip == nil {
return nil, ErrBadProxyLine
}
return ip.To16(), nil
}
func parseProxyLineV2(line []byte) (ip net.IP, err error) {
if len(line) < 16 {
return nil, ErrBadProxyLine
}
// this doesn't allocate
if string(line[:12]) != "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a" {
return nil, ErrBadProxyLine
}
// "The next byte (the 13th one) is the protocol version and command."
versionCmd := line[12]
// "The highest four bits contains the version [....] it must always be sent as \x2"
if (versionCmd >> 4) != 2 {
return nil, ErrBadProxyLine
}
// "The lowest four bits represents the command"
switch versionCmd & 0x0f {
case 0:
return nil, nil // LOCAL command
case 1:
// PROXY command, continue below
default:
// "Receivers must drop connections presenting unexpected values here"
return nil, ErrBadProxyLine
}
var addrLen int
// "The 14th byte contains the transport protocol and address family."
protoAddr := line[13]
// "The highest 4 bits contain the address family"
switch protoAddr >> 4 {
case 1:
addrLen = 4 // AF_INET
case 2:
addrLen = 16 // AF_INET6
default:
return nil, nil // AF_UNSPEC or AF_UNIX, either way there's no IP address
}
// header, source and destination address, two 16-bit port numbers:
expectedLen := 16 + 2*addrLen + 4
if len(line) < expectedLen {
return nil, ErrBadProxyLine
}
// "Starting from the 17th byte, addresses are presented in network byte order.
// The address order is always the same :
// - source layer 3 address in network byte order [...]"
if addrLen == 4 {
ip = net.IP(line[16 : 16+addrLen]).To16()
} else {
ip = make(net.IP, addrLen)
copy(ip, line[16:16+addrLen])
}
return ip, nil
}
// / WrappedConn is a net.Conn with some additional data stapled to it;
// the proxied IP, if one was read via the PROXY protocol, and the listener
// configuration.
type WrappedConn struct {
net.Conn
ProxiedIP net.IP
TLS bool
Tor bool
STSOnly bool
WebSocket bool
HideSTS bool
// Secure indicates whether we believe the connection between us and the client
// was secure against interception and modification (including all proxies):
Secure bool
}
// ReloadableListener is a wrapper for net.Listener that allows reloading
// of config data for postprocessing connections (TLS, PROXY protocol, etc.)
type ReloadableListener struct {
realListener net.Listener
// nil means the listener is closed:
config atomic.Pointer[ListenerConfig]
}
func NewReloadableListener(realListener net.Listener, config ListenerConfig) *ReloadableListener {
result := &ReloadableListener{
realListener: realListener,
}
result.config.Store(&config) // heap escape
return result
}
func (rl *ReloadableListener) Reload(config ListenerConfig) {
rl.config.Store(&config)
}
func (rl *ReloadableListener) Accept() (conn net.Conn, err error) {
conn, err = rl.realListener.Accept()
config := rl.config.Load()
if config == nil {
// Close() was called
if err == nil {
conn.Close()
}
err = net.ErrClosed
}
if err != nil {
return nil, err
}
var proxiedIP net.IP
if config.RequireProxy {
// this will occur synchronously on the goroutine calling Accept(),
// but that's OK because this listener *requires* a PROXY line,
// therefore it must be used with proxies that always send the line
// and we won't get slowloris'ed waiting for the client response
proxyLine, err := readRawProxyLine(conn, config.ProxyDeadline)
if err == nil {
proxiedIP, err = ParseProxyLine(proxyLine)
}
if err != nil {
conn.Close()
return nil, err
}
}
if config.TLSConfig != nil {
conn = tls.Server(conn, config.TLSConfig)
}
return &WrappedConn{
Conn: conn,
ProxiedIP: proxiedIP,
TLS: config.TLSConfig != nil,
Tor: config.Tor,
STSOnly: config.STSOnly,
WebSocket: config.WebSocket,
HideSTS: config.HideSTS,
// Secure will be set later by client code
}, nil
}
func (rl *ReloadableListener) Close() error {
rl.config.Store(nil)
return rl.realListener.Close()
}
func (rl *ReloadableListener) Addr() net.Addr {
return rl.realListener.Addr()
}