3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-26 05:49:25 +01:00

fix various data races, including 2 introduced by #139

This commit is contained in:
Shivaram Lingamneni 2017-10-02 04:42:50 -04:00
parent 58fb997e77
commit 23a66fa502
5 changed files with 82 additions and 43 deletions

View File

@ -94,7 +94,7 @@ func NewClient(server *Server, conn net.Conn, isTLS bool) *Client {
go socket.RunSocketWriter() go socket.RunSocketWriter()
client := &Client{ client := &Client{
atime: now, atime: now,
authorized: server.password == nil, authorized: server.getPassword() == nil,
capabilities: caps.NewSet(), capabilities: caps.NewSet(),
capState: CapNone, capState: CapNone,
capVersion: caps.Cap301, capVersion: caps.Cap301,
@ -182,10 +182,11 @@ func (client *Client) maxlens() (int, int) {
maxlenTags = 4096 maxlenTags = 4096
} }
if client.capabilities.Has(caps.MaxLine) { if client.capabilities.Has(caps.MaxLine) {
if client.server.limits.LineLen.Tags > maxlenTags { limits := client.server.getLimits()
maxlenTags = client.server.limits.LineLen.Tags if limits.LineLen.Tags > maxlenTags {
maxlenTags = limits.LineLen.Tags
} }
maxlenRest = client.server.limits.LineLen.Rest maxlenRest = limits.LineLen.Rest
} }
return maxlenTags, maxlenRest return maxlenTags, maxlenRest
} }
@ -679,7 +680,7 @@ func (client *Client) Send(tags *map[string]ircmsg.TagValue, prefix string, comm
func (client *Client) Notice(text string) { func (client *Client) Notice(text string) {
limit := 400 limit := 400
if client.capabilities.Has(caps.MaxLine) { if client.capabilities.Has(caps.MaxLine) {
limit = client.server.limits.LineLen.Rest - 110 limit = client.server.getLimits().LineLen.Rest - 110
} }
lines := wordWrap(text, limit) lines := wordWrap(text, limit)

22
irc/getters.go Normal file
View File

@ -0,0 +1,22 @@
// Copyright (c) 2017 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package irc
func (server *Server) getISupport() *ISupportList {
server.configurableStateMutex.RLock()
defer server.configurableStateMutex.RUnlock()
return server.isupport
}
func (server *Server) getLimits() Limits {
server.configurableStateMutex.RLock()
defer server.configurableStateMutex.RUnlock()
return server.limits
}
func (server *Server) getPassword() []byte {
server.configurableStateMutex.RLock()
defer server.configurableStateMutex.RUnlock()
return server.password
}

View File

@ -142,7 +142,7 @@ func (il *ISupportList) RegenerateCachedReply() {
// RplISupport outputs our ISUPPORT lines to the client. This is used on connection and in VERSION responses. // RplISupport outputs our ISUPPORT lines to the client. This is used on connection and in VERSION responses.
func (client *Client) RplISupport() { func (client *Client) RplISupport() {
for _, tokenline := range client.server.isupport.CachedReply { for _, tokenline := range client.server.getISupport().CachedReply {
// ugly trickery ahead // ugly trickery ahead
client.Send(nil, client.server.name, RPL_ISUPPORT, append([]string{client.nick}, tokenline...)...) client.Send(nil, client.server.name, RPL_ISUPPORT, append([]string{client.nick}, tokenline...)...)
} }

View File

@ -34,7 +34,7 @@ func nickHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
return false return false
} }
if err != nil || len(nicknameRaw) > server.limits.NickLen || restrictedNicknames[nickname] { if err != nil || len(nicknameRaw) > server.getLimits().NickLen || restrictedNicknames[nickname] {
client.Send(nil, server.name, ERR_ERRONEUSNICKNAME, client.nick, nicknameRaw, "Erroneous nickname") client.Send(nil, server.name, ERR_ERRONEUSNICKNAME, client.nick, nicknameRaw, "Erroneous nickname")
return false return false
} }

View File

@ -182,29 +182,31 @@ func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
func (server *Server) setISupport() { func (server *Server) setISupport() {
maxTargetsString := strconv.Itoa(maxTargets) maxTargetsString := strconv.Itoa(maxTargets)
server.configurableStateMutex.RLock()
// add RPL_ISUPPORT tokens // add RPL_ISUPPORT tokens
server.isupport = NewISupportList() isupport := NewISupportList()
server.isupport.Add("AWAYLEN", strconv.Itoa(server.limits.AwayLen)) isupport.Add("AWAYLEN", strconv.Itoa(server.limits.AwayLen))
server.isupport.Add("CASEMAPPING", casemappingName) isupport.Add("CASEMAPPING", casemappingName)
server.isupport.Add("CHANMODES", strings.Join([]string{Modes{BanMask, ExceptMask, InviteMask}.String(), "", Modes{UserLimit, Key}.String(), Modes{InviteOnly, Moderated, NoOutside, OpOnlyTopic, ChanRoleplaying, Secret}.String()}, ",")) isupport.Add("CHANMODES", strings.Join([]string{Modes{BanMask, ExceptMask, InviteMask}.String(), "", Modes{UserLimit, Key}.String(), Modes{InviteOnly, Moderated, NoOutside, OpOnlyTopic, ChanRoleplaying, Secret}.String()}, ","))
server.isupport.Add("CHANNELLEN", strconv.Itoa(server.limits.ChannelLen)) isupport.Add("CHANNELLEN", strconv.Itoa(server.limits.ChannelLen))
server.isupport.Add("CHANTYPES", "#") isupport.Add("CHANTYPES", "#")
server.isupport.Add("ELIST", "U") isupport.Add("ELIST", "U")
server.isupport.Add("EXCEPTS", "") isupport.Add("EXCEPTS", "")
server.isupport.Add("INVEX", "") isupport.Add("INVEX", "")
server.isupport.Add("KICKLEN", strconv.Itoa(server.limits.KickLen)) isupport.Add("KICKLEN", strconv.Itoa(server.limits.KickLen))
server.isupport.Add("MAXLIST", fmt.Sprintf("beI:%s", strconv.Itoa(server.limits.ChanListModes))) isupport.Add("MAXLIST", fmt.Sprintf("beI:%s", strconv.Itoa(server.limits.ChanListModes)))
server.isupport.Add("MAXTARGETS", maxTargetsString) isupport.Add("MAXTARGETS", maxTargetsString)
server.isupport.Add("MODES", "") isupport.Add("MODES", "")
server.isupport.Add("MONITOR", strconv.Itoa(server.limits.MonitorEntries)) isupport.Add("MONITOR", strconv.Itoa(server.limits.MonitorEntries))
server.isupport.Add("NETWORK", server.networkName) isupport.Add("NETWORK", server.networkName)
server.isupport.Add("NICKLEN", strconv.Itoa(server.limits.NickLen)) isupport.Add("NICKLEN", strconv.Itoa(server.limits.NickLen))
server.isupport.Add("PREFIX", "(qaohv)~&@%+") isupport.Add("PREFIX", "(qaohv)~&@%+")
server.isupport.Add("RPCHAN", "E") isupport.Add("RPCHAN", "E")
server.isupport.Add("RPUSER", "E") isupport.Add("RPUSER", "E")
server.isupport.Add("STATUSMSG", "~&@%+") isupport.Add("STATUSMSG", "~&@%+")
server.isupport.Add("TARGMAX", fmt.Sprintf("NAMES:1,LIST:1,KICK:1,WHOIS:1,USERHOST:10,PRIVMSG:%s,TAGMSG:%s,NOTICE:%s,MONITOR:", maxTargetsString, maxTargetsString, maxTargetsString)) isupport.Add("TARGMAX", fmt.Sprintf("NAMES:1,LIST:1,KICK:1,WHOIS:1,USERHOST:10,PRIVMSG:%s,TAGMSG:%s,NOTICE:%s,MONITOR:", maxTargetsString, maxTargetsString, maxTargetsString))
server.isupport.Add("TOPICLEN", strconv.Itoa(server.limits.TopicLen)) isupport.Add("TOPICLEN", strconv.Itoa(server.limits.TopicLen))
// account registration // account registration
if server.accountRegistration.Enabled { if server.accountRegistration.Enabled {
@ -216,12 +218,18 @@ func (server *Server) setISupport() {
} }
} }
server.isupport.Add("REGCOMMANDS", "CREATE,VERIFY") isupport.Add("REGCOMMANDS", "CREATE,VERIFY")
server.isupport.Add("REGCALLBACKS", strings.Join(enabledCallbacks, ",")) isupport.Add("REGCALLBACKS", strings.Join(enabledCallbacks, ","))
server.isupport.Add("REGCREDTYPES", "passphrase,certfp") isupport.Add("REGCREDTYPES", "passphrase,certfp")
} }
server.isupport.RegenerateCachedReply() server.configurableStateMutex.RUnlock()
isupport.RegenerateCachedReply()
server.configurableStateMutex.Lock()
server.isupport = isupport
server.configurableStateMutex.Unlock()
} }
func loadChannelList(channel *Channel, list string, maskMode Mode) { func loadChannelList(channel *Channel, list string, maskMode Mode) {
@ -440,15 +448,16 @@ func (server *Server) tryRegister(c *Client) {
// MOTD serves the Message of the Day. // MOTD serves the Message of the Day.
func (server *Server) MOTD(client *Client) { func (server *Server) MOTD(client *Client) {
server.configurableStateMutex.RLock() server.configurableStateMutex.RLock()
defer server.configurableStateMutex.RUnlock() motdLines := server.motdLines
server.configurableStateMutex.RUnlock()
if len(server.motdLines) < 1 { if len(motdLines) < 1 {
client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing") client.Send(nil, server.name, ERR_NOMOTD, client.nick, "MOTD File is missing")
return return
} }
client.Send(nil, server.name, RPL_MOTDSTART, client.nick, fmt.Sprintf("- %s Message of the day - ", server.name)) client.Send(nil, server.name, RPL_MOTDSTART, client.nick, fmt.Sprintf("- %s Message of the day - ", server.name))
for _, line := range server.motdLines { for _, line := range motdLines {
client.Send(nil, server.name, RPL_MOTD, client.nick, line) client.Send(nil, server.name, RPL_MOTD, client.nick, line)
} }
client.Send(nil, server.name, RPL_ENDOFMOTD, client.nick, "End of MOTD command") client.Send(nil, server.name, RPL_ENDOFMOTD, client.nick, "End of MOTD command")
@ -691,7 +700,7 @@ func joinHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
channel := server.channels.Get(casefoldedName) channel := server.channels.Get(casefoldedName)
if channel == nil { if channel == nil {
if len(casefoldedName) > server.limits.ChannelLen { if len(casefoldedName) > server.getLimits().ChannelLen {
client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, name, "No such channel") client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, name, "No such channel")
continue continue
} }
@ -1257,15 +1266,19 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
// sanity checks complete, start modifying server state // sanity checks complete, start modifying server state
server.name = config.Server.Name if initial {
server.nameCasefolded = casefoldedName server.name = config.Server.Name
server.nameCasefolded = casefoldedName
}
server.networkName = config.Network.Name server.networkName = config.Network.Name
server.configurableStateMutex.Lock()
if config.Server.Password != "" { if config.Server.Password != "" {
server.password = config.Server.PasswordBytes() server.password = config.Server.PasswordBytes()
} else { } else {
server.password = nil server.password = nil
} }
server.configurableStateMutex.Unlock()
// apply new PROXY command restrictions // apply new PROXY command restrictions
server.proxyAllowedFrom = config.Server.ProxyAllowedFrom server.proxyAllowedFrom = config.Server.ProxyAllowedFrom
@ -1372,6 +1385,7 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
} }
// set server options // set server options
server.configurableStateMutex.Lock()
lineLenConfig := LineLenLimits{ lineLenConfig := LineLenLimits{
Tags: config.Limits.LineLen.Tags, Tags: config.Limits.LineLen.Tags,
Rest: config.Limits.LineLen.Rest, Rest: config.Limits.LineLen.Rest,
@ -1395,13 +1409,14 @@ func (server *Server) applyConfig(config *Config, initial bool) error {
server.accountRegistration = &accountReg server.accountRegistration = &accountReg
server.channelRegistrationEnabled = config.Channels.Registration.Enabled server.channelRegistrationEnabled = config.Channels.Registration.Enabled
server.configurableStateMutex.Lock()
server.defaultChannelModes = ParseDefaultChannelModes(config) server.defaultChannelModes = ParseDefaultChannelModes(config)
server.configurableStateMutex.Unlock() server.configurableStateMutex.Unlock()
// set new sendqueue size // set new sendqueue size
if config.Server.MaxSendQBytes != server.MaxSendQBytes { if config.Server.MaxSendQBytes != server.MaxSendQBytes {
server.configurableStateMutex.Lock()
server.MaxSendQBytes = config.Server.MaxSendQBytes server.MaxSendQBytes = config.Server.MaxSendQBytes
server.configurableStateMutex.Unlock()
// update on all clients // update on all clients
server.clients.ByNickMutex.RLock() server.clients.ByNickMutex.RLock()
@ -1469,8 +1484,8 @@ func (server *Server) loadMOTD(motdPath string) error {
} }
server.configurableStateMutex.Lock() server.configurableStateMutex.Lock()
defer server.configurableStateMutex.Unlock()
server.motdLines = motdLines server.motdLines = motdLines
server.configurableStateMutex.Unlock()
return nil return nil
} }
@ -1628,8 +1643,9 @@ func awayHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
if len(msg.Params) > 0 { if len(msg.Params) > 0 {
isAway = true isAway = true
text = msg.Params[0] text = msg.Params[0]
if len(text) > server.limits.AwayLen { awayLen := server.getLimits().AwayLen
text = text[:server.limits.AwayLen] if len(text) > awayLen {
text = text[:awayLen]
} }
} }