strings: Follow latest advice on PRECIS regarding string stabilizing

This commit is contained in:
Daniel Oaks 2017-08-17 18:23:24 +10:00
parent f9ef97b204
commit ac91a3e484
2 changed files with 36 additions and 13 deletions

View File

@ -17,13 +17,31 @@ const (
) )
var ( var (
errCouldNotStabilize = errors.New("Could not stabilize string while casefolding")
errInvalidCharacter = errors.New("Invalid character") errInvalidCharacter = errors.New("Invalid character")
errEmpty = errors.New("String is empty") errEmpty = errors.New("String is empty")
) )
// Casefold returns a casefolded string, without doing any name or channel character checks. // Casefold returns a casefolded string, without doing any name or channel character checks.
func Casefold(str string) (string, error) { func Casefold(str string) (string, error) {
return precis.UsernameCaseMapped.CompareKey(str) var err error
oldStr := str
// follow the stabilizing rules laid out here:
// https://tools.ietf.org/html/draft-ietf-precis-7564bis-10.html#section-7
for i := 0; i < 4; i++ {
str, err = precis.UsernameCaseMapped.CompareKey(str)
if err != nil {
return "", err
}
if oldStr == str {
break
}
oldStr = str
}
if oldStr != str {
return "", errCouldNotStabilize
}
return str, nil
} }
// CasefoldChannel returns a casefolded version of a channel name. // CasefoldChannel returns a casefolded version of a channel name.

View File

@ -1,4 +1,5 @@
// Copyright (c) 2017 Euan Kemp // Copyright (c) 2017 Euan Kemp
// Copyright (c) 2017 Daniel Oaks
// released under the MIT license // released under the MIT license
package irc package irc
@ -50,14 +51,16 @@ func TestCasefoldChannel(t *testing.T) {
for i, tt := range testCases { for i, tt := range testCases {
t.Run(fmt.Sprintf("case %d: %s", i, tt.channel), func(t *testing.T) { t.Run(fmt.Sprintf("case %d: %s", i, tt.channel), func(t *testing.T) {
res, err := CasefoldChannel(tt.channel) res, err := CasefoldChannel(tt.channel)
if tt.err { if tt.err && err == nil {
if err == nil { t.Errorf("expected error when casefolding [%s], but did not receive one", tt.channel)
t.Errorf("expected error") return
} }
if !tt.err && err != nil {
t.Errorf("unexpected error while casefolding [%s]: %s", tt.channel, err.Error())
return return
} }
if tt.folded != res { if tt.folded != res {
t.Errorf("expected %v to be %v", tt.folded, res) t.Errorf("expected [%v] to be [%v]", res, tt.folded)
} }
}) })
} }
@ -91,14 +94,16 @@ func TestCasefoldName(t *testing.T) {
for i, tt := range testCases { for i, tt := range testCases {
t.Run(fmt.Sprintf("case %d: %s", i, tt.name), func(t *testing.T) { t.Run(fmt.Sprintf("case %d: %s", i, tt.name), func(t *testing.T) {
res, err := CasefoldName(tt.name) res, err := CasefoldName(tt.name)
if tt.err { if tt.err && err == nil {
if err == nil { t.Errorf("expected error when casefolding [%s], but did not receive one", tt.name)
t.Errorf("expected error") return
} }
if !tt.err && err != nil {
t.Errorf("unexpected error while casefolding [%s]: %s", tt.name, err.Error())
return return
} }
if tt.folded != res { if tt.folded != res {
t.Errorf("expected %v to be %v", tt.folded, res) t.Errorf("expected [%v] to be [%v]", res, tt.folded)
} }
}) })
} }