diff --git a/irc/rest_api.go b/irc/rest_api.go index 446a65ef..012b4109 100644 --- a/irc/rest_api.go +++ b/irc/rest_api.go @@ -75,7 +75,7 @@ func restStatus(w http.ResponseWriter, r *http.Request) { rs := restStatusResp{ Clients: restAPIServer.clients.Count(), Opers: len(restAPIServer.operators), - Channels: len(restAPIServer.channels), + Channels: restAPIServer.channels.Len(), } b, err := json.Marshal(rs) if err != nil { diff --git a/irc/server.go b/irc/server.go index f7f52980..0921433b 100644 --- a/irc/server.go +++ b/irc/server.go @@ -192,7 +192,7 @@ func NewServer(configFilename string, config *Config, logger *logger.Manager) (* accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled, accounts: make(map[string]*ClientAccount), channelRegistrationEnabled: config.Channels.Registration.Enabled, - channels: make(ChannelNameMap), + channels: NewChannelNameMap(), checkIdent: config.Server.CheckIdent, clients: NewClientLookupSet(), commands: make(chan Command), @@ -1196,9 +1196,11 @@ func whoHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { //} if mask == "" { - for _, channel := range server.channels { + server.channels.ChansLock.RLock() + for _, channel := range server.channels.Chans { whoChannel(client, channel, friends) } + server.channels.ChansLock.RUnlock() } else if mask[0] == '#' { // TODO implement wildcard matching //TODO(dan): ^ only for opers @@ -1748,12 +1750,14 @@ func listHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { } if len(channels) == 0 { - for _, channel := range server.channels { + server.channels.ChansLock.RLock() + for _, channel := range server.channels.Chans { if !client.flags[Operator] && channel.flags[Secret] { continue } client.RplList(channel) } + server.channels.ChansLock.RUnlock() } else { // limit regular users to only listing one channel if !client.flags[Operator] { @@ -1807,9 +1811,11 @@ func namesHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { //} if len(channels) == 0 { - for _, channel := range server.channels { + server.channels.ChansLock.RLock() + for _, channel := range server.channels.Chans { channel.Names(client) } + server.channels.ChansLock.RUnlock() return false } @@ -1958,7 +1964,7 @@ func lusersHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool { } client.Send(nil, server.name, RPL_LUSERCLIENT, client.nick, fmt.Sprintf("There are %d users and %d invisible on %d server(s)", totalcount, invisiblecount, 1)) client.Send(nil, server.name, RPL_LUSEROP, client.nick, fmt.Sprintf("%d IRC Operators online", opercount)) - client.Send(nil, server.name, RPL_LUSERCHANNELS, client.nick, fmt.Sprintf("%d channels formed", len(server.channels))) + client.Send(nil, server.name, RPL_LUSERCHANNELS, client.nick, fmt.Sprintf("%d channels formed", server.channels.Len())) client.Send(nil, server.name, RPL_LUSERME, client.nick, fmt.Sprintf("I have %d clients and %d servers", totalcount, 1)) return false } diff --git a/irc/types.go b/irc/types.go index ce0a8380..b6780436 100644 --- a/irc/types.go +++ b/irc/types.go @@ -8,38 +8,60 @@ package irc import ( "fmt" "strings" + "sync" ) // // simple types // -type ChannelNameMap map[string]*Channel +type ChannelNameMap struct { + ChansLock sync.RWMutex + Chans map[string]*Channel +} + +func NewChannelNameMap() ChannelNameMap { + var channels ChannelNameMap + channels.Chans = make(map[string]*Channel) + return channels +} func (channels ChannelNameMap) Get(name string) *Channel { name, err := CasefoldChannel(name) if err == nil { - return channels[name] + channels.ChansLock.RLock() + defer channels.ChansLock.RUnlock() + return channels.Chans[name] } return nil } func (channels ChannelNameMap) Add(channel *Channel) error { - if channels[channel.nameCasefolded] != nil { + channels.ChansLock.Lock() + defer channels.ChansLock.Unlock() + if channels.Chans[channel.nameCasefolded] != nil { return fmt.Errorf("%s: already set", channel.name) } - channels[channel.nameCasefolded] = channel + channels.Chans[channel.nameCasefolded] = channel return nil } func (channels ChannelNameMap) Remove(channel *Channel) error { - if channel != channels[channel.nameCasefolded] { + channels.ChansLock.Lock() + defer channels.ChansLock.Unlock() + if channel != channels.Chans[channel.nameCasefolded] { return fmt.Errorf("%s: mismatch", channel.name) } - delete(channels, channel.nameCasefolded) + delete(channels.Chans, channel.nameCasefolded) return nil } +func (channels ChannelNameMap) Len() int { + channels.ChansLock.RLock() + defer channels.ChansLock.RUnlock() + return len(channels.Chans) +} + type ModeSet map[Mode]bool func (set ModeSet) String() string {