types: Make ChannelNameMap use mutexes to fix crash

This commit is contained in:
Daniel Oaks 2017-04-17 21:01:39 +10:00
parent ff3a864aa3
commit e0035dfa04
3 changed files with 40 additions and 12 deletions

View File

@ -75,7 +75,7 @@ func restStatus(w http.ResponseWriter, r *http.Request) {
rs := restStatusResp{ rs := restStatusResp{
Clients: restAPIServer.clients.Count(), Clients: restAPIServer.clients.Count(),
Opers: len(restAPIServer.operators), Opers: len(restAPIServer.operators),
Channels: len(restAPIServer.channels), Channels: restAPIServer.channels.Len(),
} }
b, err := json.Marshal(rs) b, err := json.Marshal(rs)
if err != nil { if err != nil {

View File

@ -192,7 +192,7 @@ func NewServer(configFilename string, config *Config, logger *logger.Manager) (*
accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled, accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled,
accounts: make(map[string]*ClientAccount), accounts: make(map[string]*ClientAccount),
channelRegistrationEnabled: config.Channels.Registration.Enabled, channelRegistrationEnabled: config.Channels.Registration.Enabled,
channels: make(ChannelNameMap), channels: NewChannelNameMap(),
checkIdent: config.Server.CheckIdent, checkIdent: config.Server.CheckIdent,
clients: NewClientLookupSet(), clients: NewClientLookupSet(),
commands: make(chan Command), commands: make(chan Command),
@ -1196,9 +1196,11 @@ func whoHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
//} //}
if mask == "" { if mask == "" {
for _, channel := range server.channels { server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
whoChannel(client, channel, friends) whoChannel(client, channel, friends)
} }
server.channels.ChansLock.RUnlock()
} else if mask[0] == '#' { } else if mask[0] == '#' {
// TODO implement wildcard matching // TODO implement wildcard matching
//TODO(dan): ^ only for opers //TODO(dan): ^ only for opers
@ -1748,12 +1750,14 @@ func listHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
} }
if len(channels) == 0 { if len(channels) == 0 {
for _, channel := range server.channels { server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
if !client.flags[Operator] && channel.flags[Secret] { if !client.flags[Operator] && channel.flags[Secret] {
continue continue
} }
client.RplList(channel) client.RplList(channel)
} }
server.channels.ChansLock.RUnlock()
} else { } else {
// limit regular users to only listing one channel // limit regular users to only listing one channel
if !client.flags[Operator] { if !client.flags[Operator] {
@ -1807,9 +1811,11 @@ func namesHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
//} //}
if len(channels) == 0 { if len(channels) == 0 {
for _, channel := range server.channels { server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
channel.Names(client) channel.Names(client)
} }
server.channels.ChansLock.RUnlock()
return false return false
} }
@ -1958,7 +1964,7 @@ func lusersHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
} }
client.Send(nil, server.name, RPL_LUSERCLIENT, client.nick, fmt.Sprintf("There are %d users and %d invisible on %d server(s)", totalcount, invisiblecount, 1)) client.Send(nil, server.name, RPL_LUSERCLIENT, client.nick, fmt.Sprintf("There are %d users and %d invisible on %d server(s)", totalcount, invisiblecount, 1))
client.Send(nil, server.name, RPL_LUSEROP, client.nick, fmt.Sprintf("%d IRC Operators online", opercount)) client.Send(nil, server.name, RPL_LUSEROP, client.nick, fmt.Sprintf("%d IRC Operators online", opercount))
client.Send(nil, server.name, RPL_LUSERCHANNELS, client.nick, fmt.Sprintf("%d channels formed", len(server.channels))) client.Send(nil, server.name, RPL_LUSERCHANNELS, client.nick, fmt.Sprintf("%d channels formed", server.channels.Len()))
client.Send(nil, server.name, RPL_LUSERME, client.nick, fmt.Sprintf("I have %d clients and %d servers", totalcount, 1)) client.Send(nil, server.name, RPL_LUSERME, client.nick, fmt.Sprintf("I have %d clients and %d servers", totalcount, 1))
return false return false
} }

View File

@ -8,38 +8,60 @@ package irc
import ( import (
"fmt" "fmt"
"strings" "strings"
"sync"
) )
// //
// simple types // simple types
// //
type ChannelNameMap map[string]*Channel type ChannelNameMap struct {
ChansLock sync.RWMutex
Chans map[string]*Channel
}
func NewChannelNameMap() ChannelNameMap {
var channels ChannelNameMap
channels.Chans = make(map[string]*Channel)
return channels
}
func (channels ChannelNameMap) Get(name string) *Channel { func (channels ChannelNameMap) Get(name string) *Channel {
name, err := CasefoldChannel(name) name, err := CasefoldChannel(name)
if err == nil { if err == nil {
return channels[name] channels.ChansLock.RLock()
defer channels.ChansLock.RUnlock()
return channels.Chans[name]
} }
return nil return nil
} }
func (channels ChannelNameMap) Add(channel *Channel) error { func (channels ChannelNameMap) Add(channel *Channel) error {
if channels[channel.nameCasefolded] != nil { channels.ChansLock.Lock()
defer channels.ChansLock.Unlock()
if channels.Chans[channel.nameCasefolded] != nil {
return fmt.Errorf("%s: already set", channel.name) return fmt.Errorf("%s: already set", channel.name)
} }
channels[channel.nameCasefolded] = channel channels.Chans[channel.nameCasefolded] = channel
return nil return nil
} }
func (channels ChannelNameMap) Remove(channel *Channel) error { func (channels ChannelNameMap) Remove(channel *Channel) error {
if channel != channels[channel.nameCasefolded] { channels.ChansLock.Lock()
defer channels.ChansLock.Unlock()
if channel != channels.Chans[channel.nameCasefolded] {
return fmt.Errorf("%s: mismatch", channel.name) return fmt.Errorf("%s: mismatch", channel.name)
} }
delete(channels, channel.nameCasefolded) delete(channels.Chans, channel.nameCasefolded)
return nil return nil
} }
func (channels ChannelNameMap) Len() int {
channels.ChansLock.RLock()
defer channels.ChansLock.RUnlock()
return len(channels.Chans)
}
type ModeSet map[Mode]bool type ModeSet map[Mode]bool
func (set ModeSet) String() string { func (set ModeSet) String() string {