3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-12-22 18:52:41 +01:00

remove channelJoinPartMutex

This commit is contained in:
Shivaram Lingamneni 2017-10-30 05:21:47 -04:00
parent d715abf0f0
commit 94cf438f51
7 changed files with 260 additions and 162 deletions

View File

@ -7,7 +7,6 @@ package irc
import (
"fmt"
"log"
"strconv"
"time"
@ -41,7 +40,7 @@ type Channel struct {
func NewChannel(s *Server, name string, addDefaultModes bool) *Channel {
casefoldedName, err := CasefoldChannel(name)
if err != nil {
log.Println(fmt.Sprintf("ERROR: Channel name is bad: [%s]", name), err.Error())
s.logger.Error("internal", fmt.Sprintf("Bad channel name %s: %v", name, err))
return nil
}
@ -59,13 +58,11 @@ func NewChannel(s *Server, name string, addDefaultModes bool) *Channel {
}
if addDefaultModes {
for _, mode := range s.GetDefaultChannelModes() {
for _, mode := range s.DefaultChannelModes() {
channel.flags[mode] = true
}
}
s.channels.Add(channel)
return channel
}
@ -281,6 +278,12 @@ func (channel *Channel) CheckKey(key string) bool {
return (channel.key == "") || (channel.key == key)
}
func (channel *Channel) IsEmpty() bool {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
return len(channel.members) == 0
}
// Join joins the given client to this channel (if they can be joined).
//TODO(dan): /SAJOIN and maybe a ForceJoin function?
func (channel *Channel) Join(client *Client, key string) {
@ -684,16 +687,10 @@ func (channel *Channel) applyModeMask(client *Client, mode Mode, op ModeOp, mask
func (channel *Channel) Quit(client *Client) {
channel.stateMutex.Lock()
channel.members.Remove(client)
empty := len(channel.members) == 0
channel.stateMutex.Unlock()
channel.regenerateMembersCache()
client.removeChannel(channel)
//TODO(slingamn) fold this operation into a channelmanager type
if empty {
channel.server.channels.Remove(channel)
}
}
func (channel *Channel) Kick(client *Client, target *Client, comment string) {

162
irc/channelmanager.go Normal file
View File

@ -0,0 +1,162 @@
// Copyright (c) 2017 Shivaram Lingamneni <slingamn@cs.stanford.edu>
// released under the MIT license
package irc
import (
"errors"
"sync"
)
var (
InvalidChannelName = errors.New("Invalid channel name")
NoSuchChannel = errors.New("No such channel")
ChannelNameInUse = errors.New("Channel name in use")
)
type channelManagerEntry struct {
channel *Channel
// this is a refcount for joins, so we can avoid a race where we incorrectly
// think the channel is empty (without holding a lock across the entire Channel.Join()
// call)
pendingJoins int
}
// ChannelManager keeps track of all the channels on the server,
// providing synchronization for creation of new channels on first join,
// cleanup of empty channels on last part, and renames.
type ChannelManager struct {
sync.RWMutex // tier 2
chans map[string]*channelManagerEntry
}
// NewChannelManager returns a new ChannelManager.
func NewChannelManager() *ChannelManager {
return &ChannelManager{
chans: make(map[string]*channelManagerEntry),
}
}
// Get returns an existing channel with name equivalent to `name`, or nil
func (cm *ChannelManager) Get(name string) *Channel {
name, err := CasefoldChannel(name)
if err == nil {
cm.RLock()
defer cm.RUnlock()
return cm.chans[name].channel
}
return nil
}
// Join causes `client` to join the channel named `name`, creating it if necessary.
func (cm *ChannelManager) Join(client *Client, name string, key string) error {
server := client.server
casefoldedName, err := CasefoldChannel(name)
if err != nil || len(casefoldedName) > server.getLimits().ChannelLen {
return NoSuchChannel
}
cm.Lock()
entry := cm.chans[casefoldedName]
if entry == nil {
entry = &channelManagerEntry{
channel: NewChannel(server, name, true),
pendingJoins: 0,
}
cm.chans[casefoldedName] = entry
}
entry.pendingJoins += 1
cm.Unlock()
entry.channel.Join(client, key)
cm.maybeCleanup(entry, true)
return nil
}
func (cm *ChannelManager) maybeCleanup(entry *channelManagerEntry, afterJoin bool) {
cm.Lock()
defer cm.Unlock()
if entry.channel == nil {
return
}
if afterJoin {
entry.pendingJoins -= 1
}
if entry.channel.IsEmpty() && entry.pendingJoins == 0 {
// reread the name, handling the case where the channel was renamed
casefoldedName := entry.channel.NameCasefolded()
delete(cm.chans, casefoldedName)
// invalidate the entry (otherwise, a subsequent cleanup attempt could delete
// a valid, distinct entry under casefoldedName):
entry.channel = nil
}
}
// Part parts `client` from the channel named `name`, deleting it if it's empty.
func (cm *ChannelManager) Part(client *Client, name string, message string) error {
casefoldedName, err := CasefoldChannel(name)
if err != nil {
return NoSuchChannel
}
cm.RLock()
entry := cm.chans[casefoldedName]
cm.RUnlock()
if entry == nil {
return NoSuchChannel
}
entry.channel.Part(client, message)
cm.maybeCleanup(entry, false)
return nil
}
// Rename renames a channel (but does not notify the members)
func (cm *ChannelManager) Rename(name string, newname string) error {
cfname, err := CasefoldChannel(name)
if err != nil {
return NoSuchChannel
}
cfnewname, err := CasefoldChannel(newname)
if err != nil {
return InvalidChannelName
}
cm.Lock()
defer cm.Unlock()
if cm.chans[cfnewname] != nil {
return ChannelNameInUse
}
entry := cm.chans[cfname]
if entry == nil {
return NoSuchChannel
}
delete(cm.chans, cfname)
cm.chans[cfnewname] = entry
entry.channel.setName(newname)
entry.channel.setNameCasefolded(cfnewname)
return nil
}
// Len returns the number of channels
func (cm *ChannelManager) Len() int {
cm.RLock()
defer cm.RUnlock()
return len(cm.chans)
}
// Channels returns a slice containing all current channels
func (cm *ChannelManager) Channels() (result []*Channel) {
cm.RLock()
defer cm.RUnlock()
for _, entry := range cm.chans {
result = append(result, entry.channel)
}
return
}

View File

@ -548,14 +548,12 @@ func (client *Client) destroy() {
client.server.monitorManager.RemoveAll(client)
// clean up channels
client.server.channelJoinPartMutex.Lock()
for channel := range client.channels {
for _, channel := range client.Channels() {
channel.Quit(client)
for _, member := range channel.Members() {
friends.Add(member)
}
}
client.server.channelJoinPartMutex.Unlock()
// clean up server
client.server.clients.Remove(client)

View File

@ -41,6 +41,12 @@ func (server *Server) WebIRCConfig() []webircConfig {
return server.webirc
}
func (server *Server) DefaultChannelModes() Modes {
server.configurableStateMutex.RLock()
defer server.configurableStateMutex.RUnlock()
return server.defaultChannelModes
}
func (client *Client) getNick() string {
client.stateMutex.RLock()
defer client.stateMutex.RUnlock()
@ -114,6 +120,24 @@ func (channel *Channel) Name() string {
return channel.name
}
func (channel *Channel) setName(name string) {
channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
channel.name = name
}
func (channel *Channel) NameCasefolded() string {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()
return channel.nameCasefolded
}
func (channel *Channel) setNameCasefolded(nameCasefolded string) {
channel.stateMutex.Lock()
defer channel.stateMutex.Unlock()
channel.nameCasefolded = nameCasefolded
}
func (channel *Channel) Members() (result []*Client) {
channel.stateMutex.RLock()
defer channel.stateMutex.RUnlock()

View File

@ -52,12 +52,9 @@ func (manager *MonitorManager) AlertAbout(client *Client, online bool) {
command = RPL_MONONLINE
}
// asynchronously send all the notifications
go func() {
for _, mClient := range watchers {
mClient.Send(nil, client.server.name, command, mClient.getNick(), nick)
}
}()
for _, mClient := range watchers {
mClient.Send(nil, client.server.name, command, mClient.getNick(), nick)
}
}
// Add registers `client` to receive notifications about `nick`.

View File

@ -9,6 +9,7 @@ import (
"bufio"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"log"
"math/rand"
@ -39,6 +40,8 @@ var (
// common error responses
couldNotParseIPMsg, _ = (&[]ircmsg.IrcMessage{ircmsg.MakeMessage(nil, "", "ERROR", "Unable to parse your IP address")}[0]).Line()
RenamePrivsNeeded = errors.New("Only chanops can rename channels")
)
const (
@ -80,8 +83,7 @@ type Server struct {
accountRegistration *AccountRegistration
accounts map[string]*ClientAccount
channelRegistrationEnabled bool
channels ChannelNameMap
channelJoinPartMutex sync.Mutex // used when joining/parting channels to prevent stomping over each others' access and all
channels *ChannelManager
checkIdent bool
clients *ClientLookupSet
commands chan Command
@ -147,7 +149,7 @@ func NewServer(config *Config, logger *logger.Manager) (*Server, error) {
// initialize data structures
server := &Server{
accounts: make(map[string]*ClientAccount),
channels: *NewChannelNameMap(),
channels: NewChannelManager(),
clients: NewClientLookupSet(),
commands: make(chan Command),
connectionLimiter: connection_limits.NewLimiter(),
@ -553,53 +555,62 @@ func pongHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
}
// RENAME <oldchan> <newchan> [<reason>]
//TODO(dan): Clean up this function so it doesn't look like an eldrich horror... prolly by putting it into a server.renameChannel function.
func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
// get lots of locks... make sure nobody touches anything while we're doing this
func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) (result bool) {
result = false
// TODO(slingamn, #152) clean up locking here
server.registeredChannelsMutex.Lock()
defer server.registeredChannelsMutex.Unlock()
server.channels.ChansLock.Lock()
defer server.channels.ChansLock.Unlock()
errorResponse := func(err error, name string) {
// TODO: send correct error codes, e.g., ERR_CANNOTRENAME, ERR_CHANNAMEINUSE
var code string
switch err {
case NoSuchChannel:
code = ERR_NOSUCHCHANNEL
case RenamePrivsNeeded:
code = ERR_CHANOPRIVSNEEDED
case InvalidChannelName:
code = ERR_UNKNOWNERROR
case ChannelNameInUse:
code = ERR_UNKNOWNERROR
default:
code = ERR_UNKNOWNERROR
}
client.Send(nil, server.name, code, client.getNick(), "RENAME", name, err.Error())
}
oldName := strings.TrimSpace(msg.Params[0])
newName := strings.TrimSpace(msg.Params[1])
if oldName == "" || newName == "" {
errorResponse(InvalidChannelName, "<empty>")
return
}
casefoldedOldName, err := CasefoldChannel(oldName)
if err != nil {
errorResponse(InvalidChannelName, oldName)
return
}
casefoldedNewName, err := CasefoldChannel(newName)
if err != nil {
errorResponse(InvalidChannelName, newName)
return
}
reason := "No reason"
if 2 < len(msg.Params) {
reason = msg.Params[2]
}
// check for all the reasons why the rename couldn't happen
casefoldedOldName, err := CasefoldChannel(oldName)
if err != nil {
//TODO(dan): Change this to ERR_CANNOTRENAME
client.Send(nil, server.name, ERR_UNKNOWNERROR, client.nick, "RENAME", oldName, "Old channel name is invalid")
return false
}
channel := server.channels.Chans[casefoldedOldName]
channel := server.channels.Get(oldName)
if channel == nil {
client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, oldName, "No such channel")
return false
errorResponse(NoSuchChannel, oldName)
return
}
//TODO(dan): allow IRCops to do this?
if !channel.ClientIsAtLeast(client, Operator) {
client.Send(nil, server.name, ERR_CHANOPRIVSNEEDED, client.nick, oldName, "Only chanops can rename channels")
return false
}
casefoldedNewName, err := CasefoldChannel(newName)
if err != nil {
//TODO(dan): Change this to ERR_CANNOTRENAME
client.Send(nil, server.name, ERR_UNKNOWNERROR, client.nick, "RENAME", newName, "New channel name is invalid")
return false
}
newChannel := server.channels.Chans[casefoldedNewName]
if newChannel != nil {
//TODO(dan): Change this to ERR_CHANNAMEINUSE
client.Send(nil, server.name, ERR_UNKNOWNERROR, client.nick, "RENAME", newName, "New channel name is in use")
return false
errorResponse(RenamePrivsNeeded, oldName)
return
}
var canEdit bool
@ -622,11 +633,11 @@ func renameHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
}
// perform the channel rename
server.channels.Chans[casefoldedOldName] = nil
server.channels.Chans[casefoldedNewName] = channel
channel.name = strings.TrimSpace(msg.Params[1])
channel.nameCasefolded = casefoldedNewName
err = server.channels.Rename(oldName, newName)
if err != nil {
errorResponse(err, newName)
return
}
// rename stored channel info if any exists
server.store.Update(func(tx *buntdb.Tx) error {
@ -679,34 +690,15 @@ func joinHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
keys = strings.Split(msg.Params[1], ",")
}
// get lock
server.channelJoinPartMutex.Lock()
defer server.channelJoinPartMutex.Unlock()
for i, name := range channels {
casefoldedName, err := CasefoldChannel(name)
if err != nil {
if len(name) > 0 {
client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, name, "No such channel")
}
continue
}
channel := server.channels.Get(casefoldedName)
if channel == nil {
if len(casefoldedName) > server.getLimits().ChannelLen {
client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, name, "No such channel")
continue
}
channel = NewChannel(server, name, true)
}
var key string
if len(keys) > i {
key = keys[i]
}
channel.Join(client, key)
err := server.channels.Join(client, name, key)
if err == NoSuchChannel {
client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.getNick(), name, "No such channel")
}
}
return false
}
@ -719,22 +711,11 @@ func partHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
reason = msg.Params[1]
}
// get lock
server.channelJoinPartMutex.Lock()
defer server.channelJoinPartMutex.Unlock()
for _, chname := range channels {
casefoldedChannelName, err := CasefoldChannel(chname)
channel := server.channels.Get(casefoldedChannelName)
if err != nil || channel == nil {
if len(chname) > 0 {
client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, chname, "No such channel")
}
continue
err := server.channels.Part(client, chname, reason)
if err == NoSuchChannel {
client.Send(nil, server.name, ERR_NOSUCHCHANNEL, client.nick, chname, "No such channel")
}
channel.Part(client, reason)
}
return false
}
@ -1096,11 +1077,9 @@ func whoHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
//}
if mask == "" {
server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
for _, channel := range server.channels.Channels() {
whoChannel(client, channel, friends)
}
server.channels.ChansLock.RUnlock()
} else if mask[0] == '#' {
// TODO implement wildcard matching
//TODO(dan): ^ only for opers
@ -1859,8 +1838,7 @@ func listHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
}
if len(channels) == 0 {
server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
for _, channel := range server.channels.Channels() {
if !client.flags[Operator] && channel.flags[Secret] {
continue
}
@ -1868,7 +1846,6 @@ func listHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
client.RplList(channel)
}
}
server.channels.ChansLock.RUnlock()
} else {
// limit regular users to only listing one channel
if !client.flags[Operator] {
@ -1922,11 +1899,9 @@ func namesHandler(server *Server, client *Client, msg ircmsg.IrcMessage) bool {
//}
if len(channels) == 0 {
server.channels.ChansLock.RLock()
for _, channel := range server.channels.Chans {
for _, channel := range server.channels.Channels() {
channel.Names(client)
}
server.channels.ChansLock.RUnlock()
return false
}

View File

@ -6,64 +6,9 @@
package irc
import (
"fmt"
"strings"
"sync"
)
// ChannelNameMap is a map that converts channel names to actual channel objects.
type ChannelNameMap struct {
ChansLock sync.RWMutex
Chans map[string]*Channel
}
// NewChannelNameMap returns a new ChannelNameMap.
func NewChannelNameMap() *ChannelNameMap {
var channels ChannelNameMap
channels.Chans = make(map[string]*Channel)
return &channels
}
// Get returns the given channel if it exists.
func (channels *ChannelNameMap) Get(name string) *Channel {
name, err := CasefoldChannel(name)
if err == nil {
channels.ChansLock.RLock()
defer channels.ChansLock.RUnlock()
return channels.Chans[name]
}
return nil
}
// Add adds the given channel to our map.
func (channels *ChannelNameMap) Add(channel *Channel) error {
channels.ChansLock.Lock()
defer channels.ChansLock.Unlock()
if channels.Chans[channel.nameCasefolded] != nil {
return fmt.Errorf("%s: already set", channel.name)
}
channels.Chans[channel.nameCasefolded] = channel
return nil
}
// Remove removes the given channel from our map.
func (channels *ChannelNameMap) Remove(channel *Channel) error {
channels.ChansLock.Lock()
defer channels.ChansLock.Unlock()
if channel != channels.Chans[channel.nameCasefolded] {
return fmt.Errorf("%s: mismatch", channel.name)
}
delete(channels.Chans, channel.nameCasefolded)
return nil
}
// Len returns how many channels we have.
func (channels *ChannelNameMap) Len() int {
channels.ChansLock.RLock()
defer channels.ChansLock.RUnlock()
return len(channels.Chans)
}
// ModeSet holds a set of modes.
type ModeSet map[Mode]bool