diff --git a/irc/database.go b/irc/database.go index 635ffee5..27974668 100644 --- a/irc/database.go +++ b/irc/database.go @@ -22,7 +22,7 @@ const ( // 'version' of the database schema keySchemaVersion = "db.version" // latest schema of the db - latestDbSchema = "3" + latestDbSchema = "4" ) type SchemaChanger func(*Config, *buntdb.Tx) error @@ -190,7 +190,7 @@ func UpgradeDB(config *Config) (err error) { }) if err != nil { - log.Println("database upgrade failed and was rolled back") + log.Printf("database upgrade failed and was rolled back: %v\n", err) } return err } @@ -278,6 +278,118 @@ func schemaChangeV2ToV3(config *Config, tx *buntdb.Tx) error { return nil } +// 1. ban info format changed (from `legacyBanInfo` below to `IPBanInfo`) +// 2. dlines against individual IPs are normalized into dlines against the appropriate /128 network +func schemaChangeV3ToV4(config *Config, tx *buntdb.Tx) error { + type ipRestrictTime struct { + Duration time.Duration + Expires time.Time + } + type legacyBanInfo struct { + Reason string `json:"reason"` + OperReason string `json:"oper_reason"` + OperName string `json:"oper_name"` + Time *ipRestrictTime `json:"time"` + } + + now := time.Now() + legacyToNewInfo := func(old legacyBanInfo) (new_ IPBanInfo) { + new_.Reason = old.Reason + new_.OperReason = old.OperReason + new_.OperName = old.OperName + + if old.Time == nil { + new_.TimeCreated = now + new_.Duration = 0 + } else { + new_.TimeCreated = old.Time.Expires.Add(-1 * old.Time.Duration) + new_.Duration = old.Time.Duration + } + return + } + + var keysToDelete []string + + prefix := "bans.dline " + dlines := make(map[string]IPBanInfo) + tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool { + if !strings.HasPrefix(key, prefix) { + return false + } + keysToDelete = append(keysToDelete, key) + + var lbinfo legacyBanInfo + id := strings.TrimPrefix(key, prefix) + err := json.Unmarshal([]byte(value), &lbinfo) + if err != nil { + log.Printf("error unmarshaling legacy dline: %v\n", err) + return true + } + // legacy keys can be either an IP or a CIDR + hostNet, err := utils.NormalizedNetFromString(id) + if err != nil { + log.Printf("error unmarshaling legacy dline network: %v\n", err) + return true + } + dlines[utils.NetToNormalizedString(hostNet)] = legacyToNewInfo(lbinfo) + + return true + }) + + setOptions := func(info IPBanInfo) *buntdb.SetOptions { + if info.Duration == 0 { + return nil + } + ttl := info.TimeCreated.Add(info.Duration).Sub(now) + return &buntdb.SetOptions{Expires: true, TTL: ttl} + } + + // store the new dlines + for id, info := range dlines { + b, err := json.Marshal(info) + if err != nil { + log.Printf("error marshaling migrated dline: %v\n", err) + continue + } + tx.Set(fmt.Sprintf("bans.dlinev2 %s", id), string(b), setOptions(info)) + } + + // same operations against klines + prefix = "bans.kline " + klines := make(map[string]IPBanInfo) + tx.AscendGreaterOrEqual("", prefix, func(key, value string) bool { + if !strings.HasPrefix(key, prefix) { + return false + } + keysToDelete = append(keysToDelete, key) + mask := strings.TrimPrefix(key, prefix) + var lbinfo legacyBanInfo + err := json.Unmarshal([]byte(value), &lbinfo) + if err != nil { + log.Printf("error unmarshaling legacy kline: %v\n", err) + return true + } + klines[mask] = legacyToNewInfo(lbinfo) + return true + }) + + for mask, info := range klines { + b, err := json.Marshal(info) + if err != nil { + log.Printf("error marshaling migrated kline: %v\n", err) + continue + } + tx.Set(fmt.Sprintf("bans.klinev2 %s", mask), string(b), setOptions(info)) + } + + // clean up all the old entries + for _, key := range keysToDelete { + tx.Delete(key) + } + + return nil +} + func init() { allChanges := []SchemaChange{ { @@ -290,6 +402,11 @@ func init() { TargetVersion: "3", Changer: schemaChangeV2ToV3, }, + { + InitialVersion: "3", + TargetVersion: "4", + Changer: schemaChangeV3ToV4, + }, } // build the index diff --git a/irc/dline.go b/irc/dline.go index e0491228..c723c204 100644 --- a/irc/dline.go +++ b/irc/dline.go @@ -4,33 +4,21 @@ package irc import ( + "encoding/json" "fmt" "net" + "strings" "sync" "time" - "encoding/json" - + "github.com/oragono/oragono/irc/utils" "github.com/tidwall/buntdb" ) const ( - keyDlineEntry = "bans.dline %s" + keyDlineEntry = "bans.dlinev2 %s" ) -// IPRestrictTime contains the expiration info about the given IP. -type IPRestrictTime struct { - // Duration is how long this block lasts for. - Duration time.Duration `json:"duration"` - // Expires is when this block expires. - Expires time.Time `json:"expires"` -} - -// IsExpired returns true if the time has expired. -func (iptime *IPRestrictTime) IsExpired() bool { - return iptime.Expires.Before(time.Now()) -} - // IPBanInfo holds info about an IP/net ban. type IPBanInfo struct { // Reason is the ban reason. @@ -39,30 +27,38 @@ type IPBanInfo struct { OperReason string `json:"oper_reason"` // OperName is the oper who set the ban. OperName string `json:"oper_name"` - // Time holds details about the duration, if it exists. - Time *IPRestrictTime `json:"time"` + // time of ban creation + TimeCreated time.Time + // duration of the ban; 0 means "permanent" + Duration time.Duration +} + +func (info IPBanInfo) timeLeft() time.Duration { + return info.TimeCreated.Add(info.Duration).Sub(time.Now()) +} + +func (info IPBanInfo) TimeLeft() string { + if info.Duration == 0 { + return "indefinite" + } else { + return info.timeLeft().Truncate(time.Second).String() + } } // BanMessage returns the ban message. func (info IPBanInfo) BanMessage(message string) string { message = fmt.Sprintf(message, info.Reason) - if info.Time != nil { - message += fmt.Sprintf(" [%s]", info.Time.Duration.String()) + if info.Duration != 0 { + message += fmt.Sprintf(" [%s]", info.TimeLeft()) } return message } -// dLineAddr contains the address itself and expiration time for a given network. -type dLineAddr struct { - // Address is the address that is blocked. - Address net.IP - // Info contains information on the ban. - Info IPBanInfo -} - // dLineNet contains the net itself and expiration time for a given network. type dLineNet struct { // Network is the network that is blocked. + // This is always an IPv6 CIDR; IPv4 CIDRs are translated with the 4-in-6 prefix, + // individual IPv4 and IPV6 addresses are translated to the relevant /128. Network net.IPNet // Info contains information on the ban. Info IPBanInfo @@ -70,18 +66,26 @@ type dLineNet struct { // DLineManager manages and dlines. type DLineManager struct { - sync.RWMutex // tier 1 - // addresses that are dlined - addresses map[string]*dLineAddr - // networks that are dlined - networks map[string]*dLineNet + sync.RWMutex // tier 1 + persistenceMutex sync.Mutex // tier 2 + // networks that are dlined: + // XXX: the keys of this map (which are also the database persistence keys) + // are the human-readable representations returned by NetToNormalizedString + networks map[string]dLineNet + // this keeps track of expiration timers for temporary bans + expirationTimers map[string]*time.Timer + server *Server } // NewDLineManager returns a new DLineManager. -func NewDLineManager() *DLineManager { +func NewDLineManager(server *Server) *DLineManager { var dm DLineManager - dm.addresses = make(map[string]*dLineAddr) - dm.networks = make(map[string]*dLineNet) + dm.networks = make(map[string]dLineNet) + dm.expirationTimers = make(map[string]*time.Timer) + dm.server = server + + dm.loadFromDatastore() + return &dm } @@ -92,154 +96,209 @@ func (dm *DLineManager) AllBans() map[string]IPBanInfo { dm.RLock() defer dm.RUnlock() - for name, info := range dm.addresses { - allb[name] = info.Info - } - for name, info := range dm.networks { - allb[name] = info.Info + // map keys are already the human-readable forms, just return a copy of the map + for key, info := range dm.networks { + allb[key] = info.Info } return allb } // AddNetwork adds a network to the blocked list. -func (dm *DLineManager) AddNetwork(network net.IPNet, length *IPRestrictTime, reason, operReason, operName string) { - netString := network.String() - dln := dLineNet{ - Network: network, - Info: IPBanInfo{ - Time: length, - Reason: reason, - OperReason: operReason, - OperName: operName, - }, +func (dm *DLineManager) AddNetwork(network net.IPNet, duration time.Duration, reason, operReason, operName string) error { + dm.persistenceMutex.Lock() + defer dm.persistenceMutex.Unlock() + + // assemble ban info + info := IPBanInfo{ + Reason: reason, + OperReason: operReason, + OperName: operName, + TimeCreated: time.Now(), + Duration: duration, } + + id := dm.addNetworkInternal(network, info) + return dm.persistDline(id, info) +} + +func (dm *DLineManager) addNetworkInternal(network net.IPNet, info IPBanInfo) (id string) { + network = utils.NormalizeNet(network) + id = utils.NetToNormalizedString(network) + + var timeLeft time.Duration + if info.Duration != 0 { + timeLeft = info.timeLeft() + if timeLeft <= 0 { + return + } + } + dm.Lock() - dm.networks[netString] = &dln - dm.Unlock() + defer dm.Unlock() + + dm.networks[id] = dLineNet{ + Network: network, + Info: info, + } + + dm.cancelTimer(id) + + if info.Duration == 0 { + return + } + + // set up new expiration timer + timeCreated := info.TimeCreated + processExpiration := func() { + dm.Lock() + defer dm.Unlock() + + netBan, ok := dm.networks[id] + if ok && netBan.Info.TimeCreated.Equal(timeCreated) { + delete(dm.networks, id) + // TODO(slingamn) here's where we'd remove it from the radix tree + delete(dm.expirationTimers, id) + } + } + dm.expirationTimers[id] = time.AfterFunc(timeLeft, processExpiration) + + return +} + +func (dm *DLineManager) cancelTimer(id string) { + oldTimer := dm.expirationTimers[id] + if oldTimer != nil { + oldTimer.Stop() + delete(dm.expirationTimers, id) + } +} + +func (dm *DLineManager) persistDline(id string, info IPBanInfo) error { + // save in datastore + dlineKey := fmt.Sprintf(keyDlineEntry, id) + // assemble json from ban info + b, err := json.Marshal(info) + if err != nil { + dm.server.logger.Error("internal", "couldn't marshal d-line", err.Error()) + return err + } + bstr := string(b) + var setOptions *buntdb.SetOptions + if info.Duration != 0 { + setOptions = &buntdb.SetOptions{Expires: true, TTL: info.Duration} + } + + err = dm.server.store.Update(func(tx *buntdb.Tx) error { + _, _, err := tx.Set(dlineKey, bstr, setOptions) + return err + }) + if err != nil { + dm.server.logger.Error("internal", "couldn't store d-line", err.Error()) + } + return err +} + +func (dm *DLineManager) unpersistDline(id string) error { + dlineKey := fmt.Sprintf(keyDlineEntry, id) + return dm.server.store.Update(func(tx *buntdb.Tx) error { + _, err := tx.Delete(dlineKey) + return err + }) } // RemoveNetwork removes a network from the blocked list. -func (dm *DLineManager) RemoveNetwork(network net.IPNet) { - netString := network.String() - dm.Lock() - delete(dm.networks, netString) - dm.Unlock() +func (dm *DLineManager) RemoveNetwork(network net.IPNet) error { + dm.persistenceMutex.Lock() + defer dm.persistenceMutex.Unlock() + + id := utils.NetToNormalizedString(utils.NormalizeNet(network)) + + present := func() bool { + dm.Lock() + defer dm.Unlock() + _, ok := dm.networks[id] + delete(dm.networks, id) + dm.cancelTimer(id) + return ok + }() + + if !present { + return errNoExistingBan + } + + return dm.unpersistDline(id) } // AddIP adds an IP address to the blocked list. -func (dm *DLineManager) AddIP(addr net.IP, length *IPRestrictTime, reason, operReason, operName string) { - addrString := addr.String() - dla := dLineAddr{ - Address: addr, - Info: IPBanInfo{ - Time: length, - Reason: reason, - OperReason: operReason, - OperName: operName, - }, - } - dm.Lock() - dm.addresses[addrString] = &dla - dm.Unlock() +func (dm *DLineManager) AddIP(addr net.IP, duration time.Duration, reason, operReason, operName string) error { + return dm.AddNetwork(utils.NormalizeIPToNet(addr), duration, reason, operReason, operName) } -// RemoveIP removes an IP from the blocked list. -func (dm *DLineManager) RemoveIP(addr net.IP) { - addrString := addr.String() - dm.Lock() - delete(dm.addresses, addrString) - dm.Unlock() +// RemoveIP removes an IP address from the blocked list. +func (dm *DLineManager) RemoveIP(addr net.IP) error { + return dm.RemoveNetwork(utils.NormalizeIPToNet(addr)) } // CheckIP returns whether or not an IP address was banned, and how long it is banned for. -func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info *IPBanInfo) { - // check IP addr - addrString := addr.String() - dm.RLock() - addrInfo := dm.addresses[addrString] - dm.RUnlock() - - if addrInfo != nil { - if addrInfo.Info.Time != nil { - if addrInfo.Info.Time.IsExpired() { - // ban on IP has expired, remove it from our blocked list - dm.RemoveIP(addr) - } else { - return true, &addrInfo.Info - } - } else { - return true, &addrInfo.Info - } - } - - // check networks - doCleanup := false - defer func() { - if doCleanup { - go func() { - dm.Lock() - defer dm.Unlock() - for key, netInfo := range dm.networks { - if netInfo.Info.Time.IsExpired() { - delete(dm.networks, key) - } - } - }() - } - }() +func (dm *DLineManager) CheckIP(addr net.IP) (isBanned bool, info IPBanInfo) { + addr = addr.To16() // almost certainly unnecessary dm.RLock() defer dm.RUnlock() - for _, netInfo := range dm.networks { - if netInfo.Info.Time != nil && netInfo.Info.Time.IsExpired() { - // expired ban, ignore and clean up later - doCleanup = true - } else if netInfo.Network.Contains(addr) { - return true, &netInfo.Info + // check networks + // TODO(slingamn) use a radix tree as the data plane for this + for _, netBan := range dm.networks { + if netBan.Network.Contains(addr) { + return true, netBan.Info } } // no matches! - return false, nil + isBanned = false + return } -func (s *Server) loadDLines() { - s.dlines = NewDLineManager() +func (dm *DLineManager) loadFromDatastore() { + dlinePrefix := fmt.Sprintf(keyDlineEntry, "") + dm.server.store.View(func(tx *buntdb.Tx) error { + tx.AscendGreaterOrEqual("", dlinePrefix, func(key, value string) bool { + if !strings.HasPrefix(key, dlinePrefix) { + return false + } - // load from datastore - s.store.View(func(tx *buntdb.Tx) error { - //TODO(dan): We could make this safer - tx.AscendKeys("bans.dline *", func(key, value string) bool { // get address name - key = key[len("bans.dline "):] + key = strings.TrimPrefix(key, dlinePrefix) // load addr/net - var hostAddr net.IP - var hostNet *net.IPNet - _, hostNet, err := net.ParseCIDR(key) + hostNet, err := utils.NormalizedNetFromString(key) if err != nil { - hostAddr = net.ParseIP(key) + dm.server.logger.Error("internal", "bad dline cidr", err.Error()) + return true } // load ban info var info IPBanInfo - json.Unmarshal([]byte(value), &info) + err = json.Unmarshal([]byte(value), &info) + if err != nil { + dm.server.logger.Error("internal", "bad dline data", err.Error()) + return true + } // set opername if it isn't already set if info.OperName == "" { - info.OperName = s.name + info.OperName = dm.server.name } // add to the server - if hostNet == nil { - s.dlines.AddIP(hostAddr, info.Time, info.Reason, info.OperReason, info.OperName) - } else { - s.dlines.AddNetwork(*hostNet, info.Time, info.Reason, info.OperReason, info.OperName) - } + dm.addNetworkInternal(hostNet, info) - return true // true to continue I guess? + return true }) return nil }) } + +func (s *Server) loadDLines() { + s.dlines = NewDLineManager(s) +} diff --git a/irc/handlers.go b/irc/handlers.go index 81af9c45..b8075296 100644 --- a/irc/handlers.go +++ b/irc/handlers.go @@ -9,9 +9,7 @@ package irc import ( "bytes" "encoding/base64" - "encoding/json" "fmt" - "net" "os" "runtime" "runtime/debug" @@ -30,7 +28,6 @@ import ( "github.com/oragono/oragono/irc/modes" "github.com/oragono/oragono/irc/sno" "github.com/oragono/oragono/irc/utils" - "github.com/tidwall/buntdb" "golang.org/x/crypto/bcrypt" ) @@ -578,6 +575,33 @@ func debugHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res return false } +// helper for parsing the reason args to DLINE and KLINE +func getReasonsFromParams(params []string, currentArg int) (reason, operReason string) { + reason = "No reason given" + operReason = "" + if len(params) > currentArg { + reasons := strings.SplitN(strings.Join(params[currentArg:], " "), "|", 2) + if len(reasons) == 1 { + reason = strings.TrimSpace(reasons[0]) + } else if len(reasons) == 2 { + reason = strings.TrimSpace(reasons[0]) + operReason = strings.TrimSpace(reasons[1]) + } + } + return +} + +func formatBanForListing(client *Client, key string, info IPBanInfo) string { + desc := info.Reason + if info.OperReason != "" && info.OperReason != info.Reason { + desc = fmt.Sprintf("%s | %s", info.Reason, info.OperReason) + } + if info.Duration != 0 { + desc = fmt.Sprintf("%s [%s]", desc, info.TimeLeft()) + } + return fmt.Sprintf(client.t("Ban - %[1]s - added by %[2]s - %[3]s"), key, info.OperName, desc) +} + // DLINE [ANDKILL] [MYSELF] [duration] / [ON ] [reason [| oper reason]] // DLINE LIST func dlineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *ResponseBuffer) bool { @@ -599,7 +623,7 @@ func dlineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res } for key, info := range bans { - rb.Notice(fmt.Sprintf(client.t("Ban - %[1]s - added by %[2]s - %[3]s"), key, info.OperName, info.BanMessage("%s"))) + client.Notice(formatBanForListing(client, key, info)) } return false @@ -622,8 +646,9 @@ func dlineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res // duration duration, err := custime.ParseDuration(msg.Params[currentArg]) - durationIsUsed := err == nil - if durationIsUsed { + if err != nil { + duration = 0 + } else { currentArg++ } @@ -636,31 +661,16 @@ func dlineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res currentArg++ // check host - var hostAddr net.IP - var hostNet *net.IPNet + hostNet, err := utils.NormalizedNetFromString(hostString) - _, hostNet, err = net.ParseCIDR(hostString) if err != nil { - hostAddr = net.ParseIP(hostString) - } - - if hostAddr == nil && hostNet == nil { rb.Add(nil, server.name, ERR_UNKNOWNERROR, client.nick, msg.Command, client.t("Could not parse IP address or CIDR network")) return false } - if hostNet == nil { - hostString = hostAddr.String() - if !dlineMyself && hostAddr.Equal(client.IP()) { - rb.Add(nil, server.name, ERR_UNKNOWNERROR, client.nick, msg.Command, client.t("This ban matches you. To DLINE yourself, you must use the command: /DLINE MYSELF ")) - return false - } - } else { - hostString = hostNet.String() - if !dlineMyself && hostNet.Contains(client.IP()) { - rb.Add(nil, server.name, ERR_UNKNOWNERROR, client.nick, msg.Command, client.t("This ban matches you. To DLINE yourself, you must use the command: /DLINE MYSELF ")) - return false - } + if !dlineMyself && hostNet.Contains(client.IP()) { + rb.Add(nil, server.name, ERR_UNKNOWNERROR, client.nick, msg.Command, client.t("This ban matches you. To DLINE yourself, you must use the command: /DLINE MYSELF ")) + return false } // check remote @@ -670,71 +680,23 @@ func dlineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res } // get comment(s) - reason := "No reason given" - operReason := "No reason given" - if len(msg.Params) > currentArg { - tempReason := strings.TrimSpace(msg.Params[currentArg]) - if len(tempReason) > 0 && tempReason != "|" { - tempReasons := strings.SplitN(tempReason, "|", 2) - if tempReasons[0] != "" { - reason = tempReasons[0] - } - if len(tempReasons) > 1 && tempReasons[1] != "" { - operReason = tempReasons[1] - } else { - operReason = reason - } - } - } + reason, operReason := getReasonsFromParams(msg.Params, currentArg) + operName := oper.Name if operName == "" { operName = server.name } - // assemble ban info - var banTime *IPRestrictTime - if durationIsUsed { - banTime = &IPRestrictTime{ - Duration: duration, - Expires: time.Now().Add(duration), - } - } - - info := IPBanInfo{ - Reason: reason, - OperReason: operReason, - OperName: operName, - Time: banTime, - } - - // save in datastore - err = server.store.Update(func(tx *buntdb.Tx) error { - dlineKey := fmt.Sprintf(keyDlineEntry, hostString) - - // assemble json from ban info - b, err := json.Marshal(info) - if err != nil { - return err - } - - tx.Set(dlineKey, string(b), nil) - - return nil - }) + err = server.dlines.AddNetwork(hostNet, duration, reason, operReason, operName) if err != nil { rb.Notice(fmt.Sprintf(client.t("Could not successfully save new D-LINE: %s"), err.Error())) return false } - if hostNet == nil { - server.dlines.AddIP(hostAddr, banTime, reason, operReason, operName) - } else { - server.dlines.AddNetwork(*hostNet, banTime, reason, operReason, operName) - } - var snoDescription string - if durationIsUsed { + hostString = utils.NetToNormalizedString(hostNet) + if duration != 0 { rb.Notice(fmt.Sprintf(client.t("Added temporary (%[1]s) D-Line for %[2]s"), duration.String(), hostString)) snoDescription = fmt.Sprintf(ircfmt.Unescape("%s [%s]$r added temporary (%s) D-Line for %s"), client.nick, operName, duration.String(), hostString) } else { @@ -747,16 +709,9 @@ func dlineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res if andKill { var clientsToKill []*Client var killedClientNicks []string - var toKill bool for _, mcl := range server.clients.AllClients() { - if hostNet == nil { - toKill = hostAddr.Equal(mcl.IP()) - } else { - toKill = hostNet.Contains(mcl.IP()) - } - - if toKill { + if hostNet.Contains(mcl.IP()) { clientsToKill = append(clientsToKill, mcl) killedClientNicks = append(killedClientNicks, mcl.nick) } @@ -1047,7 +1002,7 @@ func klineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res } for key, info := range bans { - client.Notice(fmt.Sprintf(client.t("Ban - %[1]s - added by %[2]s - %[3]s"), key, info.OperName, info.BanMessage("%s"))) + client.Notice(formatBanForListing(client, key, info)) } return false @@ -1070,8 +1025,9 @@ func klineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res // duration duration, err := custime.ParseDuration(msg.Params[currentArg]) - durationIsUsed := err == nil - if durationIsUsed { + if err != nil { + duration = 0 + } else { currentArg++ } @@ -1112,63 +1068,16 @@ func klineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *Res } // get comment(s) - reason := "No reason given" - operReason := "No reason given" - if len(msg.Params) > currentArg { - tempReason := strings.TrimSpace(msg.Params[currentArg]) - if len(tempReason) > 0 && tempReason != "|" { - tempReasons := strings.SplitN(tempReason, "|", 2) - if tempReasons[0] != "" { - reason = tempReasons[0] - } - if len(tempReasons) > 1 && tempReasons[1] != "" { - operReason = tempReasons[1] - } else { - operReason = reason - } - } - } - - // assemble ban info - var banTime *IPRestrictTime - if durationIsUsed { - banTime = &IPRestrictTime{ - Duration: duration, - Expires: time.Now().Add(duration), - } - } - - info := IPBanInfo{ - Reason: reason, - OperReason: operReason, - OperName: operName, - Time: banTime, - } - - // save in datastore - err = server.store.Update(func(tx *buntdb.Tx) error { - klineKey := fmt.Sprintf(keyKlineEntry, mask) - - // assemble json from ban info - b, err := json.Marshal(info) - if err != nil { - return err - } - - tx.Set(klineKey, string(b), nil) - - return nil - }) + reason, operReason := getReasonsFromParams(msg.Params, currentArg) + err = server.klines.AddMask(mask, duration, reason, operReason, operName) if err != nil { rb.Notice(fmt.Sprintf(client.t("Could not successfully save new K-LINE: %s"), err.Error())) return false } - server.klines.AddMask(mask, banTime, reason, operReason, operName) - var snoDescription string - if durationIsUsed { + if duration != 0 { rb.Notice(fmt.Sprintf(client.t("Added temporary (%[1]s) K-Line for %[2]s"), duration.String(), mask)) snoDescription = fmt.Sprintf(ircfmt.Unescape("%s [%s]$r added temporary (%s) K-Line for %s"), client.nick, operName, duration.String(), mask) } else { @@ -2219,52 +2128,21 @@ func unDLineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *R hostString := msg.Params[0] // check host - var hostAddr net.IP - var hostNet *net.IPNet + hostNet, err := utils.NormalizedNetFromString(hostString) - _, hostNet, err := net.ParseCIDR(hostString) if err != nil { - hostAddr = net.ParseIP(hostString) - } - - if hostAddr == nil && hostNet == nil { rb.Add(nil, server.name, ERR_UNKNOWNERROR, client.nick, msg.Command, client.t("Could not parse IP address or CIDR network")) return false } - if hostNet == nil { - hostString = hostAddr.String() - } else { - hostString = hostNet.String() - } - - // save in datastore - err = server.store.Update(func(tx *buntdb.Tx) error { - dlineKey := fmt.Sprintf(keyDlineEntry, hostString) - - // check if it exists or not - val, err := tx.Get(dlineKey) - if val == "" { - return errNoExistingBan - } else if err != nil { - return err - } - - tx.Delete(dlineKey) - return nil - }) + err = server.dlines.RemoveNetwork(hostNet) if err != nil { rb.Add(nil, server.name, ERR_UNKNOWNERROR, client.nick, msg.Command, fmt.Sprintf(client.t("Could not remove ban [%s]"), err.Error())) return false } - if hostNet == nil { - server.dlines.RemoveIP(hostAddr) - } else { - server.dlines.RemoveNetwork(*hostNet) - } - + hostString = utils.NetToNormalizedString(hostNet) rb.Notice(fmt.Sprintf(client.t("Removed D-Line for %s"), hostString)) server.snomasks.Send(sno.LocalXline, fmt.Sprintf(ircfmt.Unescape("%s$r removed D-Line for %s"), client.nick, hostString)) return false @@ -2288,29 +2166,13 @@ func unKLineHandler(server *Server, client *Client, msg ircmsg.IrcMessage, rb *R mask = mask + "@*" } - // save in datastore - err := server.store.Update(func(tx *buntdb.Tx) error { - klineKey := fmt.Sprintf(keyKlineEntry, mask) - - // check if it exists or not - val, err := tx.Get(klineKey) - if val == "" { - return errNoExistingBan - } else if err != nil { - return err - } - - tx.Delete(klineKey) - return nil - }) + err := server.klines.RemoveMask(mask) if err != nil { rb.Add(nil, server.name, ERR_UNKNOWNERROR, client.nick, msg.Command, fmt.Sprintf(client.t("Could not remove ban [%s]"), err.Error())) return false } - server.klines.RemoveMask(mask) - rb.Notice(fmt.Sprintf(client.t("Removed K-Line for %s"), mask)) server.snomasks.Send(sno.LocalXline, fmt.Sprintf(ircfmt.Unescape("%s$r removed K-Line for %s"), client.nick, mask)) return false diff --git a/irc/kline.go b/irc/kline.go index 22e28c69..29ee7b79 100644 --- a/irc/kline.go +++ b/irc/kline.go @@ -5,14 +5,17 @@ package irc import ( "encoding/json" + "fmt" + "strings" "sync" + "time" "github.com/goshuirc/irc-go/ircmatch" "github.com/tidwall/buntdb" ) const ( - keyKlineEntry = "bans.kline %s" + keyKlineEntry = "bans.klinev2 %s" ) // KLineInfo contains the address itself and expiration time for a given network. @@ -27,15 +30,23 @@ type KLineInfo struct { // KLineManager manages and klines. type KLineManager struct { - sync.RWMutex // tier 1 + sync.RWMutex // tier 1 + persistenceMutex sync.Mutex // tier 2 // kline'd entries - entries map[string]*KLineInfo + entries map[string]KLineInfo + expirationTimers map[string]*time.Timer + server *Server } // NewKLineManager returns a new KLineManager. -func NewKLineManager() *KLineManager { +func NewKLineManager(s *Server) *KLineManager { var km KLineManager - km.entries = make(map[string]*KLineInfo) + km.entries = make(map[string]KLineInfo) + km.expirationTimers = make(map[string]*time.Timer) + km.server = s + + km.loadFromDatastore() + return &km } @@ -53,97 +64,177 @@ func (km *KLineManager) AllBans() map[string]IPBanInfo { } // AddMask adds to the blocked list. -func (km *KLineManager) AddMask(mask string, length *IPRestrictTime, reason, operReason, operName string) { +func (km *KLineManager) AddMask(mask string, duration time.Duration, reason, operReason, operName string) error { + km.persistenceMutex.Lock() + defer km.persistenceMutex.Unlock() + + info := IPBanInfo{ + Reason: reason, + OperReason: operReason, + OperName: operName, + TimeCreated: time.Now(), + Duration: duration, + } + km.addMaskInternal(mask, info) + return km.persistKLine(mask, info) +} + +func (km *KLineManager) addMaskInternal(mask string, info IPBanInfo) { kln := KLineInfo{ Mask: mask, Matcher: ircmatch.MakeMatch(mask), - Info: IPBanInfo{ - Time: length, - Reason: reason, - OperReason: operReason, - OperName: operName, - }, + Info: info, } + + var timeLeft time.Duration + if info.Duration > 0 { + timeLeft = info.timeLeft() + if timeLeft <= 0 { + return + } + } + km.Lock() - km.entries[mask] = &kln - km.Unlock() + defer km.Unlock() + + km.entries[mask] = kln + km.cancelTimer(mask) + + if info.Duration == 0 { + return + } + + // set up new expiration timer + timeCreated := info.TimeCreated + processExpiration := func() { + km.Lock() + defer km.Unlock() + + maskBan, ok := km.entries[mask] + if ok && maskBan.Info.TimeCreated.Equal(timeCreated) { + delete(km.entries, mask) + delete(km.expirationTimers, mask) + } + } + km.expirationTimers[mask] = time.AfterFunc(timeLeft, processExpiration) +} + +func (km *KLineManager) cancelTimer(id string) { + oldTimer := km.expirationTimers[id] + if oldTimer != nil { + oldTimer.Stop() + delete(km.expirationTimers, id) + } +} + +func (km *KLineManager) persistKLine(mask string, info IPBanInfo) error { + // save in datastore + klineKey := fmt.Sprintf(keyKlineEntry, mask) + // assemble json from ban info + b, err := json.Marshal(info) + if err != nil { + return err + } + bstr := string(b) + var setOptions *buntdb.SetOptions + if info.Duration != 0 { + setOptions = &buntdb.SetOptions{Expires: true, TTL: info.Duration} + } + + err = km.server.store.Update(func(tx *buntdb.Tx) error { + _, _, err := tx.Set(klineKey, bstr, setOptions) + return err + }) + + return err + +} + +func (km *KLineManager) unpersistKLine(mask string) error { + // save in datastore + klineKey := fmt.Sprintf(keyKlineEntry, mask) + return km.server.store.Update(func(tx *buntdb.Tx) error { + _, err := tx.Delete(klineKey) + return err + }) } // RemoveMask removes a mask from the blocked list. -func (km *KLineManager) RemoveMask(mask string) { - km.Lock() - delete(km.entries, mask) - km.Unlock() +func (km *KLineManager) RemoveMask(mask string) error { + km.persistenceMutex.Lock() + defer km.persistenceMutex.Unlock() + + present := func() bool { + km.Lock() + defer km.Unlock() + _, ok := km.entries[mask] + if ok { + delete(km.entries, mask) + } + km.cancelTimer(mask) + return ok + }() + + if !present { + return errNoExistingBan + } + + return km.unpersistKLine(mask) } // CheckMasks returns whether or not the hostmask(s) are banned, and how long they are banned for. -func (km *KLineManager) CheckMasks(masks ...string) (isBanned bool, info *IPBanInfo) { - doCleanup := false - defer func() { - // asynchronously remove expired bans - if doCleanup { - go func() { - km.Lock() - defer km.Unlock() - for key, entry := range km.entries { - if entry.Info.Time.IsExpired() { - delete(km.entries, key) - } - } - }() - } - }() - +func (km *KLineManager) CheckMasks(masks ...string) (isBanned bool, info IPBanInfo) { km.RLock() defer km.RUnlock() for _, entryInfo := range km.entries { - if entryInfo.Info.Time != nil && entryInfo.Info.Time.IsExpired() { - doCleanup = true - continue - } - - matches := false for _, mask := range masks { if entryInfo.Matcher.Match(mask) { - matches = true - break + return true, entryInfo.Info } } - if matches { - return true, &entryInfo.Info - } } // no matches! - return false, nil + isBanned = false + return } -func (s *Server) loadKLines() { - s.klines = NewKLineManager() - +func (km *KLineManager) loadFromDatastore() { // load from datastore - s.store.View(func(tx *buntdb.Tx) error { - //TODO(dan): We could make this safer - tx.AscendKeys("bans.kline *", func(key, value string) bool { + klinePrefix := fmt.Sprintf(keyKlineEntry, "") + km.server.store.View(func(tx *buntdb.Tx) error { + tx.AscendGreaterOrEqual("", klinePrefix, func(key, value string) bool { + if !strings.HasPrefix(key, klinePrefix) { + return false + } + // get address name - key = key[len("bans.kline "):] - mask := key + mask := strings.TrimPrefix(key, klinePrefix) // load ban info var info IPBanInfo - json.Unmarshal([]byte(value), &info) + err := json.Unmarshal([]byte(value), &info) + if err != nil { + km.server.logger.Error("internal", "couldn't unmarshal kline", err.Error()) + return true + } // add oper name if it doesn't exist already if info.OperName == "" { - info.OperName = s.name + info.OperName = km.server.name } // add to the server - s.klines.AddMask(mask, info.Time, info.Reason, info.OperReason, info.OperName) + km.addMaskInternal(mask, info) - return true // true to continue I guess? + return true }) return nil }) + +} + +func (s *Server) loadKLines() { + s.klines = NewKLineManager(s) } diff --git a/irc/server.go b/irc/server.go index bd396032..9647b609 100644 --- a/irc/server.go +++ b/irc/server.go @@ -285,11 +285,10 @@ func (server *Server) checkBans(ipaddr net.IP) (banned bool, message string) { if err != nil { // too many connections too quickly from client, tell them and close the connection duration := server.connectionThrottler.BanDuration() - length := &IPRestrictTime{ - Duration: duration, - Expires: time.Now().Add(duration), + if duration == 0 { + return false, "" } - server.dlines.AddIP(ipaddr, length, server.connectionThrottler.BanMessage(), "Exceeded automated connection throttle", "auto.connection.throttler") + server.dlines.AddIP(ipaddr, duration, server.connectionThrottler.BanMessage(), "Exceeded automated connection throttle", "auto.connection.throttler") // they're DLINE'd for 15 minutes or whatever, so we can reset the connection throttle now, // and once their temporary DLINE is finished they can fill up the throttler again @@ -409,11 +408,7 @@ func (server *Server) tryRegister(c *Client) { // check KLINEs isBanned, info := server.klines.CheckMasks(c.AllNickmasks()...) if isBanned { - reason := info.Reason - if info.Time != nil { - reason += fmt.Sprintf(" [%s]", info.Time.Duration.String()) - } - c.Quit(fmt.Sprintf(c.t("You are banned from this server (%s)"), reason)) + c.Quit(info.BanMessage(c.t("You are banned from this server (%s)"))) c.destroy(false) return } diff --git a/irc/utils/net.go b/irc/utils/net.go index bbc19898..c24f7057 100644 --- a/irc/utils/net.go +++ b/irc/utils/net.go @@ -9,6 +9,11 @@ import ( "strings" ) +var ( + // subnet mask for an ipv6 /128: + mask128 = net.CIDRMask(128, 128) +) + // IPString returns a simple IP string from the given net.Addr. func IPString(addr net.Addr) string { addrStr := addr.String() @@ -94,3 +99,60 @@ func IsHostname(name string) bool { return true } + +// NormalizeIPToNet represents an address (v4 or v6) as the v6 /128 CIDR +// containing only it. +func NormalizeIPToNet(addr net.IP) (network net.IPNet) { + // represent ipv4 addresses as ipv6 addresses, using the 4-in-6 prefix + // (actually this should be a no-op for any address returned by ParseIP) + addr = addr.To16() + // the network corresponding to this address is now an ipv6 /128: + return net.IPNet{ + IP: addr, + Mask: mask128, + } +} + +// NormalizeNet normalizes an IPNet to a v6 CIDR, using the 4-in-6 prefix. +// (this is like IP.To16(), but for IPNet instead of IP) +func NormalizeNet(network net.IPNet) (result net.IPNet) { + if len(network.IP) == 16 { + return network + } + ones, _ := network.Mask.Size() + return net.IPNet{ + IP: network.IP.To16(), + // include the 96 bits of the 4-in-6 prefix + Mask: net.CIDRMask(96+ones, 128), + } +} + +// Given a network, produce a human-readable string +// (i.e., CIDR if it's actually a network, IPv6 address if it's a v6 /128, +// dotted quad if it's a v4 /32). +func NetToNormalizedString(network net.IPNet) string { + ones, bits := network.Mask.Size() + if ones == bits && ones == len(network.IP)*8 { + // either a /32 or a /128, output the address: + return network.IP.String() + } + return network.String() +} + +// Parse a human-readable description (an address or CIDR, either v4 or v6) +// into a normalized v6 net.IPNet. +func NormalizedNetFromString(str string) (result net.IPNet, err error) { + _, network, err := net.ParseCIDR(str) + if err == nil { + return NormalizeNet(*network), nil + } + ip := net.ParseIP(str) + if ip == nil { + err = &net.AddrError{ + Err: "Couldn't interpret as either CIDR or address", + Addr: str, + } + return + } + return NormalizeIPToNet(ip), nil +} diff --git a/irc/utils/net_test.go b/irc/utils/net_test.go index 0094d143..0a8d595f 100644 --- a/irc/utils/net_test.go +++ b/irc/utils/net_test.go @@ -4,8 +4,16 @@ package utils +import "net" +import "reflect" import "testing" +func assertEqual(supplied, expected interface{}, t *testing.T) { + if !reflect.DeepEqual(supplied, expected) { + t.Errorf("expected %v but got %v", expected, supplied) + } +} + // hostnames from https://github.com/DanielOaks/irc-parser-tests var ( goodHostnames = []string{ @@ -47,3 +55,94 @@ func TestIsHostname(t *testing.T) { } } } + +func TestNormalizeToNet(t *testing.T) { + a := net.ParseIP("8.8.8.8") + b := net.ParseIP("8.8.4.4") + if a == nil || b == nil { + panic("something has gone very wrong") + } + + aNetwork := NormalizeIPToNet(a) + bNetwork := NormalizeIPToNet(b) + + assertEqual(aNetwork.Contains(a), true, t) + assertEqual(bNetwork.Contains(b), true, t) + assertEqual(aNetwork.Contains(b), false, t) + assertEqual(bNetwork.Contains(a), false, t) + + c := net.ParseIP("2001:4860:4860::8888") + d := net.ParseIP("2001:db8::1") + if c == nil || d == nil { + panic("something has gone very wrong") + } + + cNetwork := NormalizeIPToNet(c) + dNetwork := NormalizeIPToNet(d) + + assertEqual(cNetwork.Contains(c), true, t) + assertEqual(dNetwork.Contains(d), true, t) + assertEqual(dNetwork.Contains(c), false, t) + assertEqual(dNetwork.Contains(a), false, t) + assertEqual(cNetwork.Contains(b), false, t) + assertEqual(aNetwork.Contains(c), false, t) + assertEqual(bNetwork.Contains(c), false, t) + + assertEqual(NetToNormalizedString(aNetwork), "8.8.8.8", t) + assertEqual(NetToNormalizedString(bNetwork), "8.8.4.4", t) + assertEqual(NetToNormalizedString(cNetwork), "2001:4860:4860::8888", t) + assertEqual(NetToNormalizedString(dNetwork), "2001:db8::1", t) +} + +func TestNormalizedNetToString(t *testing.T) { + _, network, err := net.ParseCIDR("8.8.0.0/16") + if err != nil { + panic(err) + } + assertEqual(NetToNormalizedString(*network), "8.8.0.0/16", t) + + normalized := NormalizeNet(*network) + assertEqual(normalized.Contains(net.ParseIP("8.8.4.4")), true, t) + assertEqual(normalized.Contains(net.ParseIP("1.1.1.1")), false, t) + assertEqual(NetToNormalizedString(normalized), "8.8.0.0/16", t) + + _, network, err = net.ParseCIDR("8.8.4.4/32") + if err != nil { + panic(err) + } + assertEqual(NetToNormalizedString(*network), "8.8.4.4", t) + + normalized = NormalizeNet(*network) + assertEqual(normalized.Contains(net.ParseIP("8.8.4.4")), true, t) + assertEqual(normalized.Contains(net.ParseIP("8.8.8.8")), false, t) + assertEqual(NetToNormalizedString(normalized), "8.8.4.4", t) +} + +func TestNormalizedNet(t *testing.T) { + _, network, err := net.ParseCIDR("::ffff:8.8.4.4/128") + assertEqual(err, nil, t) + assertEqual(NetToNormalizedString(*network), "8.8.4.4", t) + + normalizedNet := NormalizeIPToNet(net.ParseIP("8.8.4.4")) + assertEqual(NetToNormalizedString(normalizedNet), "8.8.4.4", t) + + _, network, err = net.ParseCIDR("::ffff:8.8.0.0/112") + assertEqual(err, nil, t) + assertEqual(NetToNormalizedString(*network), "8.8.0.0/16", t) + _, v4Network, err := net.ParseCIDR("8.8.0.0/16") + assertEqual(err, nil, t) + normalizedNet = NormalizeNet(*v4Network) + assertEqual(NetToNormalizedString(normalizedNet), "8.8.0.0/16", t) +} + +func TestNormalizedNetFromString(t *testing.T) { + network, err := NormalizedNetFromString("8.8.4.4/16") + assertEqual(err, nil, t) + assertEqual(NetToNormalizedString(network), "8.8.0.0/16", t) + assertEqual(network.Contains(net.ParseIP("8.8.8.8")), true, t) + + network, err = NormalizedNetFromString("2001:0db8::1") + assertEqual(err, nil, t) + assertEqual(NetToNormalizedString(network), "2001:db8::1", t) + assertEqual(network.Contains(net.ParseIP("2001:0db8::1")), true, t) +}