|
package balancer |
|
|
|
import ( |
|
"jetbrains-ai-proxy/internal/config" |
|
"sync" |
|
"testing" |
|
"time" |
|
) |
|
|
|
func TestNewJWTBalancer(t *testing.T) { |
|
tokens := []string{"token1", "token2", "token3"} |
|
|
|
|
|
balancer := NewJWTBalancer(tokens, config.RoundRobin) |
|
if balancer == nil { |
|
t.Fatal("Expected balancer to be created") |
|
} |
|
|
|
if balancer.GetTotalTokenCount() != 3 { |
|
t.Errorf("Expected 3 tokens, got %d", balancer.GetTotalTokenCount()) |
|
} |
|
|
|
if balancer.GetHealthyTokenCount() != 3 { |
|
t.Errorf("Expected 3 healthy tokens, got %d", balancer.GetHealthyTokenCount()) |
|
} |
|
} |
|
|
|
func TestRoundRobinStrategy(t *testing.T) { |
|
tokens := []string{"token1", "token2", "token3"} |
|
balancer := NewJWTBalancer(tokens, config.RoundRobin) |
|
|
|
|
|
expectedOrder := []string{"token1", "token2", "token3", "token1", "token2", "token3"} |
|
|
|
for i, expected := range expectedOrder { |
|
token, err := balancer.GetToken() |
|
if err != nil { |
|
t.Fatalf("Unexpected error at iteration %d: %v", i, err) |
|
} |
|
|
|
if token != expected { |
|
t.Errorf("At iteration %d, expected %s, got %s", i, expected, token) |
|
} |
|
} |
|
} |
|
|
|
func TestRandomStrategy(t *testing.T) { |
|
tokens := []string{"token1", "token2", "token3"} |
|
balancer := NewJWTBalancer(tokens, config.Random) |
|
|
|
|
|
tokenCounts := make(map[string]int) |
|
iterations := 100 |
|
|
|
for i := 0; i < iterations; i++ { |
|
token, err := balancer.GetToken() |
|
if err != nil { |
|
t.Fatalf("Unexpected error at iteration %d: %v", i, err) |
|
} |
|
|
|
|
|
found := false |
|
for _, expectedToken := range tokens { |
|
if token == expectedToken { |
|
found = true |
|
break |
|
} |
|
} |
|
|
|
if !found { |
|
t.Errorf("Got unexpected token: %s", token) |
|
} |
|
|
|
tokenCounts[token]++ |
|
} |
|
|
|
|
|
for _, token := range tokens { |
|
if tokenCounts[token] == 0 { |
|
t.Errorf("Token %s was never selected", token) |
|
} |
|
} |
|
} |
|
|
|
func TestMarkTokenUnhealthy(t *testing.T) { |
|
tokens := []string{"token1", "token2", "token3"} |
|
balancer := NewJWTBalancer(tokens, config.RoundRobin) |
|
|
|
|
|
balancer.MarkTokenUnhealthy("token2") |
|
|
|
if balancer.GetHealthyTokenCount() != 2 { |
|
t.Errorf("Expected 2 healthy tokens, got %d", balancer.GetHealthyTokenCount()) |
|
} |
|
|
|
|
|
for i := 0; i < 10; i++ { |
|
token, err := balancer.GetToken() |
|
if err != nil { |
|
t.Fatalf("Unexpected error: %v", err) |
|
} |
|
|
|
if token == "token2" { |
|
t.Errorf("Got unhealthy token: %s", token) |
|
} |
|
} |
|
} |
|
|
|
func TestMarkTokenHealthy(t *testing.T) { |
|
tokens := []string{"token1", "token2", "token3"} |
|
balancer := NewJWTBalancer(tokens, config.RoundRobin) |
|
|
|
|
|
balancer.MarkTokenUnhealthy("token2") |
|
if balancer.GetHealthyTokenCount() != 2 { |
|
t.Errorf("Expected 2 healthy tokens after marking unhealthy, got %d", balancer.GetHealthyTokenCount()) |
|
} |
|
|
|
balancer.MarkTokenHealthy("token2") |
|
if balancer.GetHealthyTokenCount() != 3 { |
|
t.Errorf("Expected 3 healthy tokens after marking healthy, got %d", balancer.GetHealthyTokenCount()) |
|
} |
|
} |
|
|
|
func TestNoHealthyTokens(t *testing.T) { |
|
tokens := []string{"token1", "token2"} |
|
balancer := NewJWTBalancer(tokens, config.RoundRobin) |
|
|
|
|
|
balancer.MarkTokenUnhealthy("token1") |
|
balancer.MarkTokenUnhealthy("token2") |
|
|
|
|
|
_, err := balancer.GetToken() |
|
if err == nil { |
|
t.Error("Expected error when no healthy tokens available") |
|
} |
|
} |
|
|
|
func TestConcurrentAccess(t *testing.T) { |
|
tokens := []string{"token1", "token2", "token3", "token4", "token5"} |
|
balancer := NewJWTBalancer(tokens, config.RoundRobin) |
|
|
|
var wg sync.WaitGroup |
|
numGoroutines := 10 |
|
tokensPerGoroutine := 100 |
|
|
|
|
|
for i := 0; i < numGoroutines; i++ { |
|
wg.Add(1) |
|
go func() { |
|
defer wg.Done() |
|
for j := 0; j < tokensPerGoroutine; j++ { |
|
_, err := balancer.GetToken() |
|
if err != nil { |
|
t.Errorf("Unexpected error in concurrent access: %v", err) |
|
} |
|
} |
|
}() |
|
} |
|
|
|
|
|
for i := 0; i < numGoroutines; i++ { |
|
wg.Add(1) |
|
go func(index int) { |
|
defer wg.Done() |
|
token := tokens[index%len(tokens)] |
|
for j := 0; j < 10; j++ { |
|
if j%2 == 0 { |
|
balancer.MarkTokenUnhealthy(token) |
|
} else { |
|
balancer.MarkTokenHealthy(token) |
|
} |
|
time.Sleep(time.Millisecond) |
|
} |
|
}(i) |
|
} |
|
|
|
wg.Wait() |
|
|
|
|
|
if balancer.GetTotalTokenCount() != len(tokens) { |
|
t.Errorf("Expected %d total tokens, got %d", len(tokens), balancer.GetTotalTokenCount()) |
|
} |
|
} |
|
|
|
func TestRefreshTokens(t *testing.T) { |
|
tokens := []string{"token1", "token2"} |
|
balancer := NewJWTBalancer(tokens, config.RoundRobin) |
|
|
|
if balancer.GetTotalTokenCount() != 2 { |
|
t.Errorf("Expected 2 tokens initially, got %d", balancer.GetTotalTokenCount()) |
|
} |
|
|
|
|
|
newTokens := []string{"token3", "token4", "token5"} |
|
balancer.RefreshTokens(newTokens) |
|
|
|
if balancer.GetTotalTokenCount() != 3 { |
|
t.Errorf("Expected 3 tokens after refresh, got %d", balancer.GetTotalTokenCount()) |
|
} |
|
|
|
if balancer.GetHealthyTokenCount() != 3 { |
|
t.Errorf("Expected 3 healthy tokens after refresh, got %d", balancer.GetHealthyTokenCount()) |
|
} |
|
|
|
|
|
for i := 0; i < 6; i++ { |
|
token, err := balancer.GetToken() |
|
if err != nil { |
|
t.Fatalf("Unexpected error: %v", err) |
|
} |
|
|
|
found := false |
|
for _, newToken := range newTokens { |
|
if token == newToken { |
|
found = true |
|
break |
|
} |
|
} |
|
|
|
if !found { |
|
t.Errorf("Got unexpected token after refresh: %s", token) |
|
} |
|
} |
|
} |
|
|