feat(config): load client settings from config file
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
@@ -69,7 +71,7 @@ func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) {
|
||||
|
||||
cfgValues := make(map[string]string)
|
||||
for _, def := range defs {
|
||||
cfgValues[def.Key] = getConfigValue(cmdFlags, def)
|
||||
cfgValues[def.Key] = getConfigValue(cmdFlags, nil, def)
|
||||
}
|
||||
|
||||
cfg := &ServerConfig{
|
||||
@@ -86,9 +88,15 @@ func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) {
|
||||
func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
|
||||
defs := GetConfigDefs[ClientConfig]()
|
||||
|
||||
// Load Client Config File
|
||||
fileValues, err := getClientConfigFileValues()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfgValues := make(map[string]string)
|
||||
for _, def := range defs {
|
||||
cfgValues[def.Key] = getConfigValue(cmdFlags, def)
|
||||
cfgValues[def.Key] = getConfigValue(cmdFlags, fileValues, def)
|
||||
}
|
||||
|
||||
cfg := &ClientConfig{
|
||||
@@ -122,7 +130,33 @@ func getBaseConfig(cfgValues map[string]string) BaseConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string {
|
||||
func getClientConfigFileValues() (map[string]string, error) {
|
||||
path, err := findConfigFile()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if path == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Load Config File
|
||||
values, err := loadConfigFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Keep Client File Settings Explicit - Tunnel name and target are intentionally
|
||||
// not read from the config file because they should be provided per invocation.
|
||||
clientValues := make(map[string]string)
|
||||
for key, value := range values {
|
||||
if isClientFileConfigKey(key) {
|
||||
clientValues[key] = value
|
||||
}
|
||||
}
|
||||
return clientValues, nil
|
||||
}
|
||||
|
||||
func getConfigValue(cmdFlags *pflag.FlagSet, fileValues map[string]string, def ConfigDef) string {
|
||||
// 1. Get Flags First
|
||||
if cmdFlags != nil {
|
||||
if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" && val != def.Default {
|
||||
@@ -135,10 +169,65 @@ func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string {
|
||||
return envVal
|
||||
}
|
||||
|
||||
// 3. Defaults Last
|
||||
// 3. Config File Next
|
||||
if fileValues != nil {
|
||||
if val := fileValues[def.Key]; val != "" {
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Defaults Last
|
||||
return def.Default
|
||||
}
|
||||
|
||||
func findConfigFile() (string, error) {
|
||||
// Check Project Config
|
||||
localPath := "conduit.json"
|
||||
if _, err := os.Stat(localPath); err == nil {
|
||||
return localPath, nil
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check User Config
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
userPath := filepath.Join(configDir, "conduit", "config.json")
|
||||
if _, err := os.Stat(userPath); err == nil {
|
||||
return userPath, nil
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func loadConfigFile(path string) (map[string]string, error) {
|
||||
// Read Config File
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decode Config File
|
||||
values := make(map[string]string)
|
||||
if err := json.Unmarshal(data, &values); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file %s: %w", path, err)
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func isClientFileConfigKey(key string) bool {
|
||||
switch key {
|
||||
case "server", "api_key", "log_level", "log_format":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func processFields(t reflect.Type, defs *[]ConfigDef) {
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
144
config/config_test.go
Normal file
144
config/config_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
func TestLoadConfigFile(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
if err := os.WriteFile(path, []byte(`{"server":"https://example.com","api_key":"secret"}`), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Load Config File
|
||||
values, err := loadConfigFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify Values
|
||||
if values["server"] != "https://example.com" {
|
||||
t.Fatalf("expected server from config file, got %q", values["server"])
|
||||
}
|
||||
if values["api_key"] != "secret" {
|
||||
t.Fatalf("expected api_key from config file, got %q", values["api_key"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindConfigFile(t *testing.T) {
|
||||
workDir := t.TempDir()
|
||||
configDir := t.TempDir()
|
||||
oldDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Chdir(oldDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
if err := os.Chdir(workDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Setenv("XDG_CONFIG_HOME", configDir)
|
||||
|
||||
// Missing Config File
|
||||
path, err := findConfigFile()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if path != "" {
|
||||
t.Fatalf("expected no config file, got %q", path)
|
||||
}
|
||||
|
||||
// User Config File
|
||||
userPath := filepath.Join(configDir, "conduit", "config.json")
|
||||
if err := os.MkdirAll(filepath.Dir(userPath), 0o700); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(userPath, []byte(`{}`), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
path, err = findConfigFile()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if path != userPath {
|
||||
t.Fatalf("expected user config file %q, got %q", userPath, path)
|
||||
}
|
||||
|
||||
// Local Config File Precedence
|
||||
localPath := "conduit.json"
|
||||
if err := os.WriteFile(localPath, []byte(`{}`), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
path, err = findConfigFile()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if path != localPath {
|
||||
t.Fatalf("expected local config file %q, got %q", localPath, path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConfigValuePriority(t *testing.T) {
|
||||
def := ConfigDef{Key: "server", Env: "CONDUIT_SERVER", Default: "default"}
|
||||
fileValues := map[string]string{"server": "file"}
|
||||
|
||||
// Config File Beats Default
|
||||
if value := getConfigValue(nil, fileValues, def); value != "file" {
|
||||
t.Fatalf("expected file value, got %q", value)
|
||||
}
|
||||
|
||||
// Environment Beats Config File
|
||||
t.Setenv("CONDUIT_SERVER", "env")
|
||||
if value := getConfigValue(nil, fileValues, def); value != "env" {
|
||||
t.Fatalf("expected env value, got %q", value)
|
||||
}
|
||||
|
||||
// Flags Beat Environment
|
||||
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
||||
flags.String("server", "default", "server")
|
||||
if err := flags.Set("server", "flag"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if value := getConfigValue(flags, fileValues, def); value != "flag" {
|
||||
t.Fatalf("expected flag value, got %q", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientConfigFileValuesIgnoresTunnelSettings(t *testing.T) {
|
||||
workDir := t.TempDir()
|
||||
oldDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Chdir(oldDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
if err := os.Chdir(workDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Write Local Config File
|
||||
if err := os.WriteFile("conduit.json", []byte(`{"server":"https://example.com","api_key":"secret","name":"saved","target":"localhost:3000"}`), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
values, err := getClientConfigFileValues()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if values["server"] != "https://example.com" {
|
||||
t.Fatalf("expected server from config file, got %q", values["server"])
|
||||
}
|
||||
if values["name"] != "" || values["target"] != "" {
|
||||
t.Fatalf("expected tunnel settings to be ignored, got name=%q target=%q", values["name"], values["target"])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user