3
0
mirror of https://github.com/ergochat/ergo.git synced 2026-03-16 04:08:12 +01:00

fix environment overrides against map fields

This commit is contained in:
Shivaram Lingamneni 2026-03-15 05:28:52 +00:00
parent 3bc010a6d9
commit 218cac3037
3 changed files with 152 additions and 36 deletions

View File

@ -181,7 +181,6 @@ However, settings that were overridden using this technique cannot be rehashed -
Due to implementation details, this technique has some limitations. Here are the known issues:
1. `opers` must be overridden in its entirety with `ERGO__OPERS` (you cannot override the properties of individual operators, e.g. with `ERGO__OPERS__ADMIN__PASSWORD`)
1. `accounts.auth-script` and `server.ip-check-script` do not work as expected (see [#2275](https://github.com/ergochat/ergo/issues/2275) for workarounds).
## Productionizing with systemd

View File

@ -1120,55 +1120,104 @@ func mungeFromEnvironment(config *Config, envPair string) (applied bool, name st
pathComponents[i] = screamingSnakeToKebab(pathComponent)
}
type mapInsertion struct {
m reflect.Value
k reflect.Value
v reflect.Value
}
var mapStack []mapInsertion
v := reflect.Indirect(reflect.ValueOf(config))
t := v.Type()
for _, component := range pathComponents {
if component == "" {
return false, "", &configPathError{name, "invalid", nil}
}
if v.Kind() != reflect.Struct {
return false, "", &configPathError{name, "index into non-struct", nil}
}
var nextField reflect.StructField
success := false
n := t.NumField()
// preferentially get a field with an exact yaml tag match,
// then fall back to case-insensitive comparison of field names
for i := 0; i < n; i++ {
field := t.Field(i)
if isExported(field) && field.Tag.Get("yaml") == component {
nextField = field
success = true
break
}
}
if !success {
if v.Kind() == reflect.Struct {
var nextField reflect.StructField
success := false
n := t.NumField()
// preferentially get a field with an exact yaml tag match,
// then fall back to case-insensitive comparison of field names
for i := 0; i < n; i++ {
field := t.Field(i)
if isExported(field) && strings.ToLower(field.Name) == component {
if isExported(field) && field.Tag.Get("yaml") == component {
nextField = field
success = true
break
}
}
}
if !success {
return false, "", &configPathError{name, fmt.Sprintf("couldn't resolve path component: `%s`", component), nil}
}
v = v.FieldByName(nextField.Name)
// dereference pointer field if necessary, initialize new value if necessary
if v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
if !success {
for i := 0; i < n; i++ {
field := t.Field(i)
if isExported(field) && strings.ToLower(field.Name) == component {
nextField = field
success = true
break
}
}
}
v = reflect.Indirect(v)
if !success {
return false, "", &configPathError{name, fmt.Sprintf("couldn't resolve path component: `%s`", component), nil}
}
v = v.FieldByName(nextField.Name)
// dereference pointer field if necessary, initialize new value if necessary
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = reflect.Indirect(v)
case reflect.Map:
if v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}
}
t = v.Type()
} else if v.Kind() == reflect.Map {
keyType := v.Type().Key()
valueType := v.Type().Elem()
if keyType.Kind() != reflect.String {
return false, "", &configPathError{name, "can't index into map unless its keys are strings", nil}
}
// index into the map, returns the zero value (invalid) if not found
key := reflect.ValueOf(component)
v2 := v.MapIndex(key)
if v2.IsValid() {
// make an addressable copy of the existing value:
v3 := reflect.New(valueType).Elem()
v3.Set(v2)
v2 = v3
} else {
// make an addressable value of the map value type:
v2 = reflect.New(valueType).Elem()
// if the map value type is *Baz, set it to a new(Baz):
if valueType.Kind() == reflect.Pointer {
v2.Set(reflect.New(valueType.Elem()))
}
}
// we are not operating directly on the current map member,
// we need to go back later and insert v2 into the map:
mapStack = append(mapStack, mapInsertion{m: v, k: key, v: v2})
if valueType.Kind() != reflect.Pointer {
v = v2
} else {
v = reflect.Indirect(v2)
}
t = v.Type()
} else {
return false, "", &configPathError{name, "can't index into fields other than struct or map", nil}
}
t = v.Type()
}
yamlErr := yaml.Unmarshal([]byte(value), v.Addr().Interface())
if yamlErr != nil {
return false, "", &configPathError{name, "couldn't deserialize YAML", yamlErr}
}
// go back and do all map assignments
for i := len(mapStack) - 1; i >= 0; i-- {
elem := mapStack[i]
elem.m.SetMapIndex(elem.k, elem.v)
}
return true, name, nil
}

