mirror of
https://github.com/ergochat/ergo.git
synced 2024-11-10 22:19:31 +01:00
refactor the rehash implementation
This commit is contained in:
parent
eae04e8c51
commit
e8b1870067
@ -57,7 +57,7 @@ func NewChannel(s *Server, name string, addDefaultModes bool) *Channel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if addDefaultModes {
|
if addDefaultModes {
|
||||||
for _, mode := range s.defaultChannelModes {
|
for _, mode := range s.GetDefaultChannelModes() {
|
||||||
channel.flags[mode] = true
|
channel.flags[mode] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -244,6 +244,8 @@ type Config struct {
|
|||||||
WhowasEntries uint `yaml:"whowas-entries"`
|
WhowasEntries uint `yaml:"whowas-entries"`
|
||||||
LineLen LineLenConfig `yaml:"linelen"`
|
LineLen LineLenConfig `yaml:"linelen"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Filename string
|
||||||
}
|
}
|
||||||
|
|
||||||
// OperClass defines an assembled operator class.
|
// OperClass defines an assembled operator class.
|
||||||
@ -390,6 +392,8 @@ func LoadConfig(filename string) (config *Config, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.Filename = filename
|
||||||
|
|
||||||
// we need this so PasswordBytes returns the correct info
|
// we need this so PasswordBytes returns the correct info
|
||||||
if config.Server.Password != "" {
|
if config.Server.Password != "" {
|
||||||
config.Server.PassConfig.Password = config.Server.Password
|
config.Server.PassConfig.Password = config.Server.Password
|
||||||
|
@ -53,6 +53,32 @@ func InitDB(path string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// open an existing database, performing a schema version check
|
||||||
|
func OpenDatabase(path string) (*buntdb.DB, error) {
|
||||||
|
// open data store
|
||||||
|
db, err := buntdb.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// check db version
|
||||||
|
err = db.View(func(tx *buntdb.Tx) error {
|
||||||
|
version, _ := tx.Get(keySchemaVersion)
|
||||||
|
if version != latestDbSchema {
|
||||||
|
return fmt.Errorf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// close the db
|
||||||
|
db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpgradeDB upgrades the datastore to the latest schema.
|
// UpgradeDB upgrades the datastore to the latest schema.
|
||||||
func UpgradeDB(path string) {
|
func UpgradeDB(path string) {
|
||||||
store, err := buntdb.Open(path)
|
store, err := buntdb.Open(path)
|
||||||
|
23
irc/help.go
23
irc/help.go
@ -530,10 +530,10 @@ Oragono supports the following channel membership prefixes:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HelpIndex contains the list of all help topics for regular users.
|
// HelpIndex contains the list of all help topics for regular users.
|
||||||
var HelpIndex = "list of all help topics for regular users"
|
var HelpIndex string
|
||||||
|
|
||||||
// HelpIndexOpers contains the list of all help topics for opers.
|
// HelpIndexOpers contains the list of all help topics for opers.
|
||||||
var HelpIndexOpers = "list of all help topics for opers"
|
var HelpIndexOpers string
|
||||||
|
|
||||||
// GenerateHelpIndex is used to generate HelpIndex.
|
// GenerateHelpIndex is used to generate HelpIndex.
|
||||||
func GenerateHelpIndex(forOpers bool) string {
|
func GenerateHelpIndex(forOpers bool) string {
|
||||||
@ -582,6 +582,25 @@ Information:
|
|||||||
return newHelpIndex
|
return newHelpIndex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateHelpIndices() error {
|
||||||
|
if HelpIndex != "" && HelpIndexOpers != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// startup check that we have HELP entries for every command
|
||||||
|
for name := range Commands {
|
||||||
|
_, exists := Help[strings.ToLower(name)]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("Help entry does not exist for command %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate help indexes
|
||||||
|
HelpIndex = GenerateHelpIndex(false)
|
||||||
|
HelpIndexOpers = GenerateHelpIndex(true)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// sendHelp sends the client help of the given string.
|
// sendHelp sends the client help of the given string.
|
||||||
func (client *Client) sendHelp(name string, text string) {
|
func (client *Client) sendHelp(name string, text string) {
|
||||||
splitName := strings.Split(name, " ")
|
splitName := strings.Split(name, " ")
|
||||||
|
@ -19,9 +19,9 @@ import (
|
|||||||
|
|
||||||
const restErr = "{\"error\":\"An unknown error occurred\"}"
|
const restErr = "{\"error\":\"An unknown error occurred\"}"
|
||||||
|
|
||||||
// restAPIServer is used to keep a link to the current running server since this is the best
|
// ircServer is used to keep a link to the current running server since this is the best
|
||||||
// way to do it, given how HTTP handlers dispatch and work.
|
// way to do it, given how HTTP handlers dispatch and work.
|
||||||
var restAPIServer *Server
|
var ircServer *Server
|
||||||
|
|
||||||
type restInfoResp struct {
|
type restInfoResp struct {
|
||||||
ServerName string `json:"server-name"`
|
ServerName string `json:"server-name"`
|
||||||
@ -60,8 +60,8 @@ type restRehashResp struct {
|
|||||||
func restInfo(w http.ResponseWriter, r *http.Request) {
|
func restInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
rs := restInfoResp{
|
rs := restInfoResp{
|
||||||
Version: SemVer,
|
Version: SemVer,
|
||||||
ServerName: restAPIServer.name,
|
ServerName: ircServer.name,
|
||||||
NetworkName: restAPIServer.networkName,
|
NetworkName: ircServer.networkName,
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(rs)
|
b, err := json.Marshal(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -73,9 +73,9 @@ func restInfo(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
func restStatus(w http.ResponseWriter, r *http.Request) {
|
func restStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
rs := restStatusResp{
|
rs := restStatusResp{
|
||||||
Clients: restAPIServer.clients.Count(),
|
Clients: ircServer.clients.Count(),
|
||||||
Opers: len(restAPIServer.operators),
|
Opers: len(ircServer.operators),
|
||||||
Channels: restAPIServer.channels.Len(),
|
Channels: ircServer.channels.Len(),
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(rs)
|
b, err := json.Marshal(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -87,8 +87,8 @@ func restStatus(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
func restGetXLines(w http.ResponseWriter, r *http.Request) {
|
func restGetXLines(w http.ResponseWriter, r *http.Request) {
|
||||||
rs := restXLinesResp{
|
rs := restXLinesResp{
|
||||||
DLines: restAPIServer.dlines.AllBans(),
|
DLines: ircServer.dlines.AllBans(),
|
||||||
KLines: restAPIServer.klines.AllBans(),
|
KLines: ircServer.klines.AllBans(),
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(rs)
|
b, err := json.Marshal(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -104,7 +104,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get accounts
|
// get accounts
|
||||||
err := restAPIServer.store.View(func(tx *buntdb.Tx) error {
|
err := ircServer.store.View(func(tx *buntdb.Tx) error {
|
||||||
tx.AscendKeys("account.exists *", func(key, value string) bool {
|
tx.AscendKeys("account.exists *", func(key, value string) bool {
|
||||||
key = key[len("account.exists "):]
|
key = key[len("account.exists "):]
|
||||||
_, err := tx.Get(fmt.Sprintf(keyAccountVerified, key))
|
_, err := tx.Get(fmt.Sprintf(keyAccountVerified, key))
|
||||||
@ -118,7 +118,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
regTime := time.Unix(regTimeInt, 0)
|
regTime := time.Unix(regTimeInt, 0)
|
||||||
|
|
||||||
var clients int
|
var clients int
|
||||||
acct := restAPIServer.accounts[key]
|
acct := ircServer.accounts[key]
|
||||||
if acct != nil {
|
if acct != nil {
|
||||||
clients = len(acct.Clients)
|
clients = len(acct.Clients)
|
||||||
}
|
}
|
||||||
@ -148,7 +148,7 @@ func restGetAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func restRehash(w http.ResponseWriter, r *http.Request) {
|
func restRehash(w http.ResponseWriter, r *http.Request) {
|
||||||
err := restAPIServer.rehash()
|
err := ircServer.rehash()
|
||||||
|
|
||||||
rs := restRehashResp{
|
rs := restRehashResp{
|
||||||
Successful: err == nil,
|
Successful: err == nil,
|
||||||
@ -166,9 +166,9 @@ func restRehash(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) startRestAPI() {
|
func StartRestAPI(s *Server, listenAddr string) (*http.Server, error) {
|
||||||
// so handlers can ref it later
|
// so handlers can ref it later
|
||||||
restAPIServer = s
|
ircServer = s
|
||||||
|
|
||||||
// start router
|
// start router
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
@ -185,5 +185,16 @@ func (s *Server) startRestAPI() {
|
|||||||
rp.HandleFunc("/rehash", restRehash)
|
rp.HandleFunc("/rehash", restRehash)
|
||||||
|
|
||||||
// start api
|
// start api
|
||||||
go http.ListenAndServe(s.restAPI.Listen, r)
|
httpserver := http.Server{
|
||||||
|
Addr: listenAddr,
|
||||||
|
Handler: r,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := httpserver.ListenAndServe(); err != nil {
|
||||||
|
s.logger.Error("listeners", fmt.Sprintf("Rest API listenAndServe error: %s", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &httpserver, nil
|
||||||
}
|
}
|
||||||
|
448
irc/server.go
448
irc/server.go
@ -7,9 +7,9 @@ package irc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@ -35,8 +35,14 @@ var (
|
|||||||
tooManyClientsMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Too many clients from your network")}[0]).Line()
|
tooManyClientsMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Too many clients from your network")}[0]).Line()
|
||||||
couldNotParseIPMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line()
|
couldNotParseIPMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line()
|
||||||
bannedFromServerMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "You are banned from this server (%s)")}[0]).Line()
|
bannedFromServerMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "You are banned from this server (%s)")}[0]).Line()
|
||||||
|
)
|
||||||
|
|
||||||
errDbOutOfDate = errors.New("Database schema is old")
|
const (
|
||||||
|
// when shutting down the REST and websocket servers, wait this long
|
||||||
|
// before killing active non-WS connections. TODO: this might not be
|
||||||
|
// necessary at all? but it seems prudent to avoid potential resource
|
||||||
|
// leaks
|
||||||
|
httpShutdownTimeout = time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// Limits holds the maximum limits for various things such as topic lengths.
|
// Limits holds the maximum limits for various things such as topic lengths.
|
||||||
@ -80,6 +86,7 @@ type Server struct {
|
|||||||
clients *ClientLookupSet
|
clients *ClientLookupSet
|
||||||
commands chan Command
|
commands chan Command
|
||||||
configFilename string
|
configFilename string
|
||||||
|
configurableStateMutex sync.RWMutex // generic protection for server state modified by rehash()
|
||||||
connectionLimits *ConnectionLimits
|
connectionLimits *ConnectionLimits
|
||||||
connectionLimitsMutex sync.Mutex // used when affecting the connection limiter, to make sure rehashing doesn't make things go out-of-whack
|
connectionLimitsMutex sync.Mutex // used when affecting the connection limiter, to make sure rehashing doesn't make things go out-of-whack
|
||||||
connectionThrottle *ConnectionThrottle
|
connectionThrottle *ConnectionThrottle
|
||||||
@ -109,13 +116,15 @@ type Server struct {
|
|||||||
registeredChannelsMutex sync.RWMutex
|
registeredChannelsMutex sync.RWMutex
|
||||||
rehashMutex sync.Mutex
|
rehashMutex sync.Mutex
|
||||||
rehashSignal chan os.Signal
|
rehashSignal chan os.Signal
|
||||||
restAPI *RestAPIConfig
|
restAPI RestAPIConfig
|
||||||
|
restAPIServer *http.Server
|
||||||
proxyAllowedFrom []string
|
proxyAllowedFrom []string
|
||||||
signals chan os.Signal
|
signals chan os.Signal
|
||||||
snomasks *SnoManager
|
snomasks *SnoManager
|
||||||
store *buntdb.DB
|
store *buntdb.DB
|
||||||
stsEnabled bool
|
stsEnabled bool
|
||||||
whoWas *WhoWasList
|
whoWas *WhoWasList
|
||||||
|
wsServer *http.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -133,216 +142,38 @@ type clientConn struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewServer returns a new Oragono server.
|
// NewServer returns a new Oragono server.
|
||||||
func NewServer(configFilename string, config *Config, logger *logger.Manager) (*Server, error) {
|
func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
|
||||||
casefoldedName, err := Casefold(config.Server.Name)
|
// TODO move this to main?
|
||||||
if err != nil {
|
if err := GenerateHelpIndices(); err != nil {
|
||||||
return nil, fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error())
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
// startup check that we have HELP entries for every command
|
|
||||||
for name := range Commands {
|
|
||||||
_, exists := Help[strings.ToLower(name)]
|
|
||||||
if !exists {
|
|
||||||
return nil, fmt.Errorf("Help entry does not exist for command %s", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// generate help indexes
|
|
||||||
HelpIndex = GenerateHelpIndex(false)
|
|
||||||
HelpIndexOpers = GenerateHelpIndex(true)
|
|
||||||
|
|
||||||
if config.Accounts.AuthenticationEnabled {
|
|
||||||
SupportedCapabilities[SASL] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Server.STS.Enabled {
|
|
||||||
SupportedCapabilities[STS] = true
|
|
||||||
CapValues[STS] = config.Server.STS.Value()
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Limits.LineLen.Tags > 512 || config.Limits.LineLen.Rest > 512 {
|
|
||||||
SupportedCapabilities[MaxLine] = true
|
|
||||||
CapValues[MaxLine] = fmt.Sprintf("%d,%d", config.Limits.LineLen.Tags, config.Limits.LineLen.Rest)
|
|
||||||
}
|
|
||||||
|
|
||||||
operClasses, err := config.OperatorClasses()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error loading oper classes: %s", err.Error())
|
|
||||||
}
|
|
||||||
opers, err := config.Operators(operClasses)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error loading operators: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error loading connection limits: %s", err.Error())
|
|
||||||
}
|
|
||||||
connectionThrottle, err := NewConnectionThrottle(config.Server.ConnectionThrottle)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Error loading connection throttler: %s", err.Error())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initialize data structures
|
||||||
server := &Server{
|
server := &Server{
|
||||||
accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled,
|
|
||||||
accounts: make(map[string]*ClientAccount),
|
accounts: make(map[string]*ClientAccount),
|
||||||
channelRegistrationEnabled: config.Channels.Registration.Enabled,
|
|
||||||
channels: *NewChannelNameMap(),
|
channels: *NewChannelNameMap(),
|
||||||
checkIdent: config.Server.CheckIdent,
|
|
||||||
clients: NewClientLookupSet(),
|
clients: NewClientLookupSet(),
|
||||||
commands: make(chan Command),
|
commands: make(chan Command),
|
||||||
configFilename: configFilename,
|
|
||||||
connectionLimits: connectionLimits,
|
|
||||||
connectionThrottle: connectionThrottle,
|
|
||||||
ctime: time.Now(),
|
|
||||||
currentOpers: make(map[*Client]bool),
|
currentOpers: make(map[*Client]bool),
|
||||||
defaultChannelModes: ParseDefaultChannelModes(config),
|
|
||||||
limits: Limits{
|
|
||||||
AwayLen: int(config.Limits.AwayLen),
|
|
||||||
ChannelLen: int(config.Limits.ChannelLen),
|
|
||||||
KickLen: int(config.Limits.KickLen),
|
|
||||||
MonitorEntries: int(config.Limits.MonitorEntries),
|
|
||||||
NickLen: int(config.Limits.NickLen),
|
|
||||||
TopicLen: int(config.Limits.TopicLen),
|
|
||||||
ChanListModes: int(config.Limits.ChanListModes),
|
|
||||||
LineLen: LineLenLimits{
|
|
||||||
Tags: config.Limits.LineLen.Tags,
|
|
||||||
Rest: config.Limits.LineLen.Rest,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
listeners: make(map[string]*ListenerWrapper),
|
listeners: make(map[string]*ListenerWrapper),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
MaxSendQBytes: config.Server.MaxSendQBytes,
|
|
||||||
monitoring: make(map[string][]*Client),
|
monitoring: make(map[string][]*Client),
|
||||||
name: config.Server.Name,
|
|
||||||
nameCasefolded: casefoldedName,
|
|
||||||
networkName: config.Network.Name,
|
|
||||||
newConns: make(chan clientConn),
|
newConns: make(chan clientConn),
|
||||||
operators: opers,
|
|
||||||
operclasses: *operClasses,
|
|
||||||
proxyAllowedFrom: config.Server.ProxyAllowedFrom,
|
|
||||||
registeredChannels: make(map[string]*RegisteredChannel),
|
registeredChannels: make(map[string]*RegisteredChannel),
|
||||||
rehashSignal: make(chan os.Signal, 1),
|
rehashSignal: make(chan os.Signal, 1),
|
||||||
restAPI: &config.Server.RestAPI,
|
|
||||||
signals: make(chan os.Signal, len(ServerExitSignals)),
|
signals: make(chan os.Signal, len(ServerExitSignals)),
|
||||||
snomasks: NewSnoManager(),
|
snomasks: NewSnoManager(),
|
||||||
stsEnabled: config.Server.STS.Enabled,
|
|
||||||
whoWas: NewWhoWasList(config.Limits.WhowasEntries),
|
whoWas: NewWhoWasList(config.Limits.WhowasEntries),
|
||||||
}
|
}
|
||||||
|
|
||||||
// open data store
|
if err := server.applyConfig(config, true); err != nil {
|
||||||
server.logger.Debug("startup", "Opening datastore")
|
return nil, err
|
||||||
db, err := buntdb.Open(config.Datastore.Path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Failed to open datastore: %s", err.Error())
|
|
||||||
}
|
}
|
||||||
server.store = db
|
|
||||||
|
|
||||||
// check db version
|
|
||||||
err = server.store.View(func(tx *buntdb.Tx) error {
|
|
||||||
version, _ := tx.Get(keySchemaVersion)
|
|
||||||
if version != latestDbSchema {
|
|
||||||
logger.Error("startup", "server", fmt.Sprintf("Database must be updated. Expected schema v%s, got v%s.", latestDbSchema, version))
|
|
||||||
return errDbOutOfDate
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
// close the db
|
|
||||||
db.Close()
|
|
||||||
return nil, errDbOutOfDate
|
|
||||||
}
|
|
||||||
|
|
||||||
// load *lines
|
|
||||||
server.logger.Debug("startup", "Loading D/Klines")
|
|
||||||
server.loadDLines()
|
|
||||||
server.loadKLines()
|
|
||||||
|
|
||||||
// load password manager
|
|
||||||
server.logger.Debug("startup", "Loading passwords")
|
|
||||||
err = server.store.View(func(tx *buntdb.Tx) error {
|
|
||||||
saltString, err := tx.Get(keySalt)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Could not retrieve salt string: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
salt, err := base64.StdEncoding.DecodeString(saltString)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
pwm := NewPasswordManager(salt)
|
|
||||||
server.passwords = &pwm
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not load salt: %s", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
server.logger.Debug("startup", "Loading MOTD")
|
|
||||||
if config.Server.MOTD != "" {
|
|
||||||
file, err := os.Open(config.Server.MOTD)
|
|
||||||
if err == nil {
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
reader := bufio.NewReader(file)
|
|
||||||
for {
|
|
||||||
line, err := reader.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
line = strings.TrimRight(line, "\r\n")
|
|
||||||
// "- " is the required prefix for MOTD, we just add it here to make
|
|
||||||
// bursting it out to clients easier
|
|
||||||
line = fmt.Sprintf("- %s", line)
|
|
||||||
|
|
||||||
server.motdLines = append(server.motdLines, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Server.Password != "" {
|
|
||||||
server.password = config.Server.PasswordBytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsListeners := config.TLSListeners()
|
|
||||||
for _, addr := range config.Server.Listen {
|
|
||||||
server.listeners[addr] = server.createListener(addr, tlsListeners[addr])
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tlsListeners) == 0 {
|
|
||||||
server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections")
|
|
||||||
}
|
|
||||||
var usesStandardTLSPort bool
|
|
||||||
for addr := range config.TLSListeners() {
|
|
||||||
if strings.Contains(addr, "6697") {
|
|
||||||
usesStandardTLSPort = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if 0 < len(tlsListeners) && !usesStandardTLSPort {
|
|
||||||
server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Server.Wslisten != "" {
|
|
||||||
server.wslisten(config.Server.Wslisten, config.Server.TLSListeners)
|
|
||||||
}
|
|
||||||
|
|
||||||
// registration
|
|
||||||
accountReg := NewAccountRegistration(config.Accounts.Registration)
|
|
||||||
server.accountRegistration = &accountReg
|
|
||||||
|
|
||||||
// Attempt to clean up when receiving these signals.
|
// Attempt to clean up when receiving these signals.
|
||||||
signal.Notify(server.signals, ServerExitSignals...)
|
signal.Notify(server.signals, ServerExitSignals...)
|
||||||
signal.Notify(server.rehashSignal, syscall.SIGHUP)
|
signal.Notify(server.rehashSignal, syscall.SIGHUP)
|
||||||
|
|
||||||
server.setISupport()
|
|
||||||
|
|
||||||
// start API if enabled
|
|
||||||
if server.restAPI.Enabled {
|
|
||||||
logger.Info("startup", "server", fmt.Sprintf("%s rest API started on %s.", server.name, server.restAPI.Listen))
|
|
||||||
server.startRestAPI()
|
|
||||||
}
|
|
||||||
|
|
||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -514,13 +345,6 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen
|
|||||||
stopEvent: make(chan bool, 1),
|
stopEvent: make(chan bool, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(slingamn) move all logging of listener status to rehash()
|
|
||||||
tlsString := "plaintext"
|
|
||||||
if tlsConfig != nil {
|
|
||||||
tlsString = "TLS"
|
|
||||||
}
|
|
||||||
server.logger.Info("listeners", fmt.Sprintf("listening on %s using %s.", addr, tlsString))
|
|
||||||
|
|
||||||
var shouldStop bool
|
var shouldStop bool
|
||||||
|
|
||||||
// setup accept goroutine
|
// setup accept goroutine
|
||||||
@ -562,8 +386,29 @@ func (server *Server) createListener(addr string, tlsConfig *tls.Config) *Listen
|
|||||||
// websocket listen goroutine
|
// websocket listen goroutine
|
||||||
//
|
//
|
||||||
|
|
||||||
func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig) {
|
func (server *Server) setupWSListener(config *Config) {
|
||||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
// unconditionally shut down the old listener because we can't tell
|
||||||
|
// whether we need to reload the TLS certificate
|
||||||
|
if server.wsServer != nil {
|
||||||
|
ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout)
|
||||||
|
server.wsServer.Shutdown(ctx)
|
||||||
|
server.wsServer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.Server.Wslisten == "" {
|
||||||
|
server.wsServer = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := config.Server.Wslisten
|
||||||
|
tlsConfig := config.Server.TLSListeners[addr]
|
||||||
|
handler := http.NewServeMux()
|
||||||
|
wsServer := http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
server.logger.Error("ws", addr, fmt.Sprintf("%s method not allowed", r.Method))
|
server.logger.Error("ws", addr, fmt.Sprintf("%s method not allowed", r.Method))
|
||||||
return
|
return
|
||||||
@ -584,29 +429,26 @@ func (server *Server) wslisten(addr string, tlsMap map[string]*TLSListenConfig)
|
|||||||
|
|
||||||
newConn := clientConn{
|
newConn := clientConn{
|
||||||
Conn: WSContainer{ws},
|
Conn: WSContainer{ws},
|
||||||
IsTLS: false, //TODO(dan): track TLS or not here properly
|
IsTLS: tlsConfig != nil,
|
||||||
}
|
}
|
||||||
server.newConns <- newConn
|
server.newConns <- newConn
|
||||||
})
|
})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
config, listenTLS := tlsMap[addr]
|
|
||||||
|
|
||||||
tlsString := "plaintext"
|
|
||||||
var err error
|
var err error
|
||||||
if listenTLS {
|
server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s, tls=%t.", addr, tlsConfig != nil))
|
||||||
tlsString = "TLS"
|
|
||||||
}
|
|
||||||
server.logger.Info("listeners", fmt.Sprintf("websocket listening on %s using %s.", addr, tlsString))
|
|
||||||
|
|
||||||
if listenTLS {
|
if tlsConfig != nil {
|
||||||
err = http.ListenAndServeTLS(addr, config.Cert, config.Key, nil)
|
err = wsServer.ListenAndServeTLS(tlsConfig.Cert, tlsConfig.Key)
|
||||||
} else {
|
} else {
|
||||||
err = http.ListenAndServe(addr, nil)
|
err = wsServer.ListenAndServe()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
server.logger.Error("listeners", fmt.Sprintf("listenAndServe error [%s]: %s", tlsString, err))
|
server.logger.Error("listeners", fmt.Sprintf("websocket ListenAndServe error: %s", err))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
server.wsServer = &wsServer
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateMessageID returns a network-unique message ID.
|
// generateMessageID returns a network-unique message ID.
|
||||||
@ -660,6 +502,9 @@ func (server *Server) tryRegister(c *Client) {
|
|||||||
|
|
||||||
// MOTD serves the Message of the Day.
|
// MOTD serves the Message of the Day.
|
||||||
func (server *Server) MOTD(client *Client) {
|
func (server *Server) MOTD(client *Client) {
|
||||||
|
server.configurableStateMutex.RLock()
|
||||||
|
defer server.configurableStateMutex.RUnlock()
|
||||||
|
|
||||||
if len(server.motdLines) < 1 {
|
if len(server.motdLines) < 1 {
|
||||||
client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing")
|
client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing")
|
||||||
return
|
return
|
||||||
@ -1415,15 +1260,35 @@ func (server *Server) rehash() error {
|
|||||||
server.logger.Debug("rehash", "Got rehash lock")
|
server.logger.Debug("rehash", "Got rehash lock")
|
||||||
|
|
||||||
config, err := LoadConfig(server.configFilename)
|
config, err := LoadConfig(server.configFilename)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error rehashing config file config: %s", err.Error())
|
return fmt.Errorf("Error loading config file config: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = server.applyConfig(config, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error applying config changes: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) applyConfig(config *Config, initial bool) error {
|
||||||
|
if initial {
|
||||||
|
server.ctime = time.Now()
|
||||||
|
server.configFilename = config.Filename
|
||||||
|
}
|
||||||
|
|
||||||
|
casefoldedName, err := Casefold(config.Server.Name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Server name isn't valid [%s]: %s", config.Server.Name, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !initial {
|
||||||
// line lengths cannot be changed after launching the server
|
// line lengths cannot be changed after launching the server
|
||||||
if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest {
|
if server.limits.LineLen.Tags != config.Limits.LineLen.Tags || server.limits.LineLen.Rest != config.Limits.LineLen.Rest {
|
||||||
return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted")
|
return fmt.Errorf("Maximum line length (linelen) cannot be changed after launching the server, rehash aborted")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// confirm connectionLimits are fine
|
// confirm connectionLimits are fine
|
||||||
connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits)
|
connectionLimits, err := NewConnectionLimits(config.Server.ConnectionLimits)
|
||||||
@ -1453,6 +1318,18 @@ func (server *Server) rehash() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sanity checks complete, start modifying server state
|
||||||
|
|
||||||
|
server.name = config.Server.Name
|
||||||
|
server.nameCasefolded = casefoldedName
|
||||||
|
server.networkName = config.Network.Name
|
||||||
|
|
||||||
|
if config.Server.Password != "" {
|
||||||
|
server.password = config.Server.PasswordBytes()
|
||||||
|
} else {
|
||||||
|
server.password = nil
|
||||||
|
}
|
||||||
|
|
||||||
// apply new connectionlimits
|
// apply new connectionlimits
|
||||||
server.connectionLimitsMutex.Lock()
|
server.connectionLimitsMutex.Lock()
|
||||||
server.connectionLimits = connectionLimits
|
server.connectionLimits = connectionLimits
|
||||||
@ -1578,7 +1455,9 @@ func (server *Server) rehash() error {
|
|||||||
server.accountRegistration = &accountReg
|
server.accountRegistration = &accountReg
|
||||||
server.channelRegistrationEnabled = config.Channels.Registration.Enabled
|
server.channelRegistrationEnabled = config.Channels.Registration.Enabled
|
||||||
|
|
||||||
|
server.configurableStateMutex.Lock()
|
||||||
server.defaultChannelModes = ParseDefaultChannelModes(config)
|
server.defaultChannelModes = ParseDefaultChannelModes(config)
|
||||||
|
server.configurableStateMutex.Unlock()
|
||||||
|
|
||||||
// set new sendqueue size
|
// set new sendqueue size
|
||||||
if config.Server.MaxSendQBytes != server.MaxSendQBytes {
|
if config.Server.MaxSendQBytes != server.MaxSendQBytes {
|
||||||
@ -1595,8 +1474,8 @@ func (server *Server) rehash() error {
|
|||||||
// set RPL_ISUPPORT
|
// set RPL_ISUPPORT
|
||||||
oldISupportList := server.isupport
|
oldISupportList := server.isupport
|
||||||
server.setISupport()
|
server.setISupport()
|
||||||
|
if oldISupportList != nil {
|
||||||
newISupportReplies := oldISupportList.GetDifference(server.isupport)
|
newISupportReplies := oldISupportList.GetDifference(server.isupport)
|
||||||
|
|
||||||
// push new info to all of our clients
|
// push new info to all of our clients
|
||||||
server.clients.ByNickMutex.RLock()
|
server.clients.ByNickMutex.RLock()
|
||||||
for _, sClient := range server.clients.ByNick {
|
for _, sClient := range server.clients.ByNick {
|
||||||
@ -1606,7 +1485,97 @@ func (server *Server) rehash() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
server.clients.ByNickMutex.RUnlock()
|
server.clients.ByNickMutex.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
server.loadMOTD(config.Server.MOTD)
|
||||||
|
|
||||||
|
if initial {
|
||||||
|
if err := server.loadDatastore(config.Datastore.Path); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we are now open for business
|
||||||
|
server.setupListeners(config)
|
||||||
|
server.setupWSListener(config)
|
||||||
|
server.setupRestAPI(config)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) loadMOTD(motdPath string) error {
|
||||||
|
server.logger.Debug("rehash", "Loading MOTD")
|
||||||
|
motdLines := make([]string, 0)
|
||||||
|
if motdPath != "" {
|
||||||
|
file, err := os.Open(motdPath)
|
||||||
|
if err == nil {
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
reader := bufio.NewReader(file)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
line = strings.TrimRight(line, "\r\n")
|
||||||
|
// "- " is the required prefix for MOTD, we just add it here to make
|
||||||
|
// bursting it out to clients easier
|
||||||
|
line = fmt.Sprintf("- %s", line)
|
||||||
|
|
||||||
|
motdLines = append(motdLines, line)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
server.configurableStateMutex.Lock()
|
||||||
|
defer server.configurableStateMutex.Unlock()
|
||||||
|
server.motdLines = motdLines
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) loadDatastore(datastorePath string) error {
|
||||||
|
// open the datastore and load server state for which it (rather than config)
|
||||||
|
// is the source of truth
|
||||||
|
|
||||||
|
server.logger.Debug("startup", "Opening datastore")
|
||||||
|
db, err := OpenDatabase(datastorePath)
|
||||||
|
if err == nil {
|
||||||
|
server.store = db
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("Failed to open datastore: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// load *lines (from the datastores)
|
||||||
|
server.logger.Debug("startup", "Loading D/Klines")
|
||||||
|
server.loadDLines()
|
||||||
|
server.loadKLines()
|
||||||
|
|
||||||
|
// load password manager
|
||||||
|
server.logger.Debug("startup", "Loading passwords")
|
||||||
|
err = server.store.View(func(tx *buntdb.Tx) error {
|
||||||
|
saltString, err := tx.Get(keySalt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not retrieve salt string: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
salt, err := base64.StdEncoding.DecodeString(saltString)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pwm := NewPasswordManager(salt)
|
||||||
|
server.passwords = &pwm
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not load salt: %s", err.Error())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) setupListeners(config *Config) {
|
||||||
// update or destroy all existing listeners
|
// update or destroy all existing listeners
|
||||||
tlsListeners := config.TLSListeners()
|
tlsListeners := config.TLSListeners()
|
||||||
for addr := range server.listeners {
|
for addr := range server.listeners {
|
||||||
@ -1629,18 +1598,16 @@ func (server *Server) rehash() error {
|
|||||||
currentListener.configMutex.Unlock()
|
currentListener.configMutex.Unlock()
|
||||||
|
|
||||||
if stillConfigured {
|
if stillConfigured {
|
||||||
server.logger.Info("rehash",
|
server.logger.Info("listeners",
|
||||||
fmt.Sprintf("now listening on %s, tls=%t.", addr, (currentListener.tlsConfig != nil)),
|
fmt.Sprintf("now listening on %s, tls=%t.", addr, (currentListener.tlsConfig != nil)),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// tell the listener it should stop by interrupting its Accept() call:
|
// tell the listener it should stop by interrupting its Accept() call:
|
||||||
currentListener.listener.Close()
|
currentListener.listener.Close()
|
||||||
// XXX there is no guarantee from the API when the address will actually
|
// TODO(golang1.10) delete stopEvent once issue #21856 is released
|
||||||
// free for bind(2) again; this solution "seems to work". See here:
|
|
||||||
// https://github.com/golang/go/issues/21833
|
|
||||||
<-currentListener.stopEvent
|
<-currentListener.stopEvent
|
||||||
delete(server.listeners, addr)
|
delete(server.listeners, addr)
|
||||||
server.logger.Info("rehash", fmt.Sprintf("stopped listening on %s.", addr))
|
server.logger.Info("listeners", fmt.Sprintf("stopped listening on %s.", addr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1653,7 +1620,52 @@ func (server *Server) rehash() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
if len(tlsListeners) == 0 {
|
||||||
|
server.logger.Warning("startup", "You are not exposing an SSL/TLS listening port. You should expose at least one port (typically 6697) to accept TLS connections")
|
||||||
|
}
|
||||||
|
|
||||||
|
var usesStandardTLSPort bool
|
||||||
|
for addr := range config.TLSListeners() {
|
||||||
|
if strings.Contains(addr, "6697") {
|
||||||
|
usesStandardTLSPort = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if 0 < len(tlsListeners) && !usesStandardTLSPort {
|
||||||
|
server.logger.Warning("startup", "Port 6697 is the standard TLS port for IRC. You should (also) expose port 6697 as a TLS port to ensure clients can connect securely")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) setupRestAPI(config *Config) {
|
||||||
|
restAPIEnabled := config.Server.RestAPI.Enabled
|
||||||
|
restAPIStarted := server.restAPIServer != nil
|
||||||
|
restAPIListenAddrChanged := server.restAPI.Listen != config.Server.RestAPI.Listen
|
||||||
|
|
||||||
|
// stop an existing REST server if it's been disabled or the addr changed
|
||||||
|
if restAPIStarted && (!restAPIEnabled || restAPIListenAddrChanged) {
|
||||||
|
ctx, _ := context.WithTimeout(context.Background(), httpShutdownTimeout)
|
||||||
|
server.restAPIServer.Shutdown(ctx)
|
||||||
|
server.restAPIServer.Close()
|
||||||
|
server.logger.Info("rehash", "server", fmt.Sprintf("%s rest API stopped on %s.", server.name, server.restAPI.Listen))
|
||||||
|
server.restAPIServer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// start a new one if it's enabled or the addr changed
|
||||||
|
if restAPIEnabled && (!restAPIStarted || restAPIListenAddrChanged) {
|
||||||
|
server.restAPIServer, _ = StartRestAPI(server, config.Server.RestAPI.Listen)
|
||||||
|
server.logger.Info(
|
||||||
|
"rehash", "server",
|
||||||
|
fmt.Sprintf("%s rest API started on %s.", server.name, config.Server.RestAPI.Listen))
|
||||||
|
}
|
||||||
|
|
||||||
|
// save the config information
|
||||||
|
server.restAPI = config.Server.RestAPI
|
||||||
|
}
|
||||||
|
|
||||||
|
func (server *Server) GetDefaultChannelModes() Modes {
|
||||||
|
server.configurableStateMutex.RLock()
|
||||||
|
defer server.configurableStateMutex.RUnlock()
|
||||||
|
return server.defaultChannelModes
|
||||||
}
|
}
|
||||||
|
|
||||||
// REHASH
|
// REHASH
|
||||||
|
@ -132,7 +132,7 @@ Options:
|
|||||||
logger.Warning("startup", "You are currently running an unreleased beta version of Oragono that may be unstable and could corrupt your database.\nIf you are running a production network, please download the latest build from https://oragono.io/downloads.html and run that instead.")
|
logger.Warning("startup", "You are currently running an unreleased beta version of Oragono that may be unstable and could corrupt your database.\nIf you are running a production network, please download the latest build from https://oragono.io/downloads.html and run that instead.")
|
||||||
}
|
}
|
||||||
|
|
||||||
server, err := irc.NewServer(configfile, config, logger)
|
server, err := irc.NewServer(config, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("startup", fmt.Sprintf("Could not load server: %s", err.Error()))
|
logger.Error("startup", fmt.Sprintf("Could not load server: %s", err.Error()))
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user