|
package jetbrains |
|
|
|
import ( |
|
"context" |
|
"fmt" |
|
"github.com/go-resty/resty/v2" |
|
"jetbrains-ai-proxy/internal/balancer" |
|
"jetbrains-ai-proxy/internal/config" |
|
"jetbrains-ai-proxy/internal/types" |
|
"jetbrains-ai-proxy/internal/utils" |
|
"log" |
|
"sync" |
|
) |
|
|
|
var ( |
|
jwtBalancer balancer.JWTBalancer |
|
healthChecker *balancer.HealthChecker |
|
initOnce sync.Once |
|
configManager *config.Manager |
|
) |
|
|
|
|
|
func InitializeFromConfig() error { |
|
var initErr error |
|
|
|
initOnce.Do(func() { |
|
configManager = config.GetGlobalConfig() |
|
|
|
|
|
if err := configManager.LoadConfig(); err != nil { |
|
initErr = fmt.Errorf("failed to load config: %v", err) |
|
return |
|
} |
|
|
|
|
|
cfg := configManager.GetConfig() |
|
tokens := configManager.GetJWTTokens() |
|
|
|
if len(tokens) == 0 { |
|
initErr = fmt.Errorf("no JWT tokens configured") |
|
return |
|
} |
|
|
|
|
|
jwtBalancer = balancer.NewJWTBalancer(tokens, cfg.LoadBalanceStrategy) |
|
|
|
|
|
healthChecker = balancer.NewHealthChecker(jwtBalancer) |
|
if cfg.HealthCheckInterval > 0 { |
|
healthChecker.SetCheckInterval(cfg.HealthCheckInterval) |
|
} |
|
healthChecker.Start() |
|
|
|
log.Printf("JWT balancer initialized from config:") |
|
log.Printf(" - Tokens: %d", len(tokens)) |
|
log.Printf(" - Strategy: %s", cfg.LoadBalanceStrategy) |
|
log.Printf(" - Health check interval: %v", cfg.HealthCheckInterval) |
|
}) |
|
|
|
return initErr |
|
} |
|
|
|
|
|
func InitializeBalancer(tokens []string, strategy string) error { |
|
if len(tokens) == 0 { |
|
return fmt.Errorf("no JWT tokens provided") |
|
} |
|
|
|
var balanceStrategy config.LoadBalanceStrategy |
|
switch strategy { |
|
case "random": |
|
balanceStrategy = config.Random |
|
case "round_robin", "": |
|
balanceStrategy = config.RoundRobin |
|
default: |
|
balanceStrategy = config.RoundRobin |
|
} |
|
|
|
|
|
jwtBalancer = balancer.NewJWTBalancer(tokens, balanceStrategy) |
|
|
|
|
|
healthChecker = balancer.NewHealthChecker(jwtBalancer) |
|
healthChecker.Start() |
|
|
|
log.Printf("JWT balancer initialized with %d tokens, strategy: %s", len(tokens), string(balanceStrategy)) |
|
return nil |
|
} |
|
|
|
|
|
func ReloadConfig() error { |
|
if configManager == nil { |
|
return fmt.Errorf("config manager not initialized") |
|
} |
|
|
|
|
|
if err := configManager.LoadConfig(); err != nil { |
|
return fmt.Errorf("failed to reload config: %v", err) |
|
} |
|
|
|
|
|
cfg := configManager.GetConfig() |
|
tokens := configManager.GetJWTTokens() |
|
|
|
if len(tokens) == 0 { |
|
return fmt.Errorf("no JWT tokens in reloaded config") |
|
} |
|
|
|
|
|
if jwtBalancer != nil { |
|
jwtBalancer.RefreshTokens(tokens) |
|
} |
|
|
|
|
|
if healthChecker != nil && cfg.HealthCheckInterval > 0 { |
|
healthChecker.SetCheckInterval(cfg.HealthCheckInterval) |
|
} |
|
|
|
log.Printf("Config reloaded successfully:") |
|
log.Printf(" - Tokens: %d", len(tokens)) |
|
log.Printf(" - Strategy: %s", cfg.LoadBalanceStrategy) |
|
|
|
return nil |
|
} |
|
|
|
|
|
func StopBalancer() { |
|
if healthChecker != nil { |
|
healthChecker.Stop() |
|
} |
|
} |
|
|
|
|
|
func GetConfigManager() *config.Manager { |
|
return configManager |
|
} |
|
|
|
func SendJetbrainsRequest(ctx context.Context, req *types.JetbrainsRequest) (*resty.Response, error) { |
|
|
|
token, err := jwtBalancer.GetToken() |
|
if err != nil { |
|
log.Printf("failed to get JWT token: %v", err) |
|
return nil, fmt.Errorf("no available JWT tokens: %v", err) |
|
} |
|
|
|
resp, err := utils.RestySSEClient.R(). |
|
SetContext(ctx). |
|
SetHeader(types.JwtTokenKey, token). |
|
SetDoNotParseResponse(true). |
|
SetBody(req). |
|
Post(types.ChatStreamV7) |
|
|
|
if err != nil { |
|
log.Printf("jetbrains ai req error: %v", err) |
|
|
|
jwtBalancer.MarkTokenUnhealthy(token) |
|
return nil, err |
|
} |
|
|
|
|
|
if resp.StatusCode() == 401 { |
|
|
|
jwtBalancer.MarkTokenUnhealthy(token) |
|
log.Printf("JWT token invalid (401): %s...", token[:min(len(token), 10)]) |
|
return nil, fmt.Errorf("JWT token invalid") |
|
} else if resp.StatusCode() == 200 { |
|
|
|
jwtBalancer.MarkTokenHealthy(token) |
|
} |
|
|
|
return resp, nil |
|
} |
|
|
|
|
|
func GetBalancerStats() (int, int) { |
|
if jwtBalancer == nil { |
|
return 0, 0 |
|
} |
|
return jwtBalancer.GetHealthyTokenCount(), jwtBalancer.GetTotalTokenCount() |
|
} |
|
|
|
|
|
func min(a, b int) int { |
|
if a < b { |
|
return a |
|
} |
|
return b |
|
} |
|
|