View File

@ -8,6 +8,15 @@ import (
"testing"
)
func mungeEnvForTesting(config *Config, env []string, t *testing.T) {
for _, envPair := range env {
_, _, err := mungeFromEnvironment(config, envPair)
if err != nil {
t.Errorf("couldn't apply override `%s`: %v", envPair, err)
}
}
}
func TestEnvironmentOverrides(t *testing.T) {
var config Config
config.Server.Compatibility.SendUnprefixedSasl = true
@ -16,6 +25,12 @@ func TestEnvironmentOverrides(t *testing.T) {
config.Accounts.DefaultUserModes = &defaultUserModes
config.Server.WebSockets.AllowedOrigins = []string{"https://www.ircv3.net"}
config.Server.MOTD = "long.motd.txt" // overwrite this
config.Opers = map[string]*OperConfig{
"admin": {
Class: "server-admin",
Password: "adminpassword",
},
}
env := []string{
`USER=shivaram`, // unrelated var
`ORAGONO_USER=oragono`, // this should be ignored as well
@ -26,13 +41,11 @@ func TestEnvironmentOverrides(t *testing.T) {
`ORAGONO__ACCOUNTS__NICK_RESERVATION__ENABLED=true`,
`ERGO__ACCOUNTS__DEFAULT_USER_MODES="+iR"`,
`ORAGONO__SERVER__IP_CLOAKING={"enabled": true, "enabled-for-always-on": true, "netname": "irc", "cidr-len-ipv4": 32, "cidr-len-ipv6": 64, "num-bits": 64}`,
`ERGO__OPERS__ADMIN__PASSWORD="newadminpassword"`,
`ERGO__OPERS__OPERUSER={"class": "server-admin", "whois-line": "is a server admin", "password": "operpassword"}`,
}
for _, envPair := range env {
_, _, err := mungeFromEnvironment(&config, envPair)
if err != nil {
t.Errorf("couldn't apply override `%s`: %v", envPair, err)
}
}
mungeEnvForTesting(&config, env, t)
if config.Network.Name != "example.com" {
t.Errorf("unexpected value of network.name: %s", config.Network.Name)
@ -68,6 +81,61 @@ func TestEnvironmentOverrides(t *testing.T) {
if *config.Accounts.DefaultUserModes != "+iR" {
t.Errorf("couldn't override pre-set ptr field")
}
if (*config.Opers["admin"]).Password != "newadminpassword" {
t.Errorf("couldn't index into map and then overwrite")
}
if (*config.Opers["operuser"]).Password != "operpassword" {
t.Errorf("couldn't create new entry in map")
}
}
func TestEnvironmentInitializeNilMap(t *testing.T) {
var config Config
env := []string{
`ERGO__OPERS__OPERUSER={"class": "server-admin", "whois-line": "is a server admin", "password": "operpassword"}`,
}
mungeEnvForTesting(&config, env, t)
assertEqual((*config.Opers["operuser"]).Password, "operpassword")
// try with an initialized but empty map:
config.Opers = make(map[string]*OperConfig)
mungeEnvForTesting(&config, env, t)
assertEqual((*config.Opers["operuser"]).Password, "operpassword")
}
func TestEnvironmentCreateNewMap(t *testing.T) {
var config Config
env := []string{
`ERGO__OPERS={"operuser": {"class": "server-admin", "whois-line": "is a server admin", "password": "operpassword"}}`,
}
mungeEnvForTesting(&config, env, t)
operPassword := (*config.Opers["operuser"]).Password
if operPassword != "operpassword" {
t.Errorf("unexpected value of operator password: %s", operPassword)
}
// try with an initialized but empty map:
config.Opers = make(map[string]*OperConfig)
mungeEnvForTesting(&config, env, t)
assertEqual((*config.Opers["operuser"]).Password, "operpassword")
}
func TestEnvironmentNonPointerMap(t *testing.T) {
// edge cases that should not panic, even though the results are unusable
// since all "field names" get lowercased:
var config Config
config.Server.AdditionalISupport = map[string]string{"extban": "a"}
env := []string{
`ERGO__SERVER__ADDITIONAL_ISUPPORT__EXTBAN=~,a`,
`ERGO__FAKELAG__COMMAND_BUDGETS__PRIVMSG=10`,
}
mungeEnvForTesting(&config, env, t)
}
func TestEnvironmentOverrideErrors(t *testing.T) {