|
package balancer |
|
|
|
import ( |
|
"fmt" |
|
"jetbrains-ai-proxy/internal/config" |
|
"math/rand" |
|
"sync" |
|
"sync/atomic" |
|
"time" |
|
) |
|
|
|
|
|
type JWTBalancer interface { |
|
GetToken() (string, error) |
|
MarkTokenUnhealthy(token string) |
|
MarkTokenHealthy(token string) |
|
GetHealthyTokenCount() int |
|
GetTotalTokenCount() int |
|
RefreshTokens(tokens []string) |
|
} |
|
|
|
|
|
type TokenStatus struct { |
|
Token string |
|
Healthy bool |
|
LastUsed time.Time |
|
ErrorCount int64 |
|
} |
|
|
|
|
|
type BaseBalancer struct { |
|
tokens map[string]*TokenStatus |
|
strategy config.LoadBalanceStrategy |
|
mutex sync.RWMutex |
|
counter int64 |
|
rand *rand.Rand |
|
} |
|
|
|
|
|
func NewJWTBalancer(tokens []string, strategy config.LoadBalanceStrategy) JWTBalancer { |
|
balancer := &BaseBalancer{ |
|
tokens: make(map[string]*TokenStatus), |
|
strategy: strategy, |
|
rand: rand.New(rand.NewSource(time.Now().UnixNano())), |
|
} |
|
|
|
|
|
for _, token := range tokens { |
|
balancer.tokens[token] = &TokenStatus{ |
|
Token: token, |
|
Healthy: true, |
|
LastUsed: time.Now(), |
|
ErrorCount: 0, |
|
} |
|
} |
|
|
|
return balancer |
|
} |
|
|
|
|
|
func (b *BaseBalancer) GetToken() (string, error) { |
|
b.mutex.RLock() |
|
defer b.mutex.RUnlock() |
|
|
|
|
|
healthyTokens := make([]*TokenStatus, 0) |
|
for _, status := range b.tokens { |
|
if status.Healthy { |
|
healthyTokens = append(healthyTokens, status) |
|
} |
|
} |
|
|
|
if len(healthyTokens) == 0 { |
|
return "", fmt.Errorf("no healthy JWT tokens available") |
|
} |
|
|
|
var selectedToken *TokenStatus |
|
|
|
switch b.strategy { |
|
case config.RoundRobin: |
|
|
|
index := atomic.AddInt64(&b.counter, 1) % int64(len(healthyTokens)) |
|
selectedToken = healthyTokens[index] |
|
case config.Random: |
|
|
|
index := b.rand.Intn(len(healthyTokens)) |
|
selectedToken = healthyTokens[index] |
|
default: |
|
|
|
index := atomic.AddInt64(&b.counter, 1) % int64(len(healthyTokens)) |
|
selectedToken = healthyTokens[index] |
|
} |
|
|
|
|
|
selectedToken.LastUsed = time.Now() |
|
|
|
return selectedToken.Token, nil |
|
} |
|
|
|
|
|
func (b *BaseBalancer) MarkTokenUnhealthy(token string) { |
|
b.mutex.Lock() |
|
defer b.mutex.Unlock() |
|
|
|
if status, exists := b.tokens[token]; exists { |
|
status.Healthy = false |
|
atomic.AddInt64(&status.ErrorCount, 1) |
|
fmt.Printf("JWT token marked as unhealthy: %s (errors: %d)\n", |
|
token[:min(len(token), 10)]+"...", status.ErrorCount) |
|
} |
|
} |
|
|
|
|
|
func (b *BaseBalancer) MarkTokenHealthy(token string) { |
|
b.mutex.Lock() |
|
defer b.mutex.Unlock() |
|
|
|
if status, exists := b.tokens[token]; exists { |
|
status.Healthy = true |
|
atomic.StoreInt64(&status.ErrorCount, 0) |
|
fmt.Printf("JWT token marked as healthy: %s\n", |
|
token[:min(len(token), 10)]+"...") |
|
} |
|
} |
|
|
|
|
|
func (b *BaseBalancer) GetHealthyTokenCount() int { |
|
b.mutex.RLock() |
|
defer b.mutex.RUnlock() |
|
|
|
count := 0 |
|
for _, status := range b.tokens { |
|
if status.Healthy { |
|
count++ |
|
} |
|
} |
|
return count |
|
} |
|
|
|
|
|
func (b *BaseBalancer) GetTotalTokenCount() int { |
|
b.mutex.RLock() |
|
defer b.mutex.RUnlock() |
|
|
|
return len(b.tokens) |
|
} |
|
|
|
|
|
func (b *BaseBalancer) RefreshTokens(tokens []string) { |
|
b.mutex.Lock() |
|
defer b.mutex.Unlock() |
|
|
|
|
|
b.tokens = make(map[string]*TokenStatus) |
|
|
|
|
|
for _, token := range tokens { |
|
b.tokens[token] = &TokenStatus{ |
|
Token: token, |
|
Healthy: true, |
|
LastUsed: time.Now(), |
|
ErrorCount: 0, |
|
} |
|
} |
|
|
|
fmt.Printf("JWT tokens refreshed, total: %d\n", len(tokens)) |
|
} |
|
|
|
|
|
func min(a, b int) int { |
|
if a < b { |
|
return a |
|
} |
|
return b |
|
} |
|
|