jb2api / internal /balancer /jwt_balancer_test.go
github-actions[bot]
Update from GitHub Actions
6fefda3
package balancer
import (
"jetbrains-ai-proxy/internal/config"
"sync"
"testing"
"time"
)
func TestNewJWTBalancer(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
// 测试轮询策略
balancer := NewJWTBalancer(tokens, config.RoundRobin)
if balancer == nil {
t.Fatal("Expected balancer to be created")
}
if balancer.GetTotalTokenCount() != 3 {
t.Errorf("Expected 3 tokens, got %d", balancer.GetTotalTokenCount())
}
if balancer.GetHealthyTokenCount() != 3 {
t.Errorf("Expected 3 healthy tokens, got %d", balancer.GetHealthyTokenCount())
}
}
func TestRoundRobinStrategy(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
// 测试轮询顺序
expectedOrder := []string{"token1", "token2", "token3", "token1", "token2", "token3"}
for i, expected := range expectedOrder {
token, err := balancer.GetToken()
if err != nil {
t.Fatalf("Unexpected error at iteration %d: %v", i, err)
}
if token != expected {
t.Errorf("At iteration %d, expected %s, got %s", i, expected, token)
}
}
}
func TestRandomStrategy(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
balancer := NewJWTBalancer(tokens, config.Random)
// 测试随机策略 - 多次获取token,确保都是有效的
tokenCounts := make(map[string]int)
iterations := 100
for i := 0; i < iterations; i++ {
token, err := balancer.GetToken()
if err != nil {
t.Fatalf("Unexpected error at iteration %d: %v", i, err)
}
// 检查token是否在预期列表中
found := false
for _, expectedToken := range tokens {
if token == expectedToken {
found = true
break
}
}
if !found {
t.Errorf("Got unexpected token: %s", token)
}
tokenCounts[token]++
}
// 确保所有token都被使用过(随机策略下应该都有机会被选中)
for _, token := range tokens {
if tokenCounts[token] == 0 {
t.Errorf("Token %s was never selected", token)
}
}
}
func TestMarkTokenUnhealthy(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
// 标记一个token为不健康
balancer.MarkTokenUnhealthy("token2")
if balancer.GetHealthyTokenCount() != 2 {
t.Errorf("Expected 2 healthy tokens, got %d", balancer.GetHealthyTokenCount())
}
// 获取token,应该只返回健康的token
for i := 0; i < 10; i++ {
token, err := balancer.GetToken()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if token == "token2" {
t.Errorf("Got unhealthy token: %s", token)
}
}
}
func TestMarkTokenHealthy(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
// 先标记为不健康,再标记为健康
balancer.MarkTokenUnhealthy("token2")
if balancer.GetHealthyTokenCount() != 2 {
t.Errorf("Expected 2 healthy tokens after marking unhealthy, got %d", balancer.GetHealthyTokenCount())
}
balancer.MarkTokenHealthy("token2")
if balancer.GetHealthyTokenCount() != 3 {
t.Errorf("Expected 3 healthy tokens after marking healthy, got %d", balancer.GetHealthyTokenCount())
}
}
func TestNoHealthyTokens(t *testing.T) {
tokens := []string{"token1", "token2"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
// 标记所有token为不健康
balancer.MarkTokenUnhealthy("token1")
balancer.MarkTokenUnhealthy("token2")
// 尝试获取token应该返回错误
_, err := balancer.GetToken()
if err == nil {
t.Error("Expected error when no healthy tokens available")
}
}
func TestConcurrentAccess(t *testing.T) {
tokens := []string{"token1", "token2", "token3", "token4", "token5"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
var wg sync.WaitGroup
numGoroutines := 10
tokensPerGoroutine := 100
// 并发获取tokens
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < tokensPerGoroutine; j++ {
_, err := balancer.GetToken()
if err != nil {
t.Errorf("Unexpected error in concurrent access: %v", err)
}
}
}()
}
// 并发标记tokens健康状态
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
token := tokens[index%len(tokens)]
for j := 0; j < 10; j++ {
if j%2 == 0 {
balancer.MarkTokenUnhealthy(token)
} else {
balancer.MarkTokenHealthy(token)
}
time.Sleep(time.Millisecond)
}
}(i)
}
wg.Wait()
// 确保最终状态正常
if balancer.GetTotalTokenCount() != len(tokens) {
t.Errorf("Expected %d total tokens, got %d", len(tokens), balancer.GetTotalTokenCount())
}
}
func TestRefreshTokens(t *testing.T) {
tokens := []string{"token1", "token2"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
if balancer.GetTotalTokenCount() != 2 {
t.Errorf("Expected 2 tokens initially, got %d", balancer.GetTotalTokenCount())
}
// 刷新tokens
newTokens := []string{"token3", "token4", "token5"}
balancer.RefreshTokens(newTokens)
if balancer.GetTotalTokenCount() != 3 {
t.Errorf("Expected 3 tokens after refresh, got %d", balancer.GetTotalTokenCount())
}
if balancer.GetHealthyTokenCount() != 3 {
t.Errorf("Expected 3 healthy tokens after refresh, got %d", balancer.GetHealthyTokenCount())
}
// 验证新tokens可以被获取
for i := 0; i < 6; i++ { // 两轮完整轮询
token, err := balancer.GetToken()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
found := false
for _, newToken := range newTokens {
if token == newToken {
found = true
break
}
}
if !found {
t.Errorf("Got unexpected token after refresh: %s", token)
}
}
}