3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-22 03:49:27 +01:00

upgrade to go 1.18, use generics

This commit is contained in:
Shivaram Lingamneni 2022-03-30 00:44:51 -04:00
parent 446c654dea
commit a549827f17
15 changed files with 60 additions and 67 deletions

2
go.mod
View File

@ -1,6 +1,6 @@
module github.com/ergochat/ergo module github.com/ergochat/ergo
go 1.17 go 1.18
require ( require (
code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48 code.cloudfoundry.org/bytefmt v0.0.0-20200131002437-cf55d5288a48

View File

@ -177,10 +177,7 @@ func (channel *Channel) ExportRegistration(includeFlags uint) (info RegisteredCh
info.Bans = channel.lists[modes.BanMask].Masks() info.Bans = channel.lists[modes.BanMask].Masks()
info.Invites = channel.lists[modes.InviteMask].Masks() info.Invites = channel.lists[modes.InviteMask].Masks()
info.Excepts = channel.lists[modes.ExceptMask].Masks() info.Excepts = channel.lists[modes.ExceptMask].Masks()
info.AccountToUMode = make(map[string]modes.Mode) info.AccountToUMode = utils.CopyMap(channel.accountToUMode)
for account, mode := range channel.accountToUMode {
info.AccountToUMode[account] = mode
}
} }
if includeFlags&IncludeSettings != 0 { if includeFlags&IncludeSettings != 0 {

View File

@ -26,17 +26,17 @@ type ChannelManager struct {
sync.RWMutex // tier 2 sync.RWMutex // tier 2
// chans is the main data structure, mapping casefolded name -> *Channel // chans is the main data structure, mapping casefolded name -> *Channel
chans map[string]*channelManagerEntry chans map[string]*channelManagerEntry
chansSkeletons utils.StringSet // skeletons of *unregistered* chans chansSkeletons utils.HashSet[string] // skeletons of *unregistered* chans
registeredChannels utils.StringSet // casefolds of registered chans registeredChannels utils.HashSet[string] // casefolds of registered chans
registeredSkeletons utils.StringSet // skeletons of registered chans registeredSkeletons utils.HashSet[string] // skeletons of registered chans
purgedChannels utils.StringSet // casefolds of purged chans purgedChannels utils.HashSet[string] // casefolds of purged chans
server *Server server *Server
} }
// NewChannelManager returns a new ChannelManager. // NewChannelManager returns a new ChannelManager.
func (cm *ChannelManager) Initialize(server *Server) { func (cm *ChannelManager) Initialize(server *Server) {
cm.chans = make(map[string]*channelManagerEntry) cm.chans = make(map[string]*channelManagerEntry)
cm.chansSkeletons = make(utils.StringSet) cm.chansSkeletons = make(utils.HashSet[string])
cm.server = server cm.server = server
// purging should work even if registration is disabled // purging should work even if registration is disabled
@ -66,8 +66,8 @@ func (cm *ChannelManager) loadRegisteredChannels(config *Config) {
cm.Lock() cm.Lock()
defer cm.Unlock() defer cm.Unlock()
cm.registeredChannels = make(utils.StringSet, len(rawNames)) cm.registeredChannels = make(utils.HashSet[string], len(rawNames))
cm.registeredSkeletons = make(utils.StringSet, len(rawNames)) cm.registeredSkeletons = make(utils.HashSet[string], len(rawNames))
for _, name := range rawNames { for _, name := range rawNames {
cfname, err := CasefoldChannel(name) cfname, err := CasefoldChannel(name)
if err == nil { if err == nil {

View File

@ -145,8 +145,8 @@ func (reg *ChannelRegistry) AllChannels() (result []string) {
} }
// PurgedChannels returns the set of all casefolded channel names that have been purged // PurgedChannels returns the set of all casefolded channel names that have been purged
func (reg *ChannelRegistry) PurgedChannels() (result utils.StringSet) { func (reg *ChannelRegistry) PurgedChannels() (result utils.HashSet[string]) {
result = make(utils.StringSet) result = make(utils.HashSet[string])
prefix := fmt.Sprintf(keyChannelPurged, "") prefix := fmt.Sprintf(keyChannelPurged, "")
reg.server.store.View(func(tx *buntdb.Tx) error { reg.server.store.View(func(tx *buntdb.Tx) error {

View File

@ -994,8 +994,8 @@ func (client *Client) ModeString() (str string) {
} }
// Friends refers to clients that share a channel with this client. // Friends refers to clients that share a channel with this client.
func (client *Client) Friends(capabs ...caps.Capability) (result map[*Session]empty) { func (client *Client) Friends(capabs ...caps.Capability) (result utils.HashSet[*Session]) {
result = make(map[*Session]empty) result = make(utils.HashSet[*Session])
// look at the client's own sessions // look at the client's own sessions
addFriendsToSet(result, client, capabs...) addFriendsToSet(result, client, capabs...)
@ -1010,19 +1010,19 @@ func (client *Client) Friends(capabs ...caps.Capability) (result map[*Session]em
} }
// Friends refers to clients that share a channel or extended-monitor this client. // Friends refers to clients that share a channel or extended-monitor this client.
func (client *Client) FriendsMonitors(capabs ...caps.Capability) (result map[*Session]empty) { func (client *Client) FriendsMonitors(capabs ...caps.Capability) (result utils.HashSet[*Session]) {
result = client.Friends(capabs...) result = client.Friends(capabs...)
client.server.monitorManager.AddMonitors(result, client.nickCasefolded, capabs...) client.server.monitorManager.AddMonitors(result, client.nickCasefolded, capabs...)
return return
} }
// helper for Friends // helper for Friends
func addFriendsToSet(set map[*Session]empty, client *Client, capabs ...caps.Capability) { func addFriendsToSet(set utils.HashSet[*Session], client *Client, capabs ...caps.Capability) {
client.stateMutex.RLock() client.stateMutex.RLock()
defer client.stateMutex.RUnlock() defer client.stateMutex.RUnlock()
for _, session := range client.sessions { for _, session := range client.sessions {
if session.capabilities.HasAll(capabs...) { if session.capabilities.HasAll(capabs...) {
set[session] = empty{} set.Add(session)
} }
} }
} }
@ -1575,7 +1575,7 @@ func (client *Client) addChannel(channel *Channel, simulated bool) (err error) {
} else if client.oper == nil && len(client.channels) >= config.Channels.MaxChannelsPerClient { } else if client.oper == nil && len(client.channels) >= config.Channels.MaxChannelsPerClient {
err = errTooManyChannels err = errTooManyChannels
} else { } else {
client.channels[channel] = empty{} // success client.channels.Add(channel) // success
} }
client.stateMutex.Unlock() client.stateMutex.Unlock()

View File

@ -11,7 +11,7 @@ import (
func TestGenerateBatchID(t *testing.T) { func TestGenerateBatchID(t *testing.T) {
var session Session var session Session
s := make(utils.StringSet) s := make(utils.HashSet[string])
count := 100000 count := 100000
for i := 0; i < count; i++ { for i := 0; i < count; i++ {

View File

@ -705,7 +705,7 @@ type Config struct {
type OperClass struct { type OperClass struct {
Title string Title string
WhoisLine string `yaml:"whois-line"` WhoisLine string `yaml:"whois-line"`
Capabilities utils.StringSet // map to make lookups much easier Capabilities utils.HashSet[string] // map to make lookups much easier
} }
// OperatorClasses returns a map of assembled operator classes from the given config. // OperatorClasses returns a map of assembled operator classes from the given config.
@ -743,7 +743,7 @@ func (conf *Config) OperatorClasses() (map[string]*OperClass, error) {
// create new operclass // create new operclass
var oc OperClass var oc OperClass
oc.Capabilities = make(utils.StringSet) oc.Capabilities = make(utils.HashSet[string])
// get inhereted info from other operclasses // get inhereted info from other operclasses
if len(info.Extends) > 0 { if len(info.Extends) > 0 {

View File

@ -3424,7 +3424,7 @@ func whoHandler(server *Server, client *Client, msg ircmsg.Message, rb *Response
// Construct set of channels the client is in. // Construct set of channels the client is in.
userChannels := make(ChannelSet) userChannels := make(ChannelSet)
for _, channel := range client.Channels() { for _, channel := range client.Channels() {
userChannels[channel] = empty{} userChannels.Add(channel)
} }
// Another client is a friend if they share at least one channel, or they are the same client. // Another client is a friend if they share at least one channel, or they are the same client.
@ -3437,7 +3437,7 @@ func whoHandler(server *Server, client *Client, msg ircmsg.Message, rb *Response
if channel.flags.HasMode(modes.Auditorium) { if channel.flags.HasMode(modes.Auditorium) {
return false // TODO this should respect +v etc. return false // TODO this should respect +v etc.
} }
if _, present := userChannels[channel]; present { if userChannels.Has(channel) {
return true return true
} }
} }

View File

@ -203,7 +203,7 @@ func (list *Buffer) betweenHelper(start, end Selector, cutoff time.Time, pred Pr
// returns all correspondents, in reverse time order // returns all correspondents, in reverse time order
func (list *Buffer) allCorrespondents() (results []TargetListing) { func (list *Buffer) allCorrespondents() (results []TargetListing) {
seen := make(utils.StringSet) seen := make(utils.HashSet[string])
list.RLock() list.RLock()
defer list.RUnlock() defer list.RUnlock()

View File

@ -54,7 +54,7 @@ type databaseImport struct {
Channels map[string]channelImport Channels map[string]channelImport
} }
func serializeAmodes(raw map[string]string, validCfUsernames utils.StringSet) (result []byte, err error) { func serializeAmodes(raw map[string]string, validCfUsernames utils.HashSet[string]) (result []byte, err error) {
processed := make(map[string]int, len(raw)) processed := make(map[string]int, len(raw))
for accountName, mode := range raw { for accountName, mode := range raw {
if len(mode) != 1 { if len(mode) != 1 {
@ -80,7 +80,7 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden
tx.Set(keySchemaVersion, strconv.Itoa(importDBSchemaVersion), nil) tx.Set(keySchemaVersion, strconv.Itoa(importDBSchemaVersion), nil)
tx.Set(keyCloakSecret, utils.GenerateSecretKey(), nil) tx.Set(keyCloakSecret, utils.GenerateSecretKey(), nil)
cfUsernames := make(utils.StringSet) cfUsernames := make(utils.HashSet[string])
skeletonToUsername := make(map[string]string) skeletonToUsername := make(map[string]string)
warnSkeletons := false warnSkeletons := false

View File

@ -7,6 +7,7 @@ import (
"sync" "sync"
"github.com/ergochat/ergo/irc/caps" "github.com/ergochat/ergo/irc/caps"
"github.com/ergochat/ergo/irc/utils"
"github.com/ergochat/irc-go/ircmsg" "github.com/ergochat/irc-go/ircmsg"
) )
@ -17,21 +18,21 @@ type MonitorManager struct {
// client -> (casefolded nick it's watching -> uncasefolded nick) // client -> (casefolded nick it's watching -> uncasefolded nick)
watching map[*Session]map[string]string watching map[*Session]map[string]string
// casefolded nick -> clients watching it // casefolded nick -> clients watching it
watchedby map[string]map[*Session]empty watchedby map[string]utils.HashSet[*Session]
} }
func (mm *MonitorManager) Initialize() { func (mm *MonitorManager) Initialize() {
mm.watching = make(map[*Session]map[string]string) mm.watching = make(map[*Session]map[string]string)
mm.watchedby = make(map[string]map[*Session]empty) mm.watchedby = make(map[string]utils.HashSet[*Session])
} }
// AddMonitors adds clients using extended-monitor monitoring `client`'s nick to the passed user set. // AddMonitors adds clients using extended-monitor monitoring `client`'s nick to the passed user set.
func (manager *MonitorManager) AddMonitors(users map[*Session]empty, cfnick string, capabs ...caps.Capability) { func (manager *MonitorManager) AddMonitors(users utils.HashSet[*Session], cfnick string, capabs ...caps.Capability) {
manager.RLock() manager.RLock()
defer manager.RUnlock() defer manager.RUnlock()
for session := range manager.watchedby[cfnick] { for session := range manager.watchedby[cfnick] {
if session.capabilities.Has(caps.ExtendedMonitor) && session.capabilities.HasAll(capabs...) { if session.capabilities.Has(caps.ExtendedMonitor) && session.capabilities.HasAll(capabs...) {
users[session] = empty{} users.Add(session)
} }
} }
} }
@ -70,7 +71,7 @@ func (manager *MonitorManager) Add(session *Session, nick string, limit int) err
manager.watching[session] = make(map[string]string) manager.watching[session] = make(map[string]string)
} }
if manager.watchedby[cfnick] == nil { if manager.watchedby[cfnick] == nil {
manager.watchedby[cfnick] = make(map[*Session]empty) manager.watchedby[cfnick] = make(utils.HashSet[*Session])
} }
if len(manager.watching[session]) >= limit { if len(manager.watching[session]) >= limit {
@ -78,7 +79,7 @@ func (manager *MonitorManager) Add(session *Session, nick string, limit int) err
} }
manager.watching[session][cfnick] = nick manager.watching[session][cfnick] = nick
manager.watchedby[cfnick][session] = empty{} manager.watchedby[cfnick].Add(session)
return nil return nil
} }
@ -92,7 +93,7 @@ func (manager *MonitorManager) Remove(session *Session, nick string) (err error)
manager.Lock() manager.Lock()
defer manager.Unlock() defer manager.Unlock()
delete(manager.watching[session], cfnick) delete(manager.watching[session], cfnick)
delete(manager.watchedby[cfnick], session) manager.watchedby[cfnick].Remove(session)
return nil return nil
} }
@ -102,7 +103,7 @@ func (manager *MonitorManager) RemoveAll(session *Session) {
defer manager.Unlock() defer manager.Unlock()
for cfnick := range manager.watching[session] { for cfnick := range manager.watching[session] {
delete(manager.watchedby[cfnick], session) manager.watchedby[cfnick].Remove(session)
} }
delete(manager.watching, session) delete(manager.watching, session)
} }

View File

@ -24,8 +24,8 @@ var (
"MemoServ", "BotServ", "OperServ", "MemoServ", "BotServ", "OperServ",
} }
restrictedCasefoldedNicks = make(utils.StringSet) restrictedCasefoldedNicks = make(utils.HashSet[string])
restrictedSkeletons = make(utils.StringSet) restrictedSkeletons = make(utils.HashSet[string])
) )
func performNickChange(server *Server, client *Client, target *Client, session *Session, nickname string, rb *ResponseBuffer) error { func performNickChange(server *Server, client *Client, target *Client, session *Session, nickname string, rb *ResponseBuffer) error {

View File

@ -9,28 +9,11 @@ import (
"time" "time"
"github.com/ergochat/ergo/irc/modes" "github.com/ergochat/ergo/irc/modes"
"github.com/ergochat/ergo/irc/utils"
) )
type empty struct{}
// ClientSet is a set of clients. // ClientSet is a set of clients.
type ClientSet map[*Client]empty type ClientSet = utils.HashSet[*Client]
// Add adds the given client to this set.
func (clients ClientSet) Add(client *Client) {
clients[client] = empty{}
}
// Remove removes the given client from this set.
func (clients ClientSet) Remove(client *Client) {
delete(clients, client)
}
// Has returns true if the given client is in this set.
func (clients ClientSet) Has(client *Client) bool {
_, ok := clients[client]
return ok
}
type memberData struct { type memberData struct {
modes *modes.ModeSet modes *modes.ModeSet
@ -60,4 +43,4 @@ func (members MemberSet) Has(member *Client) bool {
} }
// ChannelSet is a set of channels. // ChannelSet is a set of channels.
type ChannelSet map[*Channel]empty type ChannelSet = utils.HashSet[*Channel]

View File

@ -5,13 +5,25 @@ package utils
type empty struct{} type empty struct{}
type StringSet map[string]empty type HashSet[T comparable] map[T]empty
func (s StringSet) Has(str string) bool { func (s HashSet[T]) Has(elem T) bool {
_, ok := s[str] _, ok := s[elem]
return ok return ok
} }
func (s StringSet) Add(str string) { func (s HashSet[T]) Add(elem T) {
s[str] = empty{} s[elem] = empty{}
}
func (s HashSet[T]) Remove(elem T) {
delete(s, elem)
}
func CopyMap[K comparable, V any](input map[K]V) (result map[K]V) {
result = make(map[K]V, len(input))
for key, value := range input {
result[key] = value
}
return
} }

View File

@ -74,7 +74,7 @@ func timeToZncWireTime(t time.Time) (result string) {
type zncPlaybackTimes struct { type zncPlaybackTimes struct {
start time.Time start time.Time
end time.Time end time.Time
targets utils.StringSet // nil for "*" (everything), otherwise the channel names targets utils.HashSet[string] // nil for "*" (everything), otherwise the channel names
setAt time.Time setAt time.Time
} }
@ -134,7 +134,7 @@ func zncPlaybackPlayHandler(client *Client, command string, params []string, rb
end = zncWireTimeToTime(params[3]) end = zncWireTimeToTime(params[3])
} }
var targets utils.StringSet var targets utils.HashSet[string]
var nickTargets []string var nickTargets []string
// three cases: // three cases:
@ -157,7 +157,7 @@ func zncPlaybackPlayHandler(client *Client, command string, params []string, rb
if params[1] == "*" { if params[1] == "*" {
playPrivmsgs = true // XXX nil `targets` means "every channel" playPrivmsgs = true // XXX nil `targets` means "every channel"
} else { } else {
targets = make(utils.StringSet) targets = make(utils.HashSet[string])
for _, targetName := range strings.Split(targetString, ",") { for _, targetName := range strings.Split(targetString, ",") {
if strings.HasPrefix(targetName, "#") { if strings.HasPrefix(targetName, "#") {
if cfTarget, err := CasefoldChannel(targetName); err == nil { if cfTarget, err := CasefoldChannel(targetName); err == nil {