diff --git a/irc/strings.go b/irc/strings.go index af0fdecf..d8f09502 100644 --- a/irc/strings.go +++ b/irc/strings.go @@ -44,8 +44,7 @@ func CasefoldChannel(name string) (string, error) { // , is used as a separator // * is used in mask matching // ? is used in mask matching - if strings.Contains(lowered, " ") || strings.Contains(lowered, ",") || - strings.Contains(lowered, "*") || strings.Contains(lowered, "?") { + if strings.ContainsAny(lowered, " ,*?") { return "", errInvalidCharacter } @@ -73,11 +72,7 @@ func CasefoldName(name string) (string, error) { // # is a channel prefix // ~&@%+ are channel membership prefixes // - I feel like disallowing - if strings.Contains(lowered, " ") || strings.Contains(lowered, ",") || - strings.Contains(lowered, "*") || strings.Contains(lowered, "?") || - strings.Contains(lowered, ".") || strings.Contains(lowered, "!") || - strings.Contains(lowered, "@") || - strings.Contains("#~&@%+-", string(lowered[0])) { + if strings.ContainsAny(lowered, " ,*?.!@:") || strings.ContainsAny(string(lowered[0]), "#~&@%+-") { return "", errInvalidCharacter } diff --git a/irc/strings_test.go b/irc/strings_test.go new file mode 100644 index 00000000..808f6ca3 --- /dev/null +++ b/irc/strings_test.go @@ -0,0 +1,105 @@ +// Copyright (c) 2017 Euan Kemp +// released under the MIT license + +package irc + +import ( + "fmt" + "testing" +) + +func TestCasefoldChannel(t *testing.T) { + type channelTest struct { + channel string + folded string + err bool + } + testCases := []channelTest{ + { + channel: "#foo", + folded: "#foo", + }, + { + channel: "#rfc1459[noncompliant]", + folded: "#rfc1459[noncompliant]", + }, + { + channel: "#{[]}", + folded: "#{[]}", + }, + { + channel: "#FOO", + folded: "#foo", + }, + { + channel: "#bang!", + folded: "#bang!", + }, + { + channel: "#", + folded: "#", + }, + } + + for _, errCase := range []string{ + "", "#*starpower", "# NASA", "#interro?", "OOF#", "foo", + } { + testCases = append(testCases, channelTest{channel: errCase, err: true}) + } + + for i, tt := range testCases { + t.Run(fmt.Sprintf("case %d: %s", i, tt.channel), func(t *testing.T) { + res, err := CasefoldChannel(tt.channel) + if tt.err { + if err == nil { + t.Errorf("expected error") + } + return + } + if tt.folded != res { + t.Errorf("expected %v to be %v", tt.folded, res) + } + }) + } +} + +func TestCasefoldName(t *testing.T) { + type nameTest struct { + name string + folded string + err bool + } + testCases := []nameTest{ + { + name: "foo", + folded: "foo", + }, + { + name: "FOO", + folded: "foo", + }, + } + + for _, errCase := range []string{ + "", "#", "foo,bar", "star*man*junior", "lo7t?", + "f.l", "excited!nick", "foo@bar", ":trail", + "~o", "&o", "@o", "%h", "+v", "-m", + } { + testCases = append(testCases, nameTest{name: errCase, err: true}) + } + + for i, tt := range testCases { + t.Run(fmt.Sprintf("case %d: %s", i, tt.name), func(t *testing.T) { + res, err := CasefoldName(tt.name) + if tt.err { + if err == nil { + t.Errorf("expected error") + } + return + } + if tt.folded != res { + t.Errorf("expected %v to be %v", tt.folded, res) + } + }) + } +}