| package amp |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "os" |
| "path/filepath" |
| "sync" |
| "testing" |
| "time" |
|
|
| "github.com/router-for-me/CLIProxyAPI/v6/internal/config" |
| log "github.com/sirupsen/logrus" |
| "github.com/sirupsen/logrus/hooks/test" |
| ) |
|
|
| func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) { |
| ctx := context.Background() |
|
|
| cases := []struct { |
| name string |
| configKey string |
| envKey string |
| fileJSON string |
| want string |
| }{ |
| {"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"}, |
| {"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"}, |
| {"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"}, |
| {"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"}, |
| {"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"}, |
| {"missing_file_returns_empty", "", "", "", ""}, |
| {"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""}, |
| } |
|
|
| for _, tc := range cases { |
| tc := tc |
| t.Run(tc.name, func(t *testing.T) { |
| tmpDir := t.TempDir() |
| secretsPath := filepath.Join(tmpDir, "secrets.json") |
|
|
| if tc.fileJSON != "" { |
| if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil { |
| t.Fatal(err) |
| } |
| } |
|
|
| t.Setenv("AMP_API_KEY", tc.envKey) |
|
|
| s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond) |
| got, err := s.Get(ctx) |
| if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) { |
| t.Fatalf("unexpected error: %v", err) |
| } |
| if got != tc.want { |
| t.Fatalf("want %q, got %q", tc.want, got) |
| } |
| }) |
| } |
| } |
|
|
| func TestMultiSourceSecret_CacheBehavior(t *testing.T) { |
| ctx := context.Background() |
| tmpDir := t.TempDir() |
| p := filepath.Join(tmpDir, "secrets.json") |
|
|
| |
| if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil { |
| t.Fatal(err) |
| } |
|
|
| s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond) |
|
|
| |
| got1, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("Get failed: %v", err) |
| } |
| if got1 != "v1" { |
| t.Fatalf("expected v1, got %s", got1) |
| } |
|
|
| |
| if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil { |
| t.Fatal(err) |
| } |
| got2, _ := s.Get(ctx) |
| if got2 != "v1" { |
| t.Fatalf("cache hit expected v1, got %s", got2) |
| } |
|
|
| |
| time.Sleep(60 * time.Millisecond) |
| got3, _ := s.Get(ctx) |
| if got3 != "v2" { |
| t.Fatalf("cache miss expected v2, got %s", got3) |
| } |
|
|
| |
| if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil { |
| t.Fatal(err) |
| } |
| s.InvalidateCache() |
| got4, _ := s.Get(ctx) |
| if got4 != "v3" { |
| t.Fatalf("invalidate expected v3, got %s", got4) |
| } |
| } |
|
|
| func TestMultiSourceSecret_FileHandling(t *testing.T) { |
| ctx := context.Background() |
|
|
| t.Run("missing_file_no_error", func(t *testing.T) { |
| s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond) |
| got, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("expected no error for missing file, got: %v", err) |
| } |
| if got != "" { |
| t.Fatalf("expected empty string, got %q", got) |
| } |
| }) |
|
|
| t.Run("invalid_json", func(t *testing.T) { |
| tmpDir := t.TempDir() |
| p := filepath.Join(tmpDir, "secrets.json") |
| if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil { |
| t.Fatal(err) |
| } |
|
|
| s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) |
| _, err := s.Get(ctx) |
| if err == nil { |
| t.Fatal("expected error for invalid JSON") |
| } |
| }) |
|
|
| t.Run("missing_key_in_json", func(t *testing.T) { |
| tmpDir := t.TempDir() |
| p := filepath.Join(tmpDir, "secrets.json") |
| if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil { |
| t.Fatal(err) |
| } |
|
|
| s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) |
| got, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("unexpected error: %v", err) |
| } |
| if got != "" { |
| t.Fatalf("expected empty string for missing key, got %q", got) |
| } |
| }) |
|
|
| t.Run("empty_key_value", func(t *testing.T) { |
| tmpDir := t.TempDir() |
| p := filepath.Join(tmpDir, "secrets.json") |
| if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil { |
| t.Fatal(err) |
| } |
|
|
| s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) |
| got, _ := s.Get(ctx) |
| if got != "" { |
| t.Fatalf("expected empty after trim, got %q", got) |
| } |
| }) |
| } |
|
|
| func TestMultiSourceSecret_Concurrency(t *testing.T) { |
| tmpDir := t.TempDir() |
| p := filepath.Join(tmpDir, "secrets.json") |
| if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil { |
| t.Fatal(err) |
| } |
|
|
| s := NewMultiSourceSecretWithPath("", p, 5*time.Second) |
| ctx := context.Background() |
|
|
| |
| const goroutines = 50 |
| const iterations = 100 |
|
|
| var wg sync.WaitGroup |
| errors := make(chan error, goroutines) |
|
|
| for i := 0; i < goroutines; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| for j := 0; j < iterations; j++ { |
| val, err := s.Get(ctx) |
| if err != nil { |
| errors <- err |
| return |
| } |
| if val != "concurrent" { |
| errors <- err |
| return |
| } |
| } |
| }() |
| } |
|
|
| wg.Wait() |
| close(errors) |
|
|
| for err := range errors { |
| t.Errorf("concurrency error: %v", err) |
| } |
| } |
|
|
| func TestStaticSecretSource(t *testing.T) { |
| ctx := context.Background() |
|
|
| t.Run("returns_provided_key", func(t *testing.T) { |
| s := NewStaticSecretSource("test-key-123") |
| got, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("unexpected error: %v", err) |
| } |
| if got != "test-key-123" { |
| t.Fatalf("want test-key-123, got %q", got) |
| } |
| }) |
|
|
| t.Run("trims_whitespace", func(t *testing.T) { |
| s := NewStaticSecretSource(" test-key ") |
| got, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("unexpected error: %v", err) |
| } |
| if got != "test-key" { |
| t.Fatalf("want test-key, got %q", got) |
| } |
| }) |
|
|
| t.Run("empty_string", func(t *testing.T) { |
| s := NewStaticSecretSource("") |
| got, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("unexpected error: %v", err) |
| } |
| if got != "" { |
| t.Fatalf("want empty string, got %q", got) |
| } |
| }) |
| } |
|
|
| func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) { |
| |
| tmpDir := t.TempDir() |
| p := filepath.Join(tmpDir, "nonexistent.json") |
|
|
| s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond) |
| ctx := context.Background() |
|
|
| |
| got1, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("expected no error for missing file, got: %v", err) |
| } |
| if got1 != "" { |
| t.Fatalf("expected empty string, got %q", got1) |
| } |
|
|
| |
| if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil { |
| t.Fatal(err) |
| } |
|
|
| |
| got2, _ := s.Get(ctx) |
| if got2 != "" { |
| t.Fatalf("cache should return empty, got %q", got2) |
| } |
|
|
| |
| time.Sleep(110 * time.Millisecond) |
| got3, _ := s.Get(ctx) |
| if got3 != "new-value" { |
| t.Fatalf("after cache expiry, expected new-value, got %q", got3) |
| } |
| } |
|
|
| func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) { |
| defaultSource := NewStaticSecretSource("default") |
| s := NewMappedSecretSource(defaultSource) |
| s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ |
| { |
| UpstreamAPIKey: "u1", |
| APIKeys: []string{"k1"}, |
| }, |
| }) |
|
|
| ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1") |
| got, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("unexpected error: %v", err) |
| } |
| if got != "u1" { |
| t.Fatalf("want u1, got %q", got) |
| } |
|
|
| ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2") |
| got, err = s.Get(ctx) |
| if err != nil { |
| t.Fatalf("unexpected error: %v", err) |
| } |
| if got != "default" { |
| t.Fatalf("want default fallback, got %q", got) |
| } |
| } |
|
|
| func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) { |
| defaultSource := NewStaticSecretSource("default") |
| s := NewMappedSecretSource(defaultSource) |
| s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ |
| { |
| UpstreamAPIKey: "u1", |
| APIKeys: []string{"k1"}, |
| }, |
| { |
| UpstreamAPIKey: "u2", |
| APIKeys: []string{"k1"}, |
| }, |
| }) |
|
|
| ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1") |
| got, err := s.Get(ctx) |
| if err != nil { |
| t.Fatalf("unexpected error: %v", err) |
| } |
| if got != "u1" { |
| t.Fatalf("want u1 (first wins), got %q", got) |
| } |
| } |
|
|
| func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) { |
| hook := test.NewLocal(log.StandardLogger()) |
| defer hook.Reset() |
|
|
| defaultSource := NewStaticSecretSource("default") |
| s := NewMappedSecretSource(defaultSource) |
| s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{ |
| { |
| UpstreamAPIKey: "u1", |
| APIKeys: []string{"k1"}, |
| }, |
| { |
| UpstreamAPIKey: "u2", |
| APIKeys: []string{"k1"}, |
| }, |
| }) |
|
|
| foundWarning := false |
| for _, entry := range hook.AllEntries() { |
| if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." { |
| foundWarning = true |
| break |
| } |
| } |
| if !foundWarning { |
| t.Fatal("expected warning log for duplicate client key, but none was found") |
| } |
| } |
|
|