diff --git a/AGENTS.md b/AGENTS.md index 2787ab6..5de3dbb 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -31,7 +31,7 @@ tunnel/stream.go — Stream interface (io.ReadWriteCloser + Source/Target server/reconstructed_conn.go — Replays re-serialized headers + buffered body + raw conn after hijack store/store.go — In-memory request/response recorder with pub/sub (SSE) web/web.go — Local tunnel monitor (port 8181), SSE endpoint -config/config.go — Reflection-based config from struct tags → flags + env vars +config/config.go — Reflection-based config from struct tags → flags + env vars + client config file pkg/maps/map.go — Generic sync.RWMutex-guarded map ``` @@ -39,7 +39,7 @@ pkg/maps/map.go — Generic sync.RWMutex-guarded map - **Go style**: standard `gofmt`, golangci-lint with `.golangci.toml` - **Comment style**: Title Case heading above logical blocks (see root `AGENTS.md`) -- **Config**: add struct tags (`json`, `default`, `description`) to `ServerConfig` or `ClientConfig` — flags and env vars are auto-derived +- **Config**: add struct tags (`json`, `default`, `description`) to `ServerConfig` or `ClientConfig` — flags and env vars are auto-derived. Client config may also come from `./conduit.json` or `~/.config/conduit/config.json` for `server`, `api_key`, `log_level`, and `log_format` only. - **Logging**: use `logrus` (`log` alias); structured fields preferred - **Concurrency**: use `pkg/maps.Map` for shared maps; protect other shared state with `sync.Mutex` - **Error handling**: return errors up; log at command/entry-point level. Use `fmt.Errorf` with `%w` for wrapping diff --git a/config/config.go b/config/config.go index 7c969d4..572c1a3 100644 --- a/config/config.go +++ b/config/config.go @@ -1,10 +1,12 @@ package config import ( + "encoding/json" "errors" "fmt" "net/url" "os" + "path/filepath" "reflect" "strings" @@ -69,7 +71,7 @@ func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) { cfgValues := make(map[string]string) for _, def := range defs { - cfgValues[def.Key] = getConfigValue(cmdFlags, def) + cfgValues[def.Key] = getConfigValue(cmdFlags, nil, def) } cfg := &ServerConfig{ @@ -86,9 +88,15 @@ func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) { func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) { defs := GetConfigDefs[ClientConfig]() + // Load Client Config File + fileValues, err := getClientConfigFileValues() + if err != nil { + return nil, err + } + cfgValues := make(map[string]string) for _, def := range defs { - cfgValues[def.Key] = getConfigValue(cmdFlags, def) + cfgValues[def.Key] = getConfigValue(cmdFlags, fileValues, def) } cfg := &ClientConfig{ @@ -122,7 +130,33 @@ func getBaseConfig(cfgValues map[string]string) BaseConfig { } } -func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string { +func getClientConfigFileValues() (map[string]string, error) { + path, err := findConfigFile() + if err != nil { + return nil, err + } + if path == "" { + return nil, nil + } + + // Load Config File + values, err := loadConfigFile(path) + if err != nil { + return nil, err + } + + // Keep Client File Settings Explicit - Tunnel name and target are intentionally + // not read from the config file because they should be provided per invocation. + clientValues := make(map[string]string) + for key, value := range values { + if isClientFileConfigKey(key) { + clientValues[key] = value + } + } + return clientValues, nil +} + +func getConfigValue(cmdFlags *pflag.FlagSet, fileValues map[string]string, def ConfigDef) string { // 1. Get Flags First if cmdFlags != nil { if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" && val != def.Default { @@ -135,10 +169,65 @@ func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string { return envVal } - // 3. Defaults Last + // 3. Config File Next + if fileValues != nil { + if val := fileValues[def.Key]; val != "" { + return val + } + } + + // 4. Defaults Last return def.Default } +func findConfigFile() (string, error) { + // Check Project Config + localPath := "conduit.json" + if _, err := os.Stat(localPath); err == nil { + return localPath, nil + } else if !errors.Is(err, os.ErrNotExist) { + return "", err + } + + // Check User Config + configDir, err := os.UserConfigDir() + if err != nil { + return "", nil + } + userPath := filepath.Join(configDir, "conduit", "config.json") + if _, err := os.Stat(userPath); err == nil { + return userPath, nil + } else if !errors.Is(err, os.ErrNotExist) { + return "", err + } + + return "", nil +} + +func loadConfigFile(path string) (map[string]string, error) { + // Read Config File + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + // Decode Config File + values := make(map[string]string) + if err := json.Unmarshal(data, &values); err != nil { + return nil, fmt.Errorf("failed to parse config file %s: %w", path, err) + } + return values, nil +} + +func isClientFileConfigKey(key string) bool { + switch key { + case "server", "api_key", "log_level", "log_format": + return true + default: + return false + } +} + func processFields(t reflect.Type, defs *[]ConfigDef) { for i := 0; i < t.NumField(); i++ { field := t.Field(i) diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..bcb62df --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,144 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/spf13/pflag" +) + +func TestLoadConfigFile(t *testing.T) { + path := filepath.Join(t.TempDir(), "config.json") + if err := os.WriteFile(path, []byte(`{"server":"https://example.com","api_key":"secret"}`), 0o600); err != nil { + t.Fatal(err) + } + + // Load Config File + values, err := loadConfigFile(path) + if err != nil { + t.Fatal(err) + } + + // Verify Values + if values["server"] != "https://example.com" { + t.Fatalf("expected server from config file, got %q", values["server"]) + } + if values["api_key"] != "secret" { + t.Fatalf("expected api_key from config file, got %q", values["api_key"]) + } +} + +func TestFindConfigFile(t *testing.T) { + workDir := t.TempDir() + configDir := t.TempDir() + oldDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { + if err := os.Chdir(oldDir); err != nil { + t.Fatal(err) + } + }() + if err := os.Chdir(workDir); err != nil { + t.Fatal(err) + } + t.Setenv("XDG_CONFIG_HOME", configDir) + + // Missing Config File + path, err := findConfigFile() + if err != nil { + t.Fatal(err) + } + if path != "" { + t.Fatalf("expected no config file, got %q", path) + } + + // User Config File + userPath := filepath.Join(configDir, "conduit", "config.json") + if err := os.MkdirAll(filepath.Dir(userPath), 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(userPath, []byte(`{}`), 0o600); err != nil { + t.Fatal(err) + } + path, err = findConfigFile() + if err != nil { + t.Fatal(err) + } + if path != userPath { + t.Fatalf("expected user config file %q, got %q", userPath, path) + } + + // Local Config File Precedence + localPath := "conduit.json" + if err := os.WriteFile(localPath, []byte(`{}`), 0o600); err != nil { + t.Fatal(err) + } + path, err = findConfigFile() + if err != nil { + t.Fatal(err) + } + if path != localPath { + t.Fatalf("expected local config file %q, got %q", localPath, path) + } +} + +func TestGetConfigValuePriority(t *testing.T) { + def := ConfigDef{Key: "server", Env: "CONDUIT_SERVER", Default: "default"} + fileValues := map[string]string{"server": "file"} + + // Config File Beats Default + if value := getConfigValue(nil, fileValues, def); value != "file" { + t.Fatalf("expected file value, got %q", value) + } + + // Environment Beats Config File + t.Setenv("CONDUIT_SERVER", "env") + if value := getConfigValue(nil, fileValues, def); value != "env" { + t.Fatalf("expected env value, got %q", value) + } + + // Flags Beat Environment + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.String("server", "default", "server") + if err := flags.Set("server", "flag"); err != nil { + t.Fatal(err) + } + if value := getConfigValue(flags, fileValues, def); value != "flag" { + t.Fatalf("expected flag value, got %q", value) + } +} + +func TestGetClientConfigFileValuesIgnoresTunnelSettings(t *testing.T) { + workDir := t.TempDir() + oldDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { + if err := os.Chdir(oldDir); err != nil { + t.Fatal(err) + } + }() + if err := os.Chdir(workDir); err != nil { + t.Fatal(err) + } + + // Write Local Config File + if err := os.WriteFile("conduit.json", []byte(`{"server":"https://example.com","api_key":"secret","name":"saved","target":"localhost:3000"}`), 0o600); err != nil { + t.Fatal(err) + } + + values, err := getClientConfigFileValues() + if err != nil { + t.Fatal(err) + } + if values["server"] != "https://example.com" { + t.Fatalf("expected server from config file, got %q", values["server"]) + } + if values["name"] != "" || values["target"] != "" { + t.Fatalf("expected tunnel settings to be ignored, got name=%q target=%q", values["name"], values["target"]) + } +}