mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-25 21:39:25 +01:00
refactor the rehash implementation
This commit is contained in:
parent
eae04e8c51
commit
e8b1870067
@ -57,7 +57,7 @@ func NewChannel(s *Server, name string, addDefaultModes bool) *Channel {
|
||||
}
|
||||
|
||||
if addDefaultModes {
|
||||
for _, mode := range s.defaultChannelModes {
|
||||
for _, mode := range s.GetDefaultChannelModes() {
|
||||
channel.flags[mode] = true
|
||||
}
|
||||
}
|
||||
|
@ -244,6 +244,8 @@ type Config struct {
|
||||
WhowasEntries uint `yaml:"whowas-entries"`
|
||||
LineLen LineLenConfig `yaml:"linelen"`
|
||||
}
|
||||
|
||||
Filename string
|
||||
}
|
||||
|
||||
// OperClass defines an assembled operator class.
|
||||
@ -390,6 +392,8 @@ func LoadConfig(filename string) (config *Config, err error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.Filename = filename
|
||||
|
||||
// we need this so PasswordBytes returns the correct info
|
||||
if config.Server.Password != "" {
|
||||
config.Server.PassConfig.Password = config.Server.Password
|
||||
|
@ -53,6 +53,32 @@ func InitDB(path string) {
|
||||
}
|
||||
}
|
||||
|
||||
// open an existing database, performing a schema version check
|
||||
func OpenDatabase(path string) (*buntdb.DB, error) {
|
||||
// open data store
|
||||
db, err := buntdb.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check db version
|
||||
err = db.View(func(tx *buntdb.Tx) error {
|
||||
version, _ := tx.Get(keySchemaVersion)
|
||||
if version != latestDbSchema {
|
||||
return fmt.Errorf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// close the db
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// UpgradeDB upgrades the datastore to the latest schema.
|
||||
func UpgradeDB(path string) {
|
||||
store, err := buntdb.Open(path)
|
||||
|
23
irc/help.go
23
irc/help.go
@ -530,10 +530,10 @@ Oragono supports the following channel membership prefixes:
|
||||
}
|
||||
|
||||
// HelpIndex contains the list of all help topics for regular users.
|
||||
var HelpIndex = "list of all help topics for regular users"
|
||||
var HelpIndex string
|
||||
|
||||
// HelpIndexOpers contains the list of all help topics for opers.
|
||||
var HelpIndexOpers = "list of all help topics for opers"
|
||||
var HelpIndexOpers string
|
||||
|
||||
// GenerateHelpIndex is used to generate HelpIndex.
|
||||
func GenerateHelpIndex(forOpers bool) string {
|
||||
@ -582,6 +582,25 @@ Information:
|
||||
return newHelpIndex
|
||||
}
|
||||
|
||||
func GenerateHelpIndices() error {
|
||||
if HelpIndex != "" && HelpIndexOpers != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// startup check that we have HELP entries for every command
|
||||
for name := range Commands {
|
||||
_, exists := Help[strings.ToLower(name)]
|
||||
if !exists {
|
||||
return fmt.Errorf("Help entry does not exist for command %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// generate help indexes
|
||||
HelpIndex = GenerateHelpIndex(false)
|
||||
HelpIndexOpers = GenerateHelpIndex(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendHelp sends the client help of the given string.
|
||||
func (client *Client) sendHelp(name string, text string) {
|
||||
splitName := strings.Split(name, " ")
|
||||
|
@ -19,9 +19,9 @@ import (
|
||||
|
||||
const restErr = "{\"error\":\"An unknown error occurred\"}"
|
||||
|
||||
// restAPIServer is used to keep a link to the current running server since this is the best
|
||||
// ircServer is used to keep a link to the current running server since this is the best
|
||||
// way to do it, given how HTTP handlers dispatch and work.
|
||||
var restAPIServer *Server
|
||||
var ircServer *Server
|
||||
|
||||
type restInfoResp struct {
|
||||
ServerName string `json:"server-name"`
|
||||
@ -60,8 +60,8 @@ type restRehashResp struct {
|
||||
func restInfo(w http.ResponseWriter, r *http.Request) {
|
||||
rs := restInfoResp{
|
||||
Version: SemVer,
|
||||
ServerName: restAPIServer.name,
|
||||
NetworkName: restAPIServer.networkName,
|
||||
ServerName: ircServer.name,
|
||||
NetworkName: ircServer.networkName,
|
||||
}
|
||||
b, err := json.Marshal(rs)
|
||||
if err != nil {
|
||||
@ -73,9 +73,9 @@ func restInfo(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func restStatus(w http.ResponseWriter, r *http.Request) {
|
||||
rs := restStatusResp{
|
||||
Clients: restAPIServer.clients.Count(),
|
||||
Opers: len(restAPIServer.operators),
|
||||
Channels: restAPIServer.channels.Len(),
|
||||
Clients: ircServer.clients.Count(),
|
||||
Opers: len(ircServer.operators),
|
||||
Channels: ircServer.channels.Len(),
|
||||
}
|
||||
b, err := json.Marshal(rs)
|
||||
if err != nil {
|
||||
@ -87,8 +87,8 @@ func restStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func restGetXLines(w http.ResponseWriter, r *http.Request) {
|
||||
rs := restXLinesResp{
|
||||
DLines: restAPIServer.dlines.AllBans(),
|
||||
KLines: restAPIServer.klines.AllBans(),
|
||||
DLines: ircServer.dlines.AllBans(),
|
||||
KLines: ircServer.klines.AllBans(),
|
||||
}
|
||||
b, err := json.Marshal(rs)
|
||||
if err != nil {
|
||||
@ -104,7 +104,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// get accounts
|
||||
err := restAPIServer.store.View(func(tx *buntdb.Tx) error {
|
||||
err := ircServer.store.View(func(tx *buntdb.Tx) error {
|
||||
tx.AscendKeys("account.exists *", func(key, value string) bool {
|
||||
key = key[len("account.exists "):]
|
||||
_, err := tx.Get(fmt.Sprintf(keyAccountVerified, key))
|
||||
@ -118,7 +118,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
regTime := time.Unix(regTimeInt, 0)
|
||||
|
||||
var clients int
|
||||
acct := restAPIServer.accounts[key]
|
||||
acct := ircServer.accounts[key]
|
||||
if acct != nil {
|
||||
clients = len(acct.Clients)
|
||||
}
|
||||
@ -148,7 +148,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func restRehash(w http.ResponseWriter, r *http.Request) {
|
||||
err := restAPIServer.rehash()
|
||||
err := ircServer.rehash()
|
||||
|
||||
rs := restRehashResp{
|
||||
Successful: err == nil,
|
||||
@ -166,9 +166,9 @@ func restRehash(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) startRestAPI() {
|
||||
func StartRestAPI(s *Server, listenAddr string) (*http.Server, error) {
|
||||
// so handlers can ref it later
|
||||
restAPIServer = s
|
||||
ircServer = s
|
||||
|
||||
// start router
|
||||
r := mux.NewRouter()
|
||||
@ -185,5 +185,16 @@ func (s *Server) startRestAPI() {
|
||||
rp.HandleFunc("/rehash", restRehash)
|
||||
|
||||
// start api
|
||||
go http.ListenAndServe(s.restAPI.Listen, r)
|
||||
httpserver := http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := httpserver.ListenAndServe(); err != nil {
|
||||
s.logger.Error("listeners", fmt.Sprintf("Rest API listenAndServe error: %s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
return &httpserver, nil
|
||||
}
|
||||
|
448
irc/server.go
448
irc/server.go
@ -7,9 +7,9 @@ package irc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
@ -35,8 +35,14 @@ var (
|
||||
tooManyClientsMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Too many clients from your network")}[0]).Line()
|
||||
couldNotParseIPMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line()
|
||||
bannedFromServerMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "You are banned from this server (%s)")}[0]).Line()
|
||||
)
|
||||
|
||||
errDbOutOfDate = errors.New("Database schema is old")
|
||||
const (
|
||||
// when shutting down the REST and websocket servers, wait this long
|
||||
// before killing active non-WS connections. TODO: this might not be
|
||||
// necessary at all? but it seems prudent to avoid potential resource
|
||||
// leaks
|
||||
httpShutdownTimeout = time.Second
|
||||
)
|
||||
|
||||
// Limits holds the maximum limits for various things such as topic lengths.
|
||||
@ -80,6 +86,7 @@ type Server struct {
|
||||
clients *ClientLookupSet
|
||||
commands chan Command
|
||||
configFilename string
|
||||
configurableStateMutex sync.RWMutex // generic protection for server state modified by rehash()
|
||||
connectionLimits *ConnectionLimits
|
||||
connectionLimitsMutex sync.Mutex // used when affecting the connection limiter, to make sure rehashing doesn't make things go out-of-whack
|
||||
connectionThrottle *ConnectionThrottle
|
||||
@ -109,13 +116,15 @@ type Server struct {
|
||||
registeredChannelsMutex sync.RWMutex
|
||||
rehashMutex sync.Mutex
|
||||
rehashSignal chan os.Signal
|
||||
restAPI *RestAPIConfig
|
||||
restAPI RestAPIConfig
|
||||
restAPIServer *http.Server
|
||||
proxyAllowedFrom []string
|
||||
signals chan os.Signal
|
||||
snomasks *SnoManager
|
||||
store *buntdb.DB
|
||||
stsEnabled bool
|
||||
whoWas *WhoWasList
|
||||
wsServer *http.Server
|
||||
}
|
||||
|
||||
var (
|
||||
@ -133,216 +142,38 @@ type clientConn struct {
|
||||
}
|
||||
|
||||
// NewServer returns a new Oragono server.
|
||||
func NewServer(configFilename string, config *Config, logger *logger.Manager) (*Server, error) {
|
||||
casefoldedName, err := Casefold(config.Server.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error())
|
||||
}
|
||||
|
||||
// startup check that we have HELP entries for every command
|
||||
for name := range Commands {
|
||||
_, exists := Help[strings.ToLower(name)]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("Help entry does not exist for command %s", name)
|
||||
}
|
||||
}
|
||||
// generate help indexes
|
||||
HelpIndex = GenerateHelpIndex(false)
|
||||
HelpIndexOpers = GenerateHelpIndex(true)
|
||||
|
||||
if config.Accounts.AuthenticationEnabled {
|
||||
SupportedCapabilities[SASL] = true
|
||||
}
|
||||
|
||||
if config.Server.STS.Enabled {
|
||||
SupportedCapabilities[STS] = true
|
||||
CapValues[STS] = config.Server.STS.Value()
|
||||
}
|
||||
|
||||
if config.Limits.LineLen.Tags > 512 || config.Limits.LineLen.Rest > 512 {
|
||||
SupportedCapabilities[MaxLine] = true
|
||||
CapValues[MaxLine] = fmt.Sprintf("%d,%d", config.Limits.LineLen.Tags, config.Limits.LineLen.Rest)
|
||||
}
|
||||
|
||||
operClasses, err := config.OperatorClasses()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error loading oper classes: %s", err.Error())
|
||||
}
|
||||
opers, err := config.Operators(operClasses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error loading operators: %s", err.Error())
|
||||
}
|
||||
|
||||
connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error loading connection limits: %s", err.Error())
|
||||
}
|
||||
connectionThrottle, err := NewConnectionThrottle(config.Server.ConnectionThrottle)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error loading connection throttler: %s", err.Error())
|
||||
func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
|
||||
// TODO move this to main?
|
||||
if err := GenerateHelpIndices(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// initialize data structures
|
||||
server := &Server{
|
||||
accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled,
|
||||
accounts: make(map[string]*ClientAccount),
|
||||
channelRegistrationEnabled: config.Channels.Registration.Enabled,
|
||||
channels: *NewChannelNameMap(),
|
||||
checkIdent: config.Server.CheckIdent,
|
||||
clients: NewClientLookupSet(),
|
||||
commands: make(chan Command),
|
||||
configFilename: configFilename,
|
||||
connectionLimits: connectionLimits,
|
||||
connectionThrottle: connectionThrottle,
|
||||
ctime: time.Now(),
|
||||
currentOpers: make(map[*Client]bool),
|
||||
defaultChannelModes: ParseDefaultChannelModes(config),
|
||||
limits: Limits{
|
||||
AwayLen: int(config.Limits.AwayLen),
|
||||
ChannelLen: int(config.Limits.ChannelLen),
|
||||
KickLen: int(config.Limits.KickLen),
|
||||
MonitorEntries: int(config.Limits.MonitorEntries),
|
||||
NickLen: int(config.Limits.NickLen),
|
||||
TopicLen: int(config.Limits.TopicLen),
|
||||
ChanListModes: int(config.Limits.ChanListModes),
|
||||
LineLen: LineLenLimits{
|
||||
Tags: config.Limits.LineLen.Tags,
|
||||
Rest: config.Limits.LineLen.Rest,
|
||||
},
|
||||
},
|
||||
listeners: make(map[string]*ListenerWrapper),
|
||||
logger: logger,
|
||||
MaxSendQBytes: config.Server.MaxSendQBytes,
|
||||
monitoring: make(map[string][]*Client),
|
||||
name: config.Server.Name,
|
||||
nameCasefolded: casefoldedName,
|
||||
networkName: config.Network.Name,
|
||||
newConns: make(chan clientConn),
|
||||
operators: opers,
|
||||
operclasses: *operClasses,
|
||||
proxyAllowedFrom: config.Server.ProxyAllowedFrom,
|
||||
registeredChannels: make(map[string]*RegisteredChannel),
|
||||
rehashSignal: make(chan os.Signal, 1),
|
||||
restAPI: &config.Server.RestAPI,
|
||||
signals: make(chan os.Signal, len(ServerExitSignals)),
|
||||
snomasks: NewSnoManager(),
|
||||
stsEnabled: config.Server.STS.Enabled,
|
||||
whoWas: NewWhoWasList(config.Limits.WhowasEntries),
|
||||
}
|
||||
|
||||
// open data store
|
||||
server.logger.Debug("startup", "Opening datastore")
|
||||
db, err := buntdb.Open(config.Datastore.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to open datastore: %s", err.Error())
|
||||
if err := server.applyConfig(config, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.store = db
|
||||
|
||||
// check db version
|
||||
err = server.store.View(func(tx *buntdb.Tx) error {
|
||||
version, _ := tx.Get(keySchemaVersion)
|
||||
if version != latestDbSchema {
|
||||
logger.Error("startup", "server", fmt.Sprintf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version))
|
||||
return errDbOutOfDate
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// close the db
|
||||
db.Close()
|
||||
return nil, errDbOutOfDate
|
||||
}
|
||||
|
||||
// load *lines
|
||||
server.logger.Debug("startup", "Loading D/Klines")
|
||||
server.loadDLines()
|
||||
server.loadKLines()
|
||||
|
||||
// load password manager
|
||||
server.logger.Debug("startup", "Loading passwords")
|
||||
err = server.store.View(func(tx *buntdb.Tx) error {
|
||||
saltString, err := tx.Get(keySalt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve salt string: %s", err.Error())
|
||||
}
|
||||
|
||||
salt, err := base64.StdEncoding.DecodeString(saltString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pwm := NewPasswordManager(salt)
|
||||
server.passwords = &pwm
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not load salt: %s", err.Error())
|
||||
}
|
||||
|
||||
server.logger.Debug("startup", "Loading MOTD")
|
||||
if config.Server.MOTD != "" {
|
||||
file, err := os.Open(config.Server.MOTD)
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
|
||||
reader := bufio.NewReader(file)
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
// "- " is the required prefix for MOTD, we just add it here to make
|
||||
// bursting it out to clients easier
|
||||
line = fmt.Sprintf("- %s", line)
|
||||
|
||||
server.motdLines = append(server.motdLines, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.Server.Password != "" {
|
||||
server.password = config.Server.PasswordBytes()
|
||||
}
|
||||
|
||||
tlsListeners := config.TLSListeners()
|
||||
for _, addr := range config.Server.Listen {
|
||||
server.listeners[addr] = server.createListener(addr, tlsListeners[addr])
|
||||
}
|
||||
|
||||
if len(tlsListeners) == 0 {
|
||||
server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections")
|
||||
}
|
||||
var usesStandardTLSPort bool
|
||||
for addr := range config.TLSListeners() {
|
||||
if strings.Contains(addr, "6697") {
|
||||
usesStandardTLSPort = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if 0 < len(tlsListeners) && !usesStandardTLSPort {
|
||||
server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely")
|
||||
}
|
||||
|
||||
if config.Server.Wslisten != "" {
|
||||
server.wslisten(config.Server.Wslisten, config.Server.TLSListeners)
|
||||
}
|
||||
|
||||
// registration
|
||||
accountReg := NewAccountRegistration(config.Accounts.Registration)
|
||||
server.accountRegistration = &accountReg
|
||||
|
||||
// Attempt to clean up when receiving these signals.
|
||||
signal.Notify(server.signals, ServerExitSignals...)
|
||||
signal.Notify(server.rehashSignal, syscall.SIGHUP)
|
||||
|
||||
server.setISupport()
|
||||
|
||||
// start API if enabled
|
||||
if server.restAPI.Enabled {
|
||||
logger.Info("startup", "server", fmt.Sprintf("%s rest API started on %s.", server.name, server.restAPI.Listen))
|
||||
server.startRestAPI()
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
@ -514,13 +345,6 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen
|
||||
stopEvent: make(chan bool, 1),
|
||||
}
|
||||
|
||||
// TODO(slingamn) move all logging of listener status to rehash()
|
||||
tlsString := "plaintext"
|
||||
if tlsConfig != nil {
|
||||
tlsString = "TLS"
|
||||
}
|
||||
server.logger.Info("listeners", fmt.Sprintf("listening on %s using %s.", addr, tlsString))
|
||||
|
||||
var shouldStop bool
|
||||
|
||||
// setup accept goroutine
|
||||
@ -562,8 +386,29 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen
|
||||
// websocket listen goroutine
|
||||
//
|
||||
|
||||
func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig) {
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
func (server *Server) setupWSListener(config *Config) {
|
||||
// unconditionally shut down the old listener because we can't tell
|
||||
// whether we need to reload the TLS certificate
|
||||
if server.wsServer != nil {
|
||||
ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout)
|
||||
server.wsServer.Shutdown(ctx)
|
||||
server.wsServer.Close()
|
||||
}
|
||||
|
||||
if config.Server.Wslisten == "" {
|
||||
server.wsServer = nil
|
||||
return
|
||||
}
|
||||
|
||||
addr := config.Server.Wslisten
|
||||
tlsConfig := config.Server.TLSListeners[addr]
|
||||
handler := http.NewServeMux()
|
||||
wsServer := http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
handler.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
server.logger.Error("ws", addr, fmt.Sprintf("%s method not allowed", r.Method))
|
||||
return
|
||||
@ -584,29 +429,26 @@ func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig)
|
||||
|
||||
newConn := clientConn{
|
||||
Conn: WSContainer{ws},
|
||||
IsTLS: false, //TODO(dan): track TLS or not here properly
|
||||
IsTLS: tlsConfig != nil,
|
||||
}
|
||||
server.newConns <- newConn
|
||||
})
|
||||
|
||||
go func() {
|
||||
config, listenTLS := tlsMap[addr]
|
||||
|
||||
tlsString := "plaintext"
|
||||
var err error
|
||||
if listenTLS {
|
||||
tlsString = "TLS"
|
||||
}
|
||||
server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s using %s.", addr, tlsString))
|
||||
server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s, tls=%t.", addr, tlsConfig != nil))
|
||||
|
||||
if listenTLS {
|
||||
err = http.ListenAndServeTLS(addr, config.Cert, config.Key, nil)
|
||||
if tlsConfig != nil {
|
||||
err = wsServer.ListenAndServeTLS(tlsConfig.Cert, tlsConfig.Key)
|
||||
} else {
|
||||
err = http.ListenAndServe(addr, nil)
|
||||
err = wsServer.ListenAndServe()
|
||||
}
|
||||
if err != nil {
|
||||
server.logger.Error("listeners", fmt.Sprintf("listenAndServe error [%s]: %s", tlsString, err))
|
||||
server.logger.Error("listeners", fmt.Sprintf("websocket ListenAndServe error: %s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
server.wsServer = &wsServer
|
||||
}
|
||||
|
||||
// generateMessageID returns a network-unique message ID.
|
||||
@ -660,6 +502,9 @@ func (server *Server) tryRegister(c *Client) {
|
||||
|
||||
// MOTD serves the Message of the Day.
|
||||
func (server *Server) MOTD(client *Client) {
|
||||
server.configurableStateMutex.RLock()
|
||||
defer server.configurableStateMutex.RUnlock()
|
||||
|
||||
if len(server.motdLines) < 1 {
|
||||
client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing")
|
||||
return
|
||||
@ -1415,15 +1260,35 @@ func (server *Server) rehash() error {
|
||||
server.logger.Debug("rehash", "Got rehash lock")
|
||||
|
||||
config, err := LoadConfig(server.configFilename)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error rehashing config file config: %s", err.Error())
|
||||
return fmt.Errorf("Error loading config file config: %s", err.Error())
|
||||
}
|
||||
|
||||
err = server.applyConfig(config, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error applying config changes: %s", err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) applyConfig(config *Config, initial bool) error {
|
||||
if initial {
|
||||
server.ctime = time.Now()
|
||||
server.configFilename = config.Filename
|
||||
}
|
||||
|
||||
casefoldedName, err := Casefold(config.Server.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error())
|
||||
}
|
||||
|
||||
if !initial {
|
||||
// line lengths cannot be changed after launching the server
|
||||
if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest {
|
||||
return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted")
|
||||
}
|
||||
}
|
||||
|
||||
// confirm connectionLimits are fine
|
||||
connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits)
|
||||
@ -1453,6 +1318,18 @@ func (server *Server) rehash() error {
|
||||
}
|
||||
}
|
||||
|
||||
// sanity checks complete, start modifying server state
|
||||
|
||||
server.name = config.Server.Name
|
||||
server.nameCasefolded = casefoldedName
|
||||
server.networkName = config.Network.Name
|
||||
|
||||
if config.Server.Password != "" {
|
||||
server.password = config.Server.PasswordBytes()
|
||||
} else {
|
||||
server.password = nil
|
||||
}
|
||||
|
||||
// apply new connectionlimits
|
||||
server.connectionLimitsMutex.Lock()
|
||||
server.connectionLimits = connectionLimits
|
||||
@ -1578,7 +1455,9 @@ func (server *Server) rehash() error {
|
||||
server.accountRegistration = &accountReg
|
||||
server.channelRegistrationEnabled = config.Channels.Registration.Enabled
|
||||
|
||||
server.configurableStateMutex.Lock()
|
||||
server.defaultChannelModes = ParseDefaultChannelModes(config)
|
||||
server.configurableStateMutex.Unlock()
|
||||
|
||||
// set new sendqueue size
|
||||
if config.Server.MaxSendQBytes != server.MaxSendQBytes {
|
||||
@ -1595,8 +1474,8 @@ func (server *Server) rehash() error {
|
||||
// set RPL_ISUPPORT
|
||||
oldISupportList := server.isupport
|
||||
server.setISupport()
|
||||
if oldISupportList != nil {
|
||||
newISupportReplies := oldISupportList.GetDifference(server.isupport)
|
||||
|
||||
// push new info to all of our clients
|
||||
server.clients.ByNickMutex.RLock()
|
||||
for _, sClient := range server.clients.ByNick {
|
||||
@ -1606,7 +1485,97 @@ func (server *Server) rehash() error {
|
||||
}
|
||||
}
|
||||
server.clients.ByNickMutex.RUnlock()
|
||||
}
|
||||
|
||||
server.loadMOTD(config.Server.MOTD)
|
||||
|
||||
if initial {
|
||||
if err := server.loadDatastore(config.Datastore.Path); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// we are now open for business
|
||||
server.setupListeners(config)
|
||||
server.setupWSListener(config)
|
||||
server.setupRestAPI(config)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) loadMOTD(motdPath string) error {
|
||||
server.logger.Debug("rehash", "Loading MOTD")
|
||||
motdLines := make([]string, 0)
|
||||
if motdPath != "" {
|
||||
file, err := os.Open(motdPath)
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
|
||||
reader := bufio.NewReader(file)
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
line = strings.TrimRight(line, "\r\n")
|
||||
// "- " is the required prefix for MOTD, we just add it here to make
|
||||
// bursting it out to clients easier
|
||||
line = fmt.Sprintf("- %s", line)
|
||||
|
||||
motdLines = append(motdLines, line)
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
server.configurableStateMutex.Lock()
|
||||
defer server.configurableStateMutex.Unlock()
|
||||
server.motdLines = motdLines
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) loadDatastore(datastorePath string) error {
|
||||
// open the datastore and load server state for which it (rather than config)
|
||||
// is the source of truth
|
||||
|
||||
server.logger.Debug("startup", "Opening datastore")
|
||||
db, err := OpenDatabase(datastorePath)
|
||||
if err == nil {
|
||||
server.store = db
|
||||
} else {
|
||||
return fmt.Errorf("Failed to open datastore: %s", err.Error())
|
||||
}
|
||||
|
||||
// load *lines (from the datastores)
|
||||
server.logger.Debug("startup", "Loading D/Klines")
|
||||
server.loadDLines()
|
||||
server.loadKLines()
|
||||
|
||||
// load password manager
|
||||
server.logger.Debug("startup", "Loading passwords")
|
||||
err = server.store.View(func(tx *buntdb.Tx) error {
|
||||
saltString, err := tx.Get(keySalt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve salt string: %s", err.Error())
|
||||
}
|
||||
|
||||
salt, err := base64.StdEncoding.DecodeString(saltString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pwm := NewPasswordManager(salt)
|
||||
server.passwords = &pwm
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not load salt: %s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) setupListeners(config *Config) {
|
||||
// update or destroy all existing listeners
|
||||
tlsListeners := config.TLSListeners()
|
||||
for addr := range server.listeners {
|
||||
@ -1629,18 +1598,16 @@ func (server *Server) rehash() error {
|
||||
currentListener.configMutex.Unlock()
|
||||
|
||||
if stillConfigured {
|
||||
server.logger.Info("rehash",
|
||||
server.logger.Info("listeners",
|
||||
fmt.Sprintf("now listening on %s, tls=%t.", addr, (currentListener.tlsConfig != nil)),
|
||||
)
|
||||
} else {
|
||||
// tell the listener it should stop by interrupting its Accept() call:
|
||||
currentListener.listener.Close()
|
||||
// XXX there is no guarantee from the API when the address will actually
|
||||
// free for bind(2) again; this solution "seems to work". See here:
|
||||
// https://github.com/golang/go/issues/21833
|
||||
// TODO(golang1.10) delete stopEvent once issue #21856 is released
|
||||
<-currentListener.stopEvent
|
||||
delete(server.listeners, addr)
|
||||
server.logger.Info("rehash", fmt.Sprintf("stopped listening on %s.", addr))
|
||||
server.logger.Info("listeners", fmt.Sprintf("stopped listening on %s.", addr))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1653,7 +1620,52 @@ func (server *Server) rehash() error {
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
if len(tlsListeners) == 0 {
|
||||
server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections")
|
||||
}
|
||||
|
||||
var usesStandardTLSPort bool
|
||||
for addr := range config.TLSListeners() {
|
||||
if strings.Contains(addr, "6697") {
|
||||
usesStandardTLSPort = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if 0 < len(tlsListeners) && !usesStandardTLSPort {
|
||||
server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely")
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) setupRestAPI(config *Config) {
|
||||
restAPIEnabled := config.Server.RestAPI.Enabled
|
||||
restAPIStarted := server.restAPIServer != nil
|
||||
restAPIListenAddrChanged := server.restAPI.Listen != config.Server.RestAPI.Listen
|
||||
|
||||
// stop an existing REST server if it's been disabled or the addr changed
|
||||
if restAPIStarted && (!restAPIEnabled || restAPIListenAddrChanged) {
|
||||
ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout)
|
||||
server.restAPIServer.Shutdown(ctx)
|
||||
server.restAPIServer.Close()
|
||||
server.logger.Info("rehash", "server", fmt.Sprintf("%s rest API stopped on %s.", server.name, server.restAPI.Listen))
|
||||
server.restAPIServer = nil
|
||||
}
|
||||
|
||||
// start a new one if it's enabled or the addr changed
|
||||
if restAPIEnabled && (!restAPIStarted || restAPIListenAddrChanged) {
|
||||
server.restAPIServer, _ = StartRestAPI(server, config.Server.RestAPI.Listen)
|
||||
server.logger.Info(
|
||||
"rehash", "server",
|
||||
fmt.Sprintf("%s rest API started on %s.", server.name, config.Server.RestAPI.Listen))
|
||||
}
|
||||
|
||||
// save the config information
|
||||
server.restAPI = config.Server.RestAPI
|
||||
}
|
||||
|
||||
func (server *Server) GetDefaultChannelModes() Modes {
|
||||
server.configurableStateMutex.RLock()
|
||||
defer server.configurableStateMutex.RUnlock()
|
||||
return server.defaultChannelModes
|
||||
}
|
||||
|
||||
// REHASH
|
||||
|
@ -132,7 +132,7 @@ Options:
|
||||
logger.Warning("startup", "You are currently running an unreleased beta version of Oragono that may be unstable and could corrupt your database.\nIf you are running a production network, please download the latest build from https://oragono.io/downloads.html and run that instead.")
|
||||
}
|
||||
|
||||
server, err := irc.NewServer(configfile, config, logger)
|
||||
server, err := irc.NewServer(config, logger)
|
||||
if err != nil {
|
||||
logger.Error("startup", fmt.Sprintf("Could not load server: %s", err.Error()))
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user