From e8b18700675448b49fbe543855341ad2b22ac73e Mon Sep 17 00:00:00 2001 From: Shivaram Lingamneni Date: Thu, 28 Sep 2017 01:30:53 -0400 Subject: [PATCH] refactor the rehash implementation --- irc/channel.go | 2 +- irc/config.go | 4 + irc/database.go | 26 +++ irc/help.go | 23 ++- irc/rest_api.go | 41 +++-- irc/server.go | 478 +++++++++++++++++++++++++----------------------- oragono.go | 2 +- 7 files changed, 324 insertions(+), 252 deletions(-) diff --git a/irc/channel.go b/irc/channel.go index 410eade0..7535095f 100644 --- a/irc/channel.go +++ b/irc/channel.go @@ -57,7 +57,7 @@ func NewChannel(s *Server, name string, addDefaultModes bool) *Channel { } if addDefaultModes { - for _, mode := range s.defaultChannelModes { + for _, mode := range s.GetDefaultChannelModes() { channel.flags[mode] = true } } diff --git a/irc/config.go b/irc/config.go index f88ebcc4..906df406 100644 --- a/irc/config.go +++ b/irc/config.go @@ -244,6 +244,8 @@ type Config struct { WhowasEntries uint `yaml:"whowas-entries"` LineLen LineLenConfig `yaml:"linelen"` } + + Filename string } // OperClass defines an assembled operator class. @@ -390,6 +392,8 @@ func LoadConfig(filename string) (config *Config, err error) { return nil, err } + config.Filename = filename + // we need this so PasswordBytes returns the correct info if config.Server.Password != "" { config.Server.PassConfig.Password = config.Server.Password diff --git a/irc/database.go b/irc/database.go index 8f2522e9..305e0667 100644 --- a/irc/database.go +++ b/irc/database.go @@ -53,6 +53,32 @@ func InitDB(path string) { } } +// open an existing database, performing a schema version check +func OpenDatabase(path string) (*buntdb.DB, error) { + // open data store + db, err := buntdb.Open(path) + if err != nil { + return nil, err + } + + // check db version + err = db.View(func(tx *buntdb.Tx) error { + version, _ := tx.Get(keySchemaVersion) + if version != latestDbSchema { + return fmt.Errorf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version) + } + return nil + }) + + if err != nil { + // close the db + db.Close() + return nil, err + } + + return db, nil +} + // UpgradeDB upgrades the datastore to the latest schema. func UpgradeDB(path string) { store, err := buntdb.Open(path) diff --git a/irc/help.go b/irc/help.go index cf0751df..6145bb8a 100644 --- a/irc/help.go +++ b/irc/help.go @@ -530,10 +530,10 @@ Oragono supports the following channel membership prefixes: } // HelpIndex contains the list of all help topics for regular users. -var HelpIndex = "list of all help topics for regular users" +var HelpIndex string // HelpIndexOpers contains the list of all help topics for opers. -var HelpIndexOpers = "list of all help topics for opers" +var HelpIndexOpers string // GenerateHelpIndex is used to generate HelpIndex. func GenerateHelpIndex(forOpers bool) string { @@ -582,6 +582,25 @@ Information: return newHelpIndex } +func GenerateHelpIndices() error { + if HelpIndex != "" && HelpIndexOpers != "" { + return nil + } + + // startup check that we have HELP entries for every command + for name := range Commands { + _, exists := Help[strings.ToLower(name)] + if !exists { + return fmt.Errorf("Help entry does not exist for command %s", name) + } + } + + // generate help indexes + HelpIndex = GenerateHelpIndex(false) + HelpIndexOpers = GenerateHelpIndex(true) + return nil +} + // sendHelp sends the client help of the given string. func (client *Client) sendHelp(name string, text string) { splitName := strings.Split(name, " ") diff --git a/irc/rest_api.go b/irc/rest_api.go index 012b4109..817406ed 100644 --- a/irc/rest_api.go +++ b/irc/rest_api.go @@ -19,9 +19,9 @@ import ( const restErr = "{\"error\":\"An unknown error occurred\"}" -// restAPIServer is used to keep a link to the current running server since this is the best +// ircServer is used to keep a link to the current running server since this is the best // way to do it, given how HTTP handlers dispatch and work. -var restAPIServer *Server +var ircServer *Server type restInfoResp struct { ServerName string `json:"server-name"` @@ -60,8 +60,8 @@ type restRehashResp struct { func restInfo(w http.ResponseWriter, r *http.Request) { rs := restInfoResp{ Version: SemVer, - ServerName: restAPIServer.name, - NetworkName: restAPIServer.networkName, + ServerName: ircServer.name, + NetworkName: ircServer.networkName, } b, err := json.Marshal(rs) if err != nil { @@ -73,9 +73,9 @@ func restInfo(w http.ResponseWriter, r *http.Request) { func restStatus(w http.ResponseWriter, r *http.Request) { rs := restStatusResp{ - Clients: restAPIServer.clients.Count(), - Opers: len(restAPIServer.operators), - Channels: restAPIServer.channels.Len(), + Clients: ircServer.clients.Count(), + Opers: len(ircServer.operators), + Channels: ircServer.channels.Len(), } b, err := json.Marshal(rs) if err != nil { @@ -87,8 +87,8 @@ func restStatus(w http.ResponseWriter, r *http.Request) { func restGetXLines(w http.ResponseWriter, r *http.Request) { rs := restXLinesResp{ - DLines: restAPIServer.dlines.AllBans(), - KLines: restAPIServer.klines.AllBans(), + DLines: ircServer.dlines.AllBans(), + KLines: ircServer.klines.AllBans(), } b, err := json.Marshal(rs) if err != nil { @@ -104,7 +104,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) { } // get accounts - err := restAPIServer.store.View(func(tx *buntdb.Tx) error { + err := ircServer.store.View(func(tx *buntdb.Tx) error { tx.AscendKeys("account.exists *", func(key, value string) bool { key = key[len("account.exists "):] _, err := tx.Get(fmt.Sprintf(keyAccountVerified, key)) @@ -118,7 +118,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) { regTime := time.Unix(regTimeInt, 0) var clients int - acct := restAPIServer.accounts[key] + acct := ircServer.accounts[key] if acct != nil { clients = len(acct.Clients) } @@ -148,7 +148,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) { } func restRehash(w http.ResponseWriter, r *http.Request) { - err := restAPIServer.rehash() + err := ircServer.rehash() rs := restRehashResp{ Successful: err == nil, @@ -166,9 +166,9 @@ func restRehash(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) startRestAPI() { +func StartRestAPI(s *Server, listenAddr string) (*http.Server, error) { // so handlers can ref it later - restAPIServer = s + ircServer = s // start router r := mux.NewRouter() @@ -185,5 +185,16 @@ func (s *Server) startRestAPI() { rp.HandleFunc("/rehash", restRehash) // start api - go http.ListenAndServe(s.restAPI.Listen, r) + httpserver := http.Server{ + Addr: listenAddr, + Handler: r, + } + + go func() { + if err := httpserver.ListenAndServe(); err != nil { + s.logger.Error("listeners", fmt.Sprintf("Rest API listenAndServe error: %s", err)) + } + }() + + return &httpserver, nil } diff --git a/irc/server.go b/irc/server.go index 037fcc65..7294a3df 100644 --- a/irc/server.go +++ b/irc/server.go @@ -7,9 +7,9 @@ package irc import ( "bufio" + "context" "crypto/tls" "encoding/base64" - "errors" "fmt" "log" "math/rand" @@ -35,8 +35,14 @@ var ( tooManyClientsMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Too many clients from your network")}[0]).Line() couldNotParseIPMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line() bannedFromServerMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "You are banned from this server (%s)")}[0]).Line() +) - errDbOutOfDate = errors.New("Database schema is old") +const ( + // when shutting down the REST and websocket servers, wait this long + // before killing active non-WS connections. TODO: this might not be + // necessary at all? but it seems prudent to avoid potential resource + // leaks + httpShutdownTimeout = time.Second ) // Limits holds the maximum limits for various things such as topic lengths. @@ -80,6 +86,7 @@ type Server struct { clients *ClientLookupSet commands chan Command configFilename string + configurableStateMutex sync.RWMutex // generic protection for server state modified by rehash() connectionLimits *ConnectionLimits connectionLimitsMutex sync.Mutex // used when affecting the connection limiter, to make sure rehashing doesn't make things go out-of-whack connectionThrottle *ConnectionThrottle @@ -109,13 +116,15 @@ type Server struct { registeredChannelsMutex sync.RWMutex rehashMutex sync.Mutex rehashSignal chan os.Signal - restAPI *RestAPIConfig + restAPI RestAPIConfig + restAPIServer *http.Server proxyAllowedFrom []string signals chan os.Signal snomasks *SnoManager store *buntdb.DB stsEnabled bool whoWas *WhoWasList + wsServer *http.Server } var ( @@ -133,216 +142,38 @@ type clientConn struct { } // NewServer returns a new Oragono server. -func NewServer(configFilename string, config *Config, logger *logger.Manager) (*Server, error) { - casefoldedName, err := Casefold(config.Server.Name) - if err != nil { - return nil, fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error()) - } - - // startup check that we have HELP entries for every command - for name := range Commands { - _, exists := Help[strings.ToLower(name)] - if !exists { - return nil, fmt.Errorf("Help entry does not exist for command %s", name) - } - } - // generate help indexes - HelpIndex = GenerateHelpIndex(false) - HelpIndexOpers = GenerateHelpIndex(true) - - if config.Accounts.AuthenticationEnabled { - SupportedCapabilities[SASL] = true - } - - if config.Server.STS.Enabled { - SupportedCapabilities[STS] = true - CapValues[STS] = config.Server.STS.Value() - } - - if config.Limits.LineLen.Tags > 512 || config.Limits.LineLen.Rest > 512 { - SupportedCapabilities[MaxLine] = true - CapValues[MaxLine] = fmt.Sprintf("%d,%d", config.Limits.LineLen.Tags, config.Limits.LineLen.Rest) - } - - operClasses, err := config.OperatorClasses() - if err != nil { - return nil, fmt.Errorf("Error loading oper classes: %s", err.Error()) - } - opers, err := config.Operators(operClasses) - if err != nil { - return nil, fmt.Errorf("Error loading operators: %s", err.Error()) - } - - connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits) - if err != nil { - return nil, fmt.Errorf("Error loading connection limits: %s", err.Error()) - } - connectionThrottle, err := NewConnectionThrottle(config.Server.ConnectionThrottle) - if err != nil { - return nil, fmt.Errorf("Error loading connection throttler: %s", err.Error()) +func NewServer(config *Config, logger *logger.Manager) (*Server, error) { + // TODO move this to main? + if err := GenerateHelpIndices(); err != nil { + return nil, err } + // initialize data structures server := &Server{ - accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled, - accounts: make(map[string]*ClientAccount), - channelRegistrationEnabled: config.Channels.Registration.Enabled, - channels: *NewChannelNameMap(), - checkIdent: config.Server.CheckIdent, - clients: NewClientLookupSet(), - commands: make(chan Command), - configFilename: configFilename, - connectionLimits: connectionLimits, - connectionThrottle: connectionThrottle, - ctime: time.Now(), - currentOpers: make(map[*Client]bool), - defaultChannelModes: ParseDefaultChannelModes(config), - limits: Limits{ - AwayLen: int(config.Limits.AwayLen), - ChannelLen: int(config.Limits.ChannelLen), - KickLen: int(config.Limits.KickLen), - MonitorEntries: int(config.Limits.MonitorEntries), - NickLen: int(config.Limits.NickLen), - TopicLen: int(config.Limits.TopicLen), - ChanListModes: int(config.Limits.ChanListModes), - LineLen: LineLenLimits{ - Tags: config.Limits.LineLen.Tags, - Rest: config.Limits.LineLen.Rest, - }, - }, + accounts: make(map[string]*ClientAccount), + channels: *NewChannelNameMap(), + clients: NewClientLookupSet(), + commands: make(chan Command), + currentOpers: make(map[*Client]bool), listeners: make(map[string]*ListenerWrapper), logger: logger, - MaxSendQBytes: config.Server.MaxSendQBytes, monitoring: make(map[string][]*Client), - name: config.Server.Name, - nameCasefolded: casefoldedName, - networkName: config.Network.Name, newConns: make(chan clientConn), - operators: opers, - operclasses: *operClasses, - proxyAllowedFrom: config.Server.ProxyAllowedFrom, registeredChannels: make(map[string]*RegisteredChannel), rehashSignal: make(chan os.Signal, 1), - restAPI: &config.Server.RestAPI, signals: make(chan os.Signal, len(ServerExitSignals)), snomasks: NewSnoManager(), - stsEnabled: config.Server.STS.Enabled, whoWas: NewWhoWasList(config.Limits.WhowasEntries), } - // open data store - server.logger.Debug("startup", "Opening datastore") - db, err := buntdb.Open(config.Datastore.Path) - if err != nil { - return nil, fmt.Errorf("Failed to open datastore: %s", err.Error()) + if err := server.applyConfig(config, true); err != nil { + return nil, err } - server.store = db - - // check db version - err = server.store.View(func(tx *buntdb.Tx) error { - version, _ := tx.Get(keySchemaVersion) - if version != latestDbSchema { - logger.Error("startup", "server", fmt.Sprintf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version)) - return errDbOutOfDate - } - return nil - }) - if err != nil { - // close the db - db.Close() - return nil, errDbOutOfDate - } - - // load *lines - server.logger.Debug("startup", "Loading D/Klines") - server.loadDLines() - server.loadKLines() - - // load password manager - server.logger.Debug("startup", "Loading passwords") - err = server.store.View(func(tx *buntdb.Tx) error { - saltString, err := tx.Get(keySalt) - if err != nil { - return fmt.Errorf("Could not retrieve salt string: %s", err.Error()) - } - - salt, err := base64.StdEncoding.DecodeString(saltString) - if err != nil { - return err - } - - pwm := NewPasswordManager(salt) - server.passwords = &pwm - return nil - }) - if err != nil { - return nil, fmt.Errorf("Could not load salt: %s", err.Error()) - } - - server.logger.Debug("startup", "Loading MOTD") - if config.Server.MOTD != "" { - file, err := os.Open(config.Server.MOTD) - if err == nil { - defer file.Close() - - reader := bufio.NewReader(file) - for { - line, err := reader.ReadString('\n') - if err != nil { - break - } - line = strings.TrimRight(line, "\r\n") - // "- " is the required prefix for MOTD, we just add it here to make - // bursting it out to clients easier - line = fmt.Sprintf("- %s", line) - - server.motdLines = append(server.motdLines, line) - } - } - } - - if config.Server.Password != "" { - server.password = config.Server.PasswordBytes() - } - - tlsListeners := config.TLSListeners() - for _, addr := range config.Server.Listen { - server.listeners[addr] = server.createListener(addr, tlsListeners[addr]) - } - - if len(tlsListeners) == 0 { - server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections") - } - var usesStandardTLSPort bool - for addr := range config.TLSListeners() { - if strings.Contains(addr, "6697") { - usesStandardTLSPort = true - break - } - } - if 0 < len(tlsListeners) && !usesStandardTLSPort { - server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely") - } - - if config.Server.Wslisten != "" { - server.wslisten(config.Server.Wslisten, config.Server.TLSListeners) - } - - // registration - accountReg := NewAccountRegistration(config.Accounts.Registration) - server.accountRegistration = &accountReg // Attempt to clean up when receiving these signals. signal.Notify(server.signals, ServerExitSignals...) signal.Notify(server.rehashSignal, syscall.SIGHUP) - server.setISupport() - - // start API if enabled - if server.restAPI.Enabled { - logger.Info("startup", "server", fmt.Sprintf("%s rest API started on %s.", server.name, server.restAPI.Listen)) - server.startRestAPI() - } - return server, nil } @@ -514,13 +345,6 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen stopEvent: make(chan bool, 1), } - // TODO(slingamn) move all logging of listener status to rehash() - tlsString := "plaintext" - if tlsConfig != nil { - tlsString = "TLS" - } - server.logger.Info("listeners", fmt.Sprintf("listening on %s using %s.", addr, tlsString)) - var shouldStop bool // setup accept goroutine @@ -562,8 +386,29 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen // websocket listen goroutine // -func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig) { - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { +func (server *Server) setupWSListener(config *Config) { + // unconditionally shut down the old listener because we can't tell + // whether we need to reload the TLS certificate + if server.wsServer != nil { + ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout) + server.wsServer.Shutdown(ctx) + server.wsServer.Close() + } + + if config.Server.Wslisten == "" { + server.wsServer = nil + return + } + + addr := config.Server.Wslisten + tlsConfig := config.Server.TLSListeners[addr] + handler := http.NewServeMux() + wsServer := http.Server{ + Addr: addr, + Handler: handler, + } + + handler.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { server.logger.Error("ws", addr, fmt.Sprintf("%s method not allowed", r.Method)) return @@ -584,29 +429,26 @@ func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig) newConn := clientConn{ Conn: WSContainer{ws}, - IsTLS: false, //TODO(dan): track TLS or not here properly + IsTLS: tlsConfig != nil, } server.newConns <- newConn }) + go func() { - config, listenTLS := tlsMap[addr] - - tlsString := "plaintext" var err error - if listenTLS { - tlsString = "TLS" - } - server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s using %s.", addr, tlsString)) + server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s, tls=%t.", addr, tlsConfig != nil)) - if listenTLS { - err = http.ListenAndServeTLS(addr, config.Cert, config.Key, nil) + if tlsConfig != nil { + err = wsServer.ListenAndServeTLS(tlsConfig.Cert, tlsConfig.Key) } else { - err = http.ListenAndServe(addr, nil) + err = wsServer.ListenAndServe() } if err != nil { - server.logger.Error("listeners", fmt.Sprintf("listenAndServe error [%s]: %s", tlsString, err)) + server.logger.Error("listeners", fmt.Sprintf("websocket ListenAndServe error: %s", err)) } }() + + server.wsServer = &wsServer } // generateMessageID returns a network-unique message ID. @@ -660,6 +502,9 @@ func (server *Server) tryRegister(c *Client) { // MOTD serves the Message of the Day. func (server *Server) MOTD(client *Client) { + server.configurableStateMutex.RLock() + defer server.configurableStateMutex.RUnlock() + if len(server.motdLines) < 1 { client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing") return @@ -1415,14 +1260,34 @@ func (server *Server) rehash() error { server.logger.Debug("rehash", "Got rehash lock") config, err := LoadConfig(server.configFilename) - if err != nil { - return fmt.Errorf("Error rehashing config file config: %s", err.Error()) + return fmt.Errorf("Error loading config file config: %s", err.Error()) } - // line lengths cannot be changed after launching the server - if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest { - return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted") + err = server.applyConfig(config, false) + if err != nil { + return fmt.Errorf("Error applying config changes: %s", err.Error()) + } + + return nil +} + +func (server *Server) applyConfig(config *Config, initial bool) error { + if initial { + server.ctime = time.Now() + server.configFilename = config.Filename + } + + casefoldedName, err := Casefold(config.Server.Name) + if err != nil { + return fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error()) + } + + if !initial { + // line lengths cannot be changed after launching the server + if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest { + return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted") + } } // confirm connectionLimits are fine @@ -1453,6 +1318,18 @@ func (server *Server) rehash() error { } } + // sanity checks complete, start modifying server state + + server.name = config.Server.Name + server.nameCasefolded = casefoldedName + server.networkName = config.Network.Name + + if config.Server.Password != "" { + server.password = config.Server.PasswordBytes() + } else { + server.password = nil + } + // apply new connectionlimits server.connectionLimitsMutex.Lock() server.connectionLimits = connectionLimits @@ -1578,7 +1455,9 @@ func (server *Server) rehash() error { server.accountRegistration = &accountReg server.channelRegistrationEnabled = config.Channels.Registration.Enabled + server.configurableStateMutex.Lock() server.defaultChannelModes = ParseDefaultChannelModes(config) + server.configurableStateMutex.Unlock() // set new sendqueue size if config.Server.MaxSendQBytes != server.MaxSendQBytes { @@ -1595,18 +1474,108 @@ func (server *Server) rehash() error { // set RPL_ISUPPORT oldISupportList := server.isupport server.setISupport() - newISupportReplies := oldISupportList.GetDifference(server.isupport) + if oldISupportList != nil { + newISupportReplies := oldISupportList.GetDifference(server.isupport) + // push new info to all of our clients + server.clients.ByNickMutex.RLock() + for _, sClient := range server.clients.ByNick { + for _, tokenline := range newISupportReplies { + // ugly trickery ahead + sClient.Send(nil, server.name, RPL_ISUPPORT, append([]string{sClient.nick}, tokenline...)...) + } + } + server.clients.ByNickMutex.RUnlock() + } - // push new info to all of our clients - server.clients.ByNickMutex.RLock() - for _, sClient := range server.clients.ByNick { - for _, tokenline := range newISupportReplies { - // ugly trickery ahead - sClient.Send(nil, server.name, RPL_ISUPPORT, append([]string{sClient.nick}, tokenline...)...) + server.loadMOTD(config.Server.MOTD) + + if initial { + if err := server.loadDatastore(config.Datastore.Path); err != nil { + return err } } - server.clients.ByNickMutex.RUnlock() + // we are now open for business + server.setupListeners(config) + server.setupWSListener(config) + server.setupRestAPI(config) + + return nil +} + +func (server *Server) loadMOTD(motdPath string) error { + server.logger.Debug("rehash", "Loading MOTD") + motdLines := make([]string, 0) + if motdPath != "" { + file, err := os.Open(motdPath) + if err == nil { + defer file.Close() + + reader := bufio.NewReader(file) + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + line = strings.TrimRight(line, "\r\n") + // "- " is the required prefix for MOTD, we just add it here to make + // bursting it out to clients easier + line = fmt.Sprintf("- %s", line) + + motdLines = append(motdLines, line) + } + } else { + return err + } + } + + server.configurableStateMutex.Lock() + defer server.configurableStateMutex.Unlock() + server.motdLines = motdLines + return nil +} + +func (server *Server) loadDatastore(datastorePath string) error { + // open the datastore and load server state for which it (rather than config) + // is the source of truth + + server.logger.Debug("startup", "Opening datastore") + db, err := OpenDatabase(datastorePath) + if err == nil { + server.store = db + } else { + return fmt.Errorf("Failed to open datastore: %s", err.Error()) + } + + // load *lines (from the datastores) + server.logger.Debug("startup", "Loading D/Klines") + server.loadDLines() + server.loadKLines() + + // load password manager + server.logger.Debug("startup", "Loading passwords") + err = server.store.View(func(tx *buntdb.Tx) error { + saltString, err := tx.Get(keySalt) + if err != nil { + return fmt.Errorf("Could not retrieve salt string: %s", err.Error()) + } + + salt, err := base64.StdEncoding.DecodeString(saltString) + if err != nil { + return err + } + + pwm := NewPasswordManager(salt) + server.passwords = &pwm + return nil + }) + if err != nil { + return fmt.Errorf("Could not load salt: %s", err.Error()) + } + return nil +} + +func (server *Server) setupListeners(config *Config) { // update or destroy all existing listeners tlsListeners := config.TLSListeners() for addr := range server.listeners { @@ -1629,18 +1598,16 @@ func (server *Server) rehash() error { currentListener.configMutex.Unlock() if stillConfigured { - server.logger.Info("rehash", + server.logger.Info("listeners", fmt.Sprintf("now listening on %s, tls=%t.", addr, (currentListener.tlsConfig != nil)), ) } else { // tell the listener it should stop by interrupting its Accept() call: currentListener.listener.Close() - // XXX there is no guarantee from the API when the address will actually - // free for bind(2) again; this solution "seems to work". See here: - // https://github.com/golang/go/issues/21833 + // TODO(golang1.10) delete stopEvent once issue #21856 is released <-currentListener.stopEvent delete(server.listeners, addr) - server.logger.Info("rehash", fmt.Sprintf("stopped listening on %s.", addr)) + server.logger.Info("listeners", fmt.Sprintf("stopped listening on %s.", addr)) } } @@ -1653,7 +1620,52 @@ func (server *Server) rehash() error { } } - return nil + if len(tlsListeners) == 0 { + server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections") + } + + var usesStandardTLSPort bool + for addr := range config.TLSListeners() { + if strings.Contains(addr, "6697") { + usesStandardTLSPort = true + break + } + } + if 0 < len(tlsListeners) && !usesStandardTLSPort { + server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely") + } +} + +func (server *Server) setupRestAPI(config *Config) { + restAPIEnabled := config.Server.RestAPI.Enabled + restAPIStarted := server.restAPIServer != nil + restAPIListenAddrChanged := server.restAPI.Listen != config.Server.RestAPI.Listen + + // stop an existing REST server if it's been disabled or the addr changed + if restAPIStarted && (!restAPIEnabled || restAPIListenAddrChanged) { + ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout) + server.restAPIServer.Shutdown(ctx) + server.restAPIServer.Close() + server.logger.Info("rehash", "server", fmt.Sprintf("%s rest API stopped on %s.", server.name, server.restAPI.Listen)) + server.restAPIServer = nil + } + + // start a new one if it's enabled or the addr changed + if restAPIEnabled && (!restAPIStarted || restAPIListenAddrChanged) { + server.restAPIServer, _ = StartRestAPI(server, config.Server.RestAPI.Listen) + server.logger.Info( + "rehash", "server", + fmt.Sprintf("%s rest API started on %s.", server.name, config.Server.RestAPI.Listen)) + } + + // save the config information + server.restAPI = config.Server.RestAPI +} + +func (server *Server) GetDefaultChannelModes() Modes { + server.configurableStateMutex.RLock() + defer server.configurableStateMutex.RUnlock() + return server.defaultChannelModes } // REHASH diff --git a/oragono.go b/oragono.go index 822b40dc..dac3399e 100644 --- a/oragono.go +++ b/oragono.go @@ -132,7 +132,7 @@ Options: logger.Warning("startup", "You are currently running an unreleased beta version of Oragono that may be unstable and could corrupt your database.\nIf you are running a production network, please download the latest build from https://oragono.io/downloads.html and run that instead.") } - server, err := irc.NewServer(configfile, config, logger) + server, err := irc.NewServer(config, logger) if err != nil { logger.Error("startup", fmt.Sprintf("Could not load server: %s", err.Error())) return