3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-10 22:19:31 +01:00

types: Make ChannelNameMap use mutexes to fix crash

This commit is contained in:
Daniel Oaks 2017-04-17 21:01:39 +10:00
parent ff3a864aa3
commit e0035dfa04
3 changed files with 40 additions and 12 deletions

View File

@ -75,7 +75,7 @@ func restStatus(w http.ResponseWriter, r *http.Request) {
rs := restStatusResp{
Clients: restAPIServer.clients.Count(),
Opers: len(restAPIServer.operators),
Channels: len(restAPIServer.channels),
Channels: restAPIServer.channels.Len(),
}
b, err := json.Marshal(rs)
if err != nil {

View File

@ -192,7 +192,7 @@ func NewServer(configFilename string, config *Config, logger *logger.Manager) (*
accountAuthenticationEnabled: config.Accounts.AuthenticationEnabled,
accounts: make(map[string]*ClientAccount),
channelRegistrationEnabled: config.Channels.Registration.Enabled,
channels: make(ChannelNameMap),
channels: NewChannelNameMap(),
checkIdent: config.Server.CheckIdent,
clients: NewClientLookupSet(),
commands: make(chan Command),
@ -1196,9 +1196,11 @@ func whoHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
//}
if mask == "" {
for _, channel := range server.channels {
server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
whoChannel(client, channel, friends)
}
server.channels.ChansLock.RUnlock()
} else if mask[0] == '#' {
// TODO implement wildcard matching
//TODO(dan): ^ only for opers
@ -1748,12 +1750,14 @@ func listHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
}
if len(channels) == 0 {
for _, channel := range server.channels {
server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
if !client.flags[Operator] && channel.flags[Secret] {
continue
}
client.RplList(channel)
}
server.channels.ChansLock.RUnlock()
} else {
// limit regular users to only listing one channel
if !client.flags[Operator] {
@ -1807,9 +1811,11 @@ func namesHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
//}
if len(channels) == 0 {
for _, channel := range server.channels {
server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
channel.Names(client)
}
server.channels.ChansLock.RUnlock()
return false
}
@ -1958,7 +1964,7 @@ func lusersHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
}
client.Send(nil, server.name, RPL_LUSERCLIENT, client.nick, fmt.Sprintf("There are %d users and %d invisible on %d server(s)", totalcount, invisiblecount, 1))
client.Send(nil, server.name, RPL_LUSEROP, client.nick, fmt.Sprintf("%d IRC Operators online", opercount))
client.Send(nil, server.name, RPL_LUSERCHANNELS, client.nick, fmt.Sprintf("%d channels formed", len(server.channels)))
client.Send(nil, server.name, RPL_LUSERCHANNELS, client.nick, fmt.Sprintf("%d channels formed", server.channels.Len()))
client.Send(nil, server.name, RPL_LUSERME, client.nick, fmt.Sprintf("I have %d clients and %d servers", totalcount, 1))
return false
}

View File

@ -8,38 +8,60 @@ package irc
import (
"fmt"
"strings"
"sync"
)
//
// simple types
//
type ChannelNameMap map[string]*Channel
type ChannelNameMap struct {
ChansLock sync.RWMutex
Chans map[string]*Channel
}
func NewChannelNameMap() ChannelNameMap {
var channels ChannelNameMap
channels.Chans = make(map[string]*Channel)
return channels
}
func (channels ChannelNameMap) Get(name string) *Channel {
name, err := CasefoldChannel(name)
if err == nil {
return channels[name]
channels.ChansLock.RLock()
defer channels.ChansLock.RUnlock()
return channels.Chans[name]
}
return nil
}
func (channels ChannelNameMap) Add(channel *Channel) error {
if channels[channel.nameCasefolded] != nil {
channels.ChansLock.Lock()
defer channels.ChansLock.Unlock()
if channels.Chans[channel.nameCasefolded] != nil {
return fmt.Errorf("%s: already set", channel.name)
}
channels[channel.nameCasefolded] = channel
channels.Chans[channel.nameCasefolded] = channel
return nil
}
func (channels ChannelNameMap) Remove(channel *Channel) error {
if channel != channels[channel.nameCasefolded] {
channels.ChansLock.Lock()
defer channels.ChansLock.Unlock()
if channel != channels.Chans[channel.nameCasefolded] {
return fmt.Errorf("%s: mismatch", channel.name)
}
delete(channels, channel.nameCasefolded)
delete(channels.Chans, channel.nameCasefolded)
return nil
}
func (channels ChannelNameMap) Len() int {
channels.ChansLock.RLock()
defer channels.ChansLock.RUnlock()
return len(channels.Chans)
}
type ModeSet map[Mode]bool
func (set ModeSet) String() string {