Merge pull request #1224 from slingamn/errors_again

minor refactoring
This commit is contained in:
Shivaram Lingamneni 2020-08-04 18:59:08 -07:00 committed by GitHub
commit 8f490ae298
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 60 additions and 52 deletions

View File

@ -5,6 +5,8 @@ package irc
import ( import (
"sync" "sync"
"github.com/oragono/oragono/irc/utils"
) )
type channelManagerEntry struct { type channelManagerEntry struct {
@ -23,17 +25,17 @@ type ChannelManager struct {
sync.RWMutex // tier 2 sync.RWMutex // tier 2
// chans is the main data structure, mapping casefolded name -> *Channel // chans is the main data structure, mapping casefolded name -> *Channel
chans map[string]*channelManagerEntry chans map[string]*channelManagerEntry
chansSkeletons StringSet // skeletons of *unregistered* chans chansSkeletons utils.StringSet // skeletons of *unregistered* chans
registeredChannels StringSet // casefolds of registered chans registeredChannels utils.StringSet // casefolds of registered chans
registeredSkeletons StringSet // skeletons of registered chans registeredSkeletons utils.StringSet // skeletons of registered chans
purgedChannels StringSet // casefolds of purged chans purgedChannels utils.StringSet // casefolds of purged chans
server *Server server *Server
} }
// NewChannelManager returns a new ChannelManager. // NewChannelManager returns a new ChannelManager.
func (cm *ChannelManager) Initialize(server *Server) { func (cm *ChannelManager) Initialize(server *Server) {
cm.chans = make(map[string]*channelManagerEntry) cm.chans = make(map[string]*channelManagerEntry)
cm.chansSkeletons = make(StringSet) cm.chansSkeletons = make(utils.StringSet)
cm.server = server cm.server = server
cm.loadRegisteredChannels(server.Config()) cm.loadRegisteredChannels(server.Config())
@ -47,8 +49,8 @@ func (cm *ChannelManager) loadRegisteredChannels(config *Config) {
} }
rawNames := cm.server.channelRegistry.AllChannels() rawNames := cm.server.channelRegistry.AllChannels()
registeredChannels := make(StringSet, len(rawNames)) registeredChannels := make(utils.StringSet, len(rawNames))
registeredSkeletons := make(StringSet, len(rawNames)) registeredSkeletons := make(utils.StringSet, len(rawNames))
for _, name := range rawNames { for _, name := range rawNames {
cfname, err := CasefoldChannel(name) cfname, err := CasefoldChannel(name)
if err == nil { if err == nil {

View File

@ -4,15 +4,16 @@
package irc package irc
import ( import (
"encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"encoding/json" "github.com/tidwall/buntdb"
"github.com/oragono/oragono/irc/modes" "github.com/oragono/oragono/irc/modes"
"github.com/tidwall/buntdb" "github.com/oragono/oragono/irc/utils"
) )
// this is exclusively the *persistence* layer for channel registration; // this is exclusively the *persistence* layer for channel registration;
@ -140,8 +141,8 @@ func (reg *ChannelRegistry) AllChannels() (result []string) {
} }
// PurgedChannels returns the set of all casefolded channel names that have been purged // PurgedChannels returns the set of all casefolded channel names that have been purged
func (reg *ChannelRegistry) PurgedChannels() (result map[string]empty) { func (reg *ChannelRegistry) PurgedChannels() (result utils.StringSet) {
result = make(map[string]empty) result = make(utils.StringSet)
prefix := fmt.Sprintf(keyChannelPurged, "") prefix := fmt.Sprintf(keyChannelPurged, "")
reg.server.store.View(func(tx *buntdb.Tx) error { reg.server.store.View(func(tx *buntdb.Tx) error {
@ -150,7 +151,7 @@ func (reg *ChannelRegistry) PurgedChannels() (result map[string]empty) {
return false return false
} }
channel := strings.TrimPrefix(key, prefix) channel := strings.TrimPrefix(key, prefix)
result[channel] = empty{} result.Add(channel)
return true return true
}) })
}) })

View File

@ -64,7 +64,7 @@ type Client struct {
destroyed bool destroyed bool
modes modes.ModeSet modes modes.ModeSet
hostname string hostname string
invitedTo StringSet invitedTo utils.StringSet
isSTSOnly bool isSTSOnly bool
languages []string languages []string
lastActive time.Time // last time they sent a command that wasn't PONG or similar lastActive time.Time // last time they sent a command that wasn't PONG or similar
@ -1641,7 +1641,7 @@ func (client *Client) Invite(casefoldedChannel string) {
defer client.stateMutex.Unlock() defer client.stateMutex.Unlock()
if client.invitedTo == nil { if client.invitedTo == nil {
client.invitedTo = make(StringSet) client.invitedTo = make(utils.StringSet)
} }
client.invitedTo.Add(casefoldedChannel) client.invitedTo.Add(casefoldedChannel)

View File

@ -5,11 +5,13 @@ package irc
import ( import (
"testing" "testing"
"github.com/oragono/oragono/irc/utils"
) )
func TestGenerateBatchID(t *testing.T) { func TestGenerateBatchID(t *testing.T) {
var session Session var session Session
s := make(StringSet) s := make(utils.StringSet)
count := 100000 count := 100000
for i := 0; i < count; i++ { for i := 0; i < count; i++ {

View File

@ -624,8 +624,8 @@ type Config struct {
// OperClass defines an assembled operator class. // OperClass defines an assembled operator class.
type OperClass struct { type OperClass struct {
Title string Title string
WhoisLine string `yaml:"whois-line"` WhoisLine string `yaml:"whois-line"`
Capabilities StringSet // map to make lookups much easier Capabilities utils.StringSet // map to make lookups much easier
} }
// OperatorClasses returns a map of assembled operator classes from the given config. // OperatorClasses returns a map of assembled operator classes from the given config.
@ -663,7 +663,7 @@ func (conf *Config) OperatorClasses() (map[string]*OperClass, error) {
// create new operclass // create new operclass
var oc OperClass var oc OperClass
oc.Capabilities = make(StringSet) oc.Capabilities = make(utils.StringSet)
// get inhereted info from other operclasses // get inhereted info from other operclasses
if len(info.Extends) > 0 { if len(info.Extends) > 0 {

View File

@ -73,13 +73,6 @@ var (
errRegisteredOnly = errors.New("Cannot join registered-only channel without an account") errRegisteredOnly = errors.New("Cannot join registered-only channel without an account")
) )
// Socket Errors
var (
errNoPeerCerts = errors.New("Client did not provide a certificate")
errNotTLS = errors.New("Not a TLS connection")
errReadQ = errors.New("ReadQ Exceeded")
)
// String Errors // String Errors
var ( var (
errCouldNotStabilize = errors.New("Could not stabilize string while casefolding") errCouldNotStabilize = errors.New("Could not stabilize string while casefolding")

View File

@ -2998,9 +2998,9 @@ func whoHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
} }
} else { } else {
// Construct set of channels the client is in. // Construct set of channels the client is in.
userChannels := make(map[*Channel]bool) userChannels := make(ChannelSet)
for _, channel := range client.Channels() { for _, channel := range client.Channels() {
userChannels[channel] = true userChannels[channel] = empty{}
} }
// Another client is a friend if they share at least one channel, or they are the same client. // Another client is a friend if they share at least one channel, or they are the same client.
@ -3010,7 +3010,7 @@ func whoHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Respo
} }
for _, channel := range otherClient.Channels() { for _, channel := range otherClient.Channels() {
if userChannels[channel] { if _, present := userChannels[channel]; present {
return true return true
} }
} }

View File

@ -3,6 +3,7 @@ package irc
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"net" "net"
"unicode/utf8" "unicode/utf8"
@ -17,7 +18,8 @@ const (
) )
var ( var (
crlf = []byte{'\r', '\n'} crlf = []byte{'\r', '\n'}
errReadQ = errors.New("ReadQ Exceeded")
) )
// IRCConn abstracts away the distinction between a regular // IRCConn abstracts away the distinction between a regular
@ -31,7 +33,7 @@ type IRCConn interface {
// these take an IRC line or lines, correctly terminated with CRLF: // these take an IRC line or lines, correctly terminated with CRLF:
WriteLine([]byte) error WriteLine([]byte) error
WriteLines([][]byte) error WriteLines([][]byte) error
// this returns an IRC line without the terminating CRLF: // this returns an IRC line, possibly terminated with CRLF, LF, or nothing:
ReadLine() (line []byte, err error) ReadLine() (line []byte, err error)
Close() error Close() error
@ -127,6 +129,9 @@ func (wc IRCWSConn) ReadLine() (line []byte, err error) {
messageType, line, err = wc.conn.ReadMessage() messageType, line, err = wc.conn.ReadMessage()
// on empty message or non-text message, try again, block if necessary // on empty message or non-text message, try again, block if necessary
if err != nil || (messageType == websocket.TextMessage && len(line) != 0) { if err != nil || (messageType == websocket.TextMessage && len(line) != 0) {
if err == websocket.ErrReadLimit {
err = errReadQ
}
return return
} }
} }

View File

@ -28,17 +28,6 @@ func (clients ClientSet) Has(client *Client) bool {
return ok return ok
} }
type StringSet map[string]empty
func (s StringSet) Has(str string) bool {
_, ok := s[str]
return ok
}
func (s StringSet) Add(str string) {
s[str] = empty{}
}
// MemberSet is a set of members with modes. // MemberSet is a set of members with modes.
type MemberSet map[*Client]*modes.ModeSet type MemberSet map[*Client]*modes.ModeSet

View File

@ -10,27 +10,25 @@ import (
"time" "time"
) )
type e struct{}
// Semaphore is a counting semaphore. // Semaphore is a counting semaphore.
// A semaphore of capacity 1 can be used as a trylock. // A semaphore of capacity 1 can be used as a trylock.
type Semaphore (chan e) type Semaphore (chan empty)
// Initialize initializes a semaphore to a given capacity. // Initialize initializes a semaphore to a given capacity.
func (semaphore *Semaphore) Initialize(capacity int) { func (semaphore *Semaphore) Initialize(capacity int) {
*semaphore = make(chan e, capacity) *semaphore = make(chan empty, capacity)
} }
// Acquire acquires a semaphore, blocking if necessary. // Acquire acquires a semaphore, blocking if necessary.
func (semaphore *Semaphore) Acquire() { func (semaphore *Semaphore) Acquire() {
(*semaphore) <- e{} (*semaphore) <- empty{}
} }
// TryAcquire tries to acquire a semaphore, returning whether the acquire was // TryAcquire tries to acquire a semaphore, returning whether the acquire was
// successful. It never blocks. // successful. It never blocks.
func (semaphore *Semaphore) TryAcquire() (acquired bool) { func (semaphore *Semaphore) TryAcquire() (acquired bool) {
select { select {
case (*semaphore) <- e{}: case (*semaphore) <- empty{}:
return true return true
default: default:
return false return false
@ -47,7 +45,7 @@ func (semaphore *Semaphore) AcquireWithTimeout(timeout time.Duration) (acquired
timer := time.NewTimer(timeout) timer := time.NewTimer(timeout)
select { select {
case (*semaphore) <- e{}: case (*semaphore) <- empty{}:
acquired = true acquired = true
case <-timer.C: case <-timer.C:
acquired = false acquired = false
@ -61,7 +59,7 @@ func (semaphore *Semaphore) AcquireWithTimeout(timeout time.Duration) (acquired
// Note that if the context is already expired, the acquire may succeed anyway. // Note that if the context is already expired, the acquire may succeed anyway.
func (semaphore *Semaphore) AcquireWithContext(ctx context.Context) (acquired bool) { func (semaphore *Semaphore) AcquireWithContext(ctx context.Context) (acquired bool) {
select { select {
case (*semaphore) <- e{}: case (*semaphore) <- empty{}:
acquired = true acquired = true
case <-ctx.Done(): case <-ctx.Done():
acquired = false acquired = false

17
irc/utils/types.go Normal file
View File

@ -0,0 +1,17 @@
// Copyright (c) 2020 Shivaram Lingamneni
// released under the MIT license
package utils
type empty struct{}
type StringSet map[string]empty
func (s StringSet) Has(str string) bool {
_, ok := s[str]
return ok
}
func (s StringSet) Add(str string) {
s[str] = empty{}
}

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/oragono/oragono/irc/history" "github.com/oragono/oragono/irc/history"
"github.com/oragono/oragono/irc/utils"
) )
const ( const (
@ -62,7 +63,7 @@ func timeToZncWireTime(t time.Time) (result string) {
type zncPlaybackTimes struct { type zncPlaybackTimes struct {
start time.Time start time.Time
end time.Time end time.Time
targets StringSet // nil for "*" (everything), otherwise the channel names targets utils.StringSet // nil for "*" (everything), otherwise the channel names
setAt time.Time setAt time.Time
} }
@ -122,7 +123,7 @@ func zncPlaybackPlayHandler(client *Client, command string, params []string, rb
end = zncWireTimeToTime(params[3]) end = zncWireTimeToTime(params[3])
} }
var targets StringSet var targets utils.StringSet
var nickTargets []string var nickTargets []string
// three cases: // three cases:
@ -145,7 +146,7 @@ func zncPlaybackPlayHandler(client *Client, command string, params []string, rb
if params[1] == "*" { if params[1] == "*" {
playPrivmsgs = true // XXX nil `targets` means "every channel" playPrivmsgs = true // XXX nil `targets` means "every channel"
} else { } else {
targets = make(StringSet) targets = make(utils.StringSet)
for _, targetName := range strings.Split(targetString, ",") { for _, targetName := range strings.Split(targetString, ",") {
if targetName == "*self" { if targetName == "*self" {
playPrivmsgs = true playPrivmsgs = true