3
0
mirror of https://github.com/ergochat/ergo.git synced 2024-11-26 13:59:44 +01:00

Merge pull request #1424 from slingamn/import_enhancements

validate amode recipients during data import
This commit is contained in:
Shivaram Lingamneni 2020-12-04 01:55:01 -08:00 committed by GitHub
commit 7624936d8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -53,17 +53,18 @@ type databaseImport struct {
Channels map[string]channelImport Channels map[string]channelImport
} }
func serializeAmodes(raw map[string]string) (result []byte, err error) { func serializeAmodes(raw map[string]string, validCfUsernames utils.StringSet) (result []byte, err error) {
processed := make(map[string]int, len(raw)) processed := make(map[string]int, len(raw))
for accountName, mode := range raw { for accountName, mode := range raw {
if len(mode) != 1 { if len(mode) != 1 {
return nil, fmt.Errorf("invalid mode %s for account %s", mode, accountName) return nil, fmt.Errorf("invalid mode %s for account %s", mode, accountName)
} }
cfname, err := CasefoldName(accountName) cfname, err := CasefoldName(accountName)
if err != nil { if err != nil || !validCfUsernames.Has(cfname) {
return nil, fmt.Errorf("invalid amode recipient %s: %w", accountName, err) log.Printf("skipping invalid amode recipient %s\n", accountName)
} else {
processed[cfname] = int(mode[0])
} }
processed[cfname] = int(mode[0])
} }
result, err = json.Marshal(processed) result, err = json.Marshal(processed)
return return
@ -78,6 +79,8 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden
tx.Set(keySchemaVersion, strconv.Itoa(importDBSchemaVersion), nil) tx.Set(keySchemaVersion, strconv.Itoa(importDBSchemaVersion), nil)
tx.Set(keyCloakSecret, utils.GenerateSecretKey(), nil) tx.Set(keyCloakSecret, utils.GenerateSecretKey(), nil)
cfUsernames := make(utils.StringSet)
for username, userInfo := range dbImport.Users { for username, userInfo := range dbImport.Users {
cfUsername, err := CasefoldName(username) cfUsername, err := CasefoldName(username)
if err != nil { if err != nil {
@ -118,6 +121,7 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden
for _, certfp := range certfps { for _, certfp := range certfps {
tx.Set(fmt.Sprintf(keyCertToAccount, certfp), cfUsername, nil) tx.Set(fmt.Sprintf(keyCertToAccount, certfp), cfUsername, nil)
} }
cfUsernames.Add(cfUsername)
} }
for chname, chInfo := range dbImport.Channels { for chname, chInfo := range dbImport.Channels {
@ -149,7 +153,7 @@ func doImportDBGeneric(config *Config, dbImport databaseImport, credsType Creden
tx.Set(fmt.Sprintf(keyChannelTopicSetBy, cfchname), chInfo.TopicSetBy, nil) tx.Set(fmt.Sprintf(keyChannelTopicSetBy, cfchname), chInfo.TopicSetBy, nil)
} }
if len(chInfo.Amode) != 0 { if len(chInfo.Amode) != 0 {
m, err := serializeAmodes(chInfo.Amode) m, err := serializeAmodes(chInfo.Amode, cfUsernames)
if err == nil { if err == nil {
tx.Set(fmt.Sprintf(keyChannelAccountToUMode, cfchname), string(m), nil) tx.Set(fmt.Sprintf(keyChannelAccountToUMode, cfchname), string(m), nil)
} else { } else {