3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-26 13:59:44 +01:00

refactor the rehash implementation

This commit is contained in:
Shivaram Lingamneni 2017-09-28 01:30:53 -04:00
parent eae04e8c51
commit e8b1870067
7 changed files with 324 additions and 252 deletions

View File

@ -57,7 +57,7 @@ func NewChannel(s *Server, name string, addDefaultModes bool) *Channel {
} }
if addDefaultModes { if addDefaultModes {
for _, mode := range s.defaultChannelModes { for _, mode := range s.GetDefaultChannelModes() {
channel.flags[mode] = true channel.flags[mode] = true
} }
} }

View File

@ -244,6 +244,8 @@ type Config struct {
WhowasEntries uint `yaml:"whowas-entries"` WhowasEntries uint `yaml:"whowas-entries"`
LineLen LineLenConfig `yaml:"linelen"` LineLen LineLenConfig `yaml:"linelen"`
} }
Filename string
} }
// OperClass defines an assembled operator class. // OperClass defines an assembled operator class.
@ -390,6 +392,8 @@ func LoadConfig(filename string) (config *Config, err error) {
return nil, err return nil, err
} }
config.Filename = filename
// we need this so PasswordBytes returns the correct info // we need this so PasswordBytes returns the correct info
if config.Server.Password != "" { if config.Server.Password != "" {
config.Server.PassConfig.Password = config.Server.Password config.Server.PassConfig.Password = config.Server.Password

View File

@ -53,6 +53,32 @@ func InitDB(path string) {
} }
} }
// open an existing database, performing a schema version check
func OpenDatabase(path string) (*buntdb.DB, error) {
// open data store
db, err := buntdb.Open(path)
if err != nil {
return nil, err
}
// check db version
err = db.View(func(tx *buntdb.Tx) error {
version, _ := tx.Get(keySchemaVersion)
if version != latestDbSchema {
return fmt.Errorf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version)
}
return nil
})
if err != nil {
// close the db
db.Close()
return nil, err
}
return db, nil
}
// UpgradeDB upgrades the datastore to the latest schema. // UpgradeDB upgrades the datastore to the latest schema.
func UpgradeDB(path string) { func UpgradeDB(path string) {
store, err := buntdb.Open(path) store, err := buntdb.Open(path)

View File

@ -530,10 +530,10 @@ Oragono supports the following channel membership prefixes:
} }
// HelpIndex contains the list of all help topics for regular users. // HelpIndex contains the list of all help topics for regular users.
var HelpIndex = "list of all help topics for regular users" var HelpIndex string
// HelpIndexOpers contains the list of all help topics for opers. // HelpIndexOpers contains the list of all help topics for opers.
var HelpIndexOpers = "list of all help topics for opers" var HelpIndexOpers string
// GenerateHelpIndex is used to generate HelpIndex. // GenerateHelpIndex is used to generate HelpIndex.
func GenerateHelpIndex(forOpers bool) string { func GenerateHelpIndex(forOpers bool) string {
@ -582,6 +582,25 @@ Information:
return newHelpIndex return newHelpIndex
} }
func GenerateHelpIndices() error {
if HelpIndex != "" && HelpIndexOpers != "" {
return nil
}
// startup check that we have HELP entries for every command
for name := range Commands {
_, exists := Help[strings.ToLower(name)]
if !exists {
return fmt.Errorf("Help entry does not exist for command %s", name)
}
}
// generate help indexes
HelpIndex = GenerateHelpIndex(false)
HelpIndexOpers = GenerateHelpIndex(true)
return nil
}
// sendHelp sends the client help of the given string. // sendHelp sends the client help of the given string.
func (client *Client) sendHelp(name string, text string) { func (client *Client) sendHelp(name string, text string) {
splitName := strings.Split(name, " ") splitName := strings.Split(name, " ")

View File

@ -19,9 +19,9 @@ import (
const restErr = "{\"error\":\"An unknown error occurred\"}" const restErr = "{\"error\":\"An unknown error occurred\"}"
// restAPIServer is used to keep a link to the current running server since this is the best // ircServer is used to keep a link to the current running server since this is the best
// way to do it, given how HTTP handlers dispatch and work. // way to do it, given how HTTP handlers dispatch and work.
var restAPIServer *Server var ircServer *Server
type restInfoResp struct { type restInfoResp struct {
ServerName string `json:"server-name"` ServerName string `json:"server-name"`
@ -60,8 +60,8 @@ type restRehashResp struct {
func restInfo(w http.ResponseWriter, r *http.Request) { func restInfo(w http.ResponseWriter, r *http.Request) {
rs := restInfoResp{ rs := restInfoResp{
Version: SemVer, Version: SemVer,
ServerName: restAPIServer.name, ServerName: ircServer.name,
NetworkName: restAPIServer.networkName, NetworkName: ircServer.networkName,
} }
b, err := json.Marshal(rs) b, err := json.Marshal(rs)
if err != nil { if err != nil {
@ -73,9 +73,9 @@ func restInfo(w http.ResponseWriter, r *http.Request) {
func restStatus(w http.ResponseWriter, r *http.Request) { func restStatus(w http.ResponseWriter, r *http.Request) {
rs := restStatusResp{ rs := restStatusResp{
Clients: restAPIServer.clients.Count(), Clients: ircServer.clients.Count(),
Opers: len(restAPIServer.operators), Opers: len(ircServer.operators),
Channels: restAPIServer.channels.Len(), Channels: ircServer.channels.Len(),
} }
b, err := json.Marshal(rs) b, err := json.Marshal(rs)
if err != nil { if err != nil {
@ -87,8 +87,8 @@ func restStatus(w http.ResponseWriter, r *http.Request) {
func restGetXLines(w http.ResponseWriter, r *http.Request) { func restGetXLines(w http.ResponseWriter, r *http.Request) {
rs := restXLinesResp{ rs := restXLinesResp{
DLines: restAPIServer.dlines.AllBans(), DLines: ircServer.dlines.AllBans(),
KLines: restAPIServer.klines.AllBans(), KLines: ircServer.klines.AllBans(),
} }
b, err := json.Marshal(rs) b, err := json.Marshal(rs)
if err != nil { if err != nil {
@ -104,7 +104,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
} }
// get accounts // get accounts
err := restAPIServer.store.View(func(tx *buntdb.Tx) error { err := ircServer.store.View(func(tx *buntdb.Tx) error {
tx.AscendKeys("account.exists *", func(key, value string) bool { tx.AscendKeys("account.exists *", func(key, value string) bool {
key = key[len("account.exists "):] key = key[len("account.exists "):]
_, err := tx.Get(fmt.Sprintf(keyAccountVerified, key)) _, err := tx.Get(fmt.Sprintf(keyAccountVerified, key))
@ -118,7 +118,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
regTime := time.Unix(regTimeInt, 0) regTime := time.Unix(regTimeInt, 0)
var clients int var clients int
acct := restAPIServer.accounts[key] acct := ircServer.accounts[key]
if acct != nil { if acct != nil {
clients = len(acct.Clients) clients = len(acct.Clients)
} }
@ -148,7 +148,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
} }
func restRehash(w http.ResponseWriter, r *http.Request) { func restRehash(w http.ResponseWriter, r *http.Request) {
err := restAPIServer.rehash() err := ircServer.rehash()
rs := restRehashResp{ rs := restRehashResp{
Successful: err == nil, Successful: err == nil,
@ -166,9 +166,9 @@ func restRehash(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Server) startRestAPI() { func StartRestAPI(s *Server, listenAddr string) (*http.Server, error) {
// so handlers can ref it later // so handlers can ref it later
restAPIServer = s ircServer = s
// start router // start router
r := mux.NewRouter() r := mux.NewRouter()
@ -185,5 +185,16 @@ func (s *Server) startRestAPI() {
rp.HandleFunc("/rehash", restRehash) rp.HandleFunc("/rehash", restRehash)
// start api // start api
go http.ListenAndServe(s.restAPI.Listen, r) httpserver := http.Server{
Addr: listenAddr,
Handler: r,
}
go func() {
if err := httpserver.ListenAndServe(); err != nil {
s.logger.Error("listeners", fmt.Sprintf("Rest API listenAndServe error: %s", err))
}
}()
return &httpserver, nil
} }

View File

@ -7,9 +7,9 @@ package irc
import ( import (
"bufio" "bufio"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"log" "log"
"math/rand" "math/rand"
@ -35,8 +35,14 @@ var (
tooManyClientsMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Too many clients from your network")}[0]).Line() tooManyClientsMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Too many clients from your network")}[0]).Line()
couldNotParseIPMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line() couldNotParseIPMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line()
bannedFromServerMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "You are banned from this server (%s)")}[0]).Line() bannedFromServerMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "You are banned from this server (%s)")}[0]).Line()
)
errDbOutOfDate = errors.New("Database schema is old") const (
// when shutting down the REST and websocket servers, wait this long
// before killing active non-WS connections. TODO: this might not be
// necessary at all? but it seems prudent to avoid potential resource
// leaks
httpShutdownTimeout = time.Second
) )
// Limits holds the maximum limits for various things such as topic lengths. // Limits holds the maximum limits for various things such as topic lengths.
@ -80,6 +86,7 @@ type Server struct {
clients *ClientLookupSet clients *ClientLookupSet
commands chan Command commands chan Command
configFilename string configFilename string
configurableStateMutex sync.RWMutex // generic protection for server state modified by rehash()
connectionLimits *ConnectionLimits connectionLimits *ConnectionLimits
connectionLimitsMutex sync.Mutex // used when affecting the connection limiter, to make sure rehashing doesn't make things go out-of-whack connectionLimitsMutex sync.Mutex // used when affecting the connection limiter, to make sure rehashing doesn't make things go out-of-whack
connectionThrottle *ConnectionThrottle connectionThrottle *ConnectionThrottle
@ -109,13 +116,15 @@ type Server struct {
registeredChannelsMutex sync.RWMutex registeredChannelsMutex sync.RWMutex
rehashMutex sync.Mutex rehashMutex sync.Mutex
rehashSignal chan os.Signal rehashSignal chan os.Signal
restAPI *RestAPIConfig restAPI RestAPIConfig
restAPIServer *http.Server
proxyAllowedFrom []string proxyAllowedFrom []string
signals chan os.Signal signals chan os.Signal
snomasks *SnoManager snomasks *SnoManager
store *buntdb.DB store *buntdb.DB
stsEnabled bool stsEnabled bool
whoWas *WhoWasList whoWas *WhoWasList
wsServer *http.Server
} }
var ( var (
@ -133,216 +142,38 @@ type clientConn struct {
} }
// NewServer returns a new Oragono server. // NewServer returns a new Oragono server.
func NewServer(configFilename string, config *Config, logger *logger.Manager) (*Server, error) { func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
casefoldedName, err := Casefold(config.Server.Name) // TODO move this to main?
if err != nil { if err := GenerateHelpIndices(); err != nil {
return nil, fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error()) return nil, err
}
// startup check that we have HELP entries for every command
for name := range Commands {
_, exists := Help[strings.ToLower(name)]
if !exists {
return nil, fmt.Errorf("Help entry does not exist for command %s", name)
}
}
// generate help indexes
HelpIndex = GenerateHelpIndex(false)
HelpIndexOpers = GenerateHelpIndex(true)
if config.Accounts.AuthenticationEnabled {
SupportedCapabilities[SASL] = true
}
if config.Server.STS.Enabled {
SupportedCapabilities[STS] = true
CapValues[STS] = config.Server.STS.Value()
}
if config.Limits.LineLen.Tags > 512 || config.Limits.LineLen.Rest > 512 {
SupportedCapabilities[MaxLine] = true
CapValues[MaxLine] = fmt.Sprintf("%d,%d", config.Limits.LineLen.Tags, config.Limits.LineLen.Rest)
}
operClasses, err := config.OperatorClasses()
if err != nil {
return nil, fmt.Errorf("Error loading oper classes: %s", err.Error())
}
opers, err := config.Operators(operClasses)
if err != nil {
return nil, fmt.Errorf("Error loading operators: %s", err.Error())
}
connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits)
if err != nil {
return nil, fmt.Errorf("Error loading connection limits: %s", err.Error())
}
connectionThrottle, err := NewConnectionThrottle(config.Server.ConnectionThrottle)
if err != nil {
return nil, fmt.Errorf("Error loading connection throttler: %s", err.Error())
} }
// initialize data structures
server := &Server{ server := &Server{
accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled,
accounts: make(map[string]*ClientAccount), accounts: make(map[string]*ClientAccount),
channelRegistrationEnabled: config.Channels.Registration.Enabled,
channels: *NewChannelNameMap(), channels: *NewChannelNameMap(),
checkIdent: config.Server.CheckIdent,
clients: NewClientLookupSet(), clients: NewClientLookupSet(),
commands: make(chan Command), commands: make(chan Command),
configFilename: configFilename,
connectionLimits: connectionLimits,
connectionThrottle: connectionThrottle,
ctime: time.Now(),
currentOpers: make(map[*Client]bool), currentOpers: make(map[*Client]bool),
defaultChannelModes: ParseDefaultChannelModes(config),
limits: Limits{
AwayLen: int(config.Limits.AwayLen),
ChannelLen: int(config.Limits.ChannelLen),
KickLen: int(config.Limits.KickLen),
MonitorEntries: int(config.Limits.MonitorEntries),
NickLen: int(config.Limits.NickLen),
TopicLen: int(config.Limits.TopicLen),
ChanListModes: int(config.Limits.ChanListModes),
LineLen: LineLenLimits{
Tags: config.Limits.LineLen.Tags,
Rest: config.Limits.LineLen.Rest,
},
},
listeners: make(map[string]*ListenerWrapper), listeners: make(map[string]*ListenerWrapper),
logger: logger, logger: logger,
MaxSendQBytes: config.Server.MaxSendQBytes,
monitoring: make(map[string][]*Client), monitoring: make(map[string][]*Client),
name: config.Server.Name,
nameCasefolded: casefoldedName,
networkName: config.Network.Name,
newConns: make(chan clientConn), newConns: make(chan clientConn),
operators: opers,
operclasses: *operClasses,
proxyAllowedFrom: config.Server.ProxyAllowedFrom,
registeredChannels: make(map[string]*RegisteredChannel), registeredChannels: make(map[string]*RegisteredChannel),
rehashSignal: make(chan os.Signal, 1), rehashSignal: make(chan os.Signal, 1),
restAPI: &config.Server.RestAPI,
signals: make(chan os.Signal, len(ServerExitSignals)), signals: make(chan os.Signal, len(ServerExitSignals)),
snomasks: NewSnoManager(), snomasks: NewSnoManager(),
stsEnabled: config.Server.STS.Enabled,
whoWas: NewWhoWasList(config.Limits.WhowasEntries), whoWas: NewWhoWasList(config.Limits.WhowasEntries),
} }
// open data store if err := server.applyConfig(config, true); err != nil {
server.logger.Debug("startup", "Opening datastore") return nil, err
db, err := buntdb.Open(config.Datastore.Path)
if err != nil {
return nil, fmt.Errorf("Failed to open datastore: %s", err.Error())
} }
server.store = db
// check db version
err = server.store.View(func(tx *buntdb.Tx) error {
version, _ := tx.Get(keySchemaVersion)
if version != latestDbSchema {
logger.Error("startup", "server", fmt.Sprintf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version))
return errDbOutOfDate
}
return nil
})
if err != nil {
// close the db
db.Close()
return nil, errDbOutOfDate
}
// load *lines
server.logger.Debug("startup", "Loading D/Klines")
server.loadDLines()
server.loadKLines()
// load password manager
server.logger.Debug("startup", "Loading passwords")
err = server.store.View(func(tx *buntdb.Tx) error {
saltString, err := tx.Get(keySalt)
if err != nil {
return fmt.Errorf("Could not retrieve salt string: %s", err.Error())
}
salt, err := base64.StdEncoding.DecodeString(saltString)
if err != nil {
return err
}
pwm := NewPasswordManager(salt)
server.passwords = &pwm
return nil
})
if err != nil {
return nil, fmt.Errorf("Could not load salt: %s", err.Error())
}
server.logger.Debug("startup", "Loading MOTD")
if config.Server.MOTD != "" {
file, err := os.Open(config.Server.MOTD)
if err == nil {
defer file.Close()
reader := bufio.NewReader(file)
for {
line, err := reader.ReadString('\n')
if err != nil {
break
}
line = strings.TrimRight(line, "\r\n")
// "- " is the required prefix for MOTD, we just add it here to make
// bursting it out to clients easier
line = fmt.Sprintf("- %s", line)
server.motdLines = append(server.motdLines, line)
}
}
}
if config.Server.Password != "" {
server.password = config.Server.PasswordBytes()
}
tlsListeners := config.TLSListeners()
for _, addr := range config.Server.Listen {
server.listeners[addr] = server.createListener(addr, tlsListeners[addr])
}
if len(tlsListeners) == 0 {
server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections")
}
var usesStandardTLSPort bool
for addr := range config.TLSListeners() {
if strings.Contains(addr, "6697") {
usesStandardTLSPort = true
break
}
}
if 0 < len(tlsListeners) && !usesStandardTLSPort {
server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely")
}
if config.Server.Wslisten != "" {
server.wslisten(config.Server.Wslisten, config.Server.TLSListeners)
}
// registration
accountReg := NewAccountRegistration(config.Accounts.Registration)
server.accountRegistration = &accountReg
// Attempt to clean up when receiving these signals. // Attempt to clean up when receiving these signals.
signal.Notify(server.signals, ServerExitSignals...) signal.Notify(server.signals, ServerExitSignals...)
signal.Notify(server.rehashSignal, syscall.SIGHUP) signal.Notify(server.rehashSignal, syscall.SIGHUP)
server.setISupport()
// start API if enabled
if server.restAPI.Enabled {
logger.Info("startup", "server", fmt.Sprintf("%s rest API started on %s.", server.name, server.restAPI.Listen))
server.startRestAPI()
}
return server, nil return server, nil
} }
@ -514,13 +345,6 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen
stopEvent: make(chan bool, 1), stopEvent: make(chan bool, 1),
} }
// TODO(slingamn) move all logging of listener status to rehash()
tlsString := "plaintext"
if tlsConfig != nil {
tlsString = "TLS"
}
server.logger.Info("listeners", fmt.Sprintf("listening on %s using %s.", addr, tlsString))
var shouldStop bool var shouldStop bool
// setup accept goroutine // setup accept goroutine
@ -562,8 +386,29 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen
// websocket listen goroutine // websocket listen goroutine
// //
func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig) { func (server *Server) setupWSListener(config *Config) {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // unconditionally shut down the old listener because we can't tell
// whether we need to reload the TLS certificate
if server.wsServer != nil {
ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout)
server.wsServer.Shutdown(ctx)
server.wsServer.Close()
}
if config.Server.Wslisten == "" {
server.wsServer = nil
return
}
addr := config.Server.Wslisten
tlsConfig := config.Server.TLSListeners[addr]
handler := http.NewServeMux()
wsServer := http.Server{
Addr: addr,
Handler: handler,
}
handler.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" { if r.Method != "GET" {
server.logger.Error("ws", addr, fmt.Sprintf("%s method not allowed", r.Method)) server.logger.Error("ws", addr, fmt.Sprintf("%s method not allowed", r.Method))
return return
@ -584,29 +429,26 @@ func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig)
newConn := clientConn{ newConn := clientConn{
Conn: WSContainer{ws}, Conn: WSContainer{ws},
IsTLS: false, //TODO(dan): track TLS or not here properly IsTLS: tlsConfig != nil,
} }
server.newConns <- newConn server.newConns <- newConn
}) })
go func() { go func() {
config, listenTLS := tlsMap[addr]
tlsString := "plaintext"
var err error var err error
if listenTLS { server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s, tls=%t.", addr, tlsConfig != nil))
tlsString = "TLS"
}
server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s using %s.", addr, tlsString))
if listenTLS { if tlsConfig != nil {
err = http.ListenAndServeTLS(addr, config.Cert, config.Key, nil) err = wsServer.ListenAndServeTLS(tlsConfig.Cert, tlsConfig.Key)
} else { } else {
err = http.ListenAndServe(addr, nil) err = wsServer.ListenAndServe()
} }
if err != nil { if err != nil {
server.logger.Error("listeners", fmt.Sprintf("listenAndServe error [%s]: %s", tlsString, err)) server.logger.Error("listeners", fmt.Sprintf("websocket ListenAndServe error: %s", err))
} }
}() }()
server.wsServer = &wsServer
} }
// generateMessageID returns a network-unique message ID. // generateMessageID returns a network-unique message ID.
@ -660,6 +502,9 @@ func (server *Server) tryRegister(c *Client) {
// MOTD serves the Message of the Day. // MOTD serves the Message of the Day.
func (server *Server) MOTD(client *Client) { func (server *Server) MOTD(client *Client) {
server.configurableStateMutex.RLock()
defer server.configurableStateMutex.RUnlock()
if len(server.motdLines) < 1 { if len(server.motdLines) < 1 {
client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing") client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing")
return return
@ -1415,15 +1260,35 @@ func (server *Server) rehash() error {
server.logger.Debug("rehash", "Got rehash lock") server.logger.Debug("rehash", "Got rehash lock")
config, err := LoadConfig(server.configFilename) config, err := LoadConfig(server.configFilename)
if err != nil { if err != nil {
return fmt.Errorf("Error rehashing config file config: %s", err.Error()) return fmt.Errorf("Error loading config file config: %s", err.Error())
} }
err = server.applyConfig(config, false)
if err != nil {
return fmt.Errorf("Error applying config changes: %s", err.Error())
}
return nil
}
func (server *Server) applyConfig(config *Config, initial bool) error {
if initial {
server.ctime = time.Now()
server.configFilename = config.Filename
}
casefoldedName, err := Casefold(config.Server.Name)
if err != nil {
return fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error())
}
if !initial {
// line lengths cannot be changed after launching the server // line lengths cannot be changed after launching the server
if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest { if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest {
return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted") return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted")
} }
}
// confirm connectionLimits are fine // confirm connectionLimits are fine
connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits) connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits)
@ -1453,6 +1318,18 @@ func (server *Server) rehash() error {
} }
} }
// sanity checks complete, start modifying server state
server.name = config.Server.Name
server.nameCasefolded = casefoldedName
server.networkName = config.Network.Name
if config.Server.Password != "" {
server.password = config.Server.PasswordBytes()
} else {
server.password = nil
}
// apply new connectionlimits // apply new connectionlimits
server.connectionLimitsMutex.Lock() server.connectionLimitsMutex.Lock()
server.connectionLimits = connectionLimits server.connectionLimits = connectionLimits
@ -1578,7 +1455,9 @@ func (server *Server) rehash() error {
server.accountRegistration = &accountReg server.accountRegistration = &accountReg
server.channelRegistrationEnabled = config.Channels.Registration.Enabled server.channelRegistrationEnabled = config.Channels.Registration.Enabled
server.configurableStateMutex.Lock()
server.defaultChannelModes = ParseDefaultChannelModes(config) server.defaultChannelModes = ParseDefaultChannelModes(config)
server.configurableStateMutex.Unlock()
// set new sendqueue size // set new sendqueue size
if config.Server.MaxSendQBytes != server.MaxSendQBytes { if config.Server.MaxSendQBytes != server.MaxSendQBytes {
@ -1595,8 +1474,8 @@ func (server *Server) rehash() error {
// set RPL_ISUPPORT // set RPL_ISUPPORT
oldISupportList := server.isupport oldISupportList := server.isupport
server.setISupport() server.setISupport()
if oldISupportList != nil {
newISupportReplies := oldISupportList.GetDifference(server.isupport) newISupportReplies := oldISupportList.GetDifference(server.isupport)
// push new info to all of our clients // push new info to all of our clients
server.clients.ByNickMutex.RLock() server.clients.ByNickMutex.RLock()
for _, sClient := range server.clients.ByNick { for _, sClient := range server.clients.ByNick {
@ -1606,7 +1485,97 @@ func (server *Server) rehash() error {
} }
} }
server.clients.ByNickMutex.RUnlock() server.clients.ByNickMutex.RUnlock()
}
server.loadMOTD(config.Server.MOTD)
if initial {
if err := server.loadDatastore(config.Datastore.Path); err != nil {
return err
}
}
// we are now open for business
server.setupListeners(config)
server.setupWSListener(config)
server.setupRestAPI(config)
return nil
}
func (server *Server) loadMOTD(motdPath string) error {
server.logger.Debug("rehash", "Loading MOTD")
motdLines := make([]string, 0)
if motdPath != "" {
file, err := os.Open(motdPath)
if err == nil {
defer file.Close()
reader := bufio.NewReader(file)
for {
line, err := reader.ReadString('\n')
if err != nil {
break
}
line = strings.TrimRight(line, "\r\n")
// "- " is the required prefix for MOTD, we just add it here to make
// bursting it out to clients easier
line = fmt.Sprintf("- %s", line)
motdLines = append(motdLines, line)
}
} else {
return err
}
}
server.configurableStateMutex.Lock()
defer server.configurableStateMutex.Unlock()
server.motdLines = motdLines
return nil
}
func (server *Server) loadDatastore(datastorePath string) error {
// open the datastore and load server state for which it (rather than config)
// is the source of truth
server.logger.Debug("startup", "Opening datastore")
db, err := OpenDatabase(datastorePath)
if err == nil {
server.store = db
} else {
return fmt.Errorf("Failed to open datastore: %s", err.Error())
}
// load *lines (from the datastores)
server.logger.Debug("startup", "Loading D/Klines")
server.loadDLines()
server.loadKLines()
// load password manager
server.logger.Debug("startup", "Loading passwords")
err = server.store.View(func(tx *buntdb.Tx) error {
saltString, err := tx.Get(keySalt)
if err != nil {
return fmt.Errorf("Could not retrieve salt string: %s", err.Error())
}
salt, err := base64.StdEncoding.DecodeString(saltString)
if err != nil {
return err
}
pwm := NewPasswordManager(salt)
server.passwords = &pwm
return nil
})
if err != nil {
return fmt.Errorf("Could not load salt: %s", err.Error())
}
return nil
}
func (server *Server) setupListeners(config *Config) {
// update or destroy all existing listeners // update or destroy all existing listeners
tlsListeners := config.TLSListeners() tlsListeners := config.TLSListeners()
for addr := range server.listeners { for addr := range server.listeners {
@ -1629,18 +1598,16 @@ func (server *Server) rehash() error {
currentListener.configMutex.Unlock() currentListener.configMutex.Unlock()
if stillConfigured { if stillConfigured {
server.logger.Info("rehash", server.logger.Info("listeners",
fmt.Sprintf("now listening on %s, tls=%t.", addr, (currentListener.tlsConfig != nil)), fmt.Sprintf("now listening on %s, tls=%t.", addr, (currentListener.tlsConfig != nil)),
) )
} else { } else {
// tell the listener it should stop by interrupting its Accept() call: // tell the listener it should stop by interrupting its Accept() call:
currentListener.listener.Close() currentListener.listener.Close()
// XXX there is no guarantee from the API when the address will actually // TODO(golang1.10) delete stopEvent once issue #21856 is released
// free for bind(2) again; this solution "seems to work". See here:
// https://github.com/golang/go/issues/21833
<-currentListener.stopEvent <-currentListener.stopEvent
delete(server.listeners, addr) delete(server.listeners, addr)
server.logger.Info("rehash", fmt.Sprintf("stopped listening on %s.", addr)) server.logger.Info("listeners", fmt.Sprintf("stopped listening on %s.", addr))
} }
} }
@ -1653,7 +1620,52 @@ func (server *Server) rehash() error {
} }
} }
return nil if len(tlsListeners) == 0 {
server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections")
}
var usesStandardTLSPort bool
for addr := range config.TLSListeners() {
if strings.Contains(addr, "6697") {
usesStandardTLSPort = true
break
}
}
if 0 < len(tlsListeners) && !usesStandardTLSPort {
server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely")
}
}
func (server *Server) setupRestAPI(config *Config) {
restAPIEnabled := config.Server.RestAPI.Enabled
restAPIStarted := server.restAPIServer != nil
restAPIListenAddrChanged := server.restAPI.Listen != config.Server.RestAPI.Listen
// stop an existing REST server if it's been disabled or the addr changed
if restAPIStarted && (!restAPIEnabled || restAPIListenAddrChanged) {
ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout)
server.restAPIServer.Shutdown(ctx)
server.restAPIServer.Close()
server.logger.Info("rehash", "server", fmt.Sprintf("%s rest API stopped on %s.", server.name, server.restAPI.Listen))
server.restAPIServer = nil
}
// start a new one if it's enabled or the addr changed
if restAPIEnabled && (!restAPIStarted || restAPIListenAddrChanged) {
server.restAPIServer, _ = StartRestAPI(server, config.Server.RestAPI.Listen)
server.logger.Info(
"rehash", "server",
fmt.Sprintf("%s rest API started on %s.", server.name, config.Server.RestAPI.Listen))
}
// save the config information
server.restAPI = config.Server.RestAPI
}
func (server *Server) GetDefaultChannelModes() Modes {
server.configurableStateMutex.RLock()
defer server.configurableStateMutex.RUnlock()
return server.defaultChannelModes
} }
// REHASH // REHASH

View File

@ -132,7 +132,7 @@ Options:
logger.Warning("startup", "You are currently running an unreleased beta version of Oragono that may be unstable and could corrupt your database.\nIf you are running a production network, please download the latest build from https://oragono.io/downloads.html and run that instead.") logger.Warning("startup", "You are currently running an unreleased beta version of Oragono that may be unstable and could corrupt your database.\nIf you are running a production network, please download the latest build from https://oragono.io/downloads.html and run that instead.")
} }
server, err := irc.NewServer(configfile, config, logger) server, err := irc.NewServer(config, logger)
if err != nil { if err != nil {
logger.Error("startup", fmt.Sprintf("Could not load server: %s", err.Error())) logger.Error("startup", fmt.Sprintf("Could not load server: %s", err.Error()))
return return