File size: 4,682 Bytes
6fefda3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
package jetbrains
import (
"context"
"fmt"
"github.com/go-resty/resty/v2"
"jetbrains-ai-proxy/internal/balancer"
"jetbrains-ai-proxy/internal/config"
"jetbrains-ai-proxy/internal/types"
"jetbrains-ai-proxy/internal/utils"
"log"
"sync"
)
var (
jwtBalancer balancer.JWTBalancer
healthChecker *balancer.HealthChecker
initOnce sync.Once
configManager *config.Manager
)
// InitializeFromConfig 从配置管理器初始化JWT负载均衡器
func InitializeFromConfig() error {
var initErr error
initOnce.Do(func() {
configManager = config.GetGlobalConfig()
// 加载配置
if err := configManager.LoadConfig(); err != nil {
initErr = fmt.Errorf("failed to load config: %v", err)
return
}
// 获取配置
cfg := configManager.GetConfig()
tokens := configManager.GetJWTTokens()
if len(tokens) == 0 {
initErr = fmt.Errorf("no JWT tokens configured")
return
}
// 创建负载均衡器
jwtBalancer = balancer.NewJWTBalancer(tokens, cfg.LoadBalanceStrategy)
// 创建并启动健康检查器
healthChecker = balancer.NewHealthChecker(jwtBalancer)
if cfg.HealthCheckInterval > 0 {
healthChecker.SetCheckInterval(cfg.HealthCheckInterval)
}
healthChecker.Start()
log.Printf("JWT balancer initialized from config:")
log.Printf(" - Tokens: %d", len(tokens))
log.Printf(" - Strategy: %s", cfg.LoadBalanceStrategy)
log.Printf(" - Health check interval: %v", cfg.HealthCheckInterval)
})
return initErr
}
// InitializeBalancer 初始化JWT负载均衡器(向后兼容)
func InitializeBalancer(tokens []string, strategy string) error {
if len(tokens) == 0 {
return fmt.Errorf("no JWT tokens provided")
}
var balanceStrategy config.LoadBalanceStrategy
switch strategy {
case "random":
balanceStrategy = config.Random
case "round_robin", "":
balanceStrategy = config.RoundRobin
default:
balanceStrategy = config.RoundRobin
}
// 创建负载均衡器
jwtBalancer = balancer.NewJWTBalancer(tokens, balanceStrategy)
// 创建并启动健康检查器
healthChecker = balancer.NewHealthChecker(jwtBalancer)
healthChecker.Start()
log.Printf("JWT balancer initialized with %d tokens, strategy: %s", len(tokens), string(balanceStrategy))
return nil
}
// ReloadConfig 重新加载配置
func ReloadConfig() error {
if configManager == nil {
return fmt.Errorf("config manager not initialized")
}
// 重新加载配置
if err := configManager.LoadConfig(); err != nil {
return fmt.Errorf("failed to reload config: %v", err)
}
// 获取新配置
cfg := configManager.GetConfig()
tokens := configManager.GetJWTTokens()
if len(tokens) == 0 {
return fmt.Errorf("no JWT tokens in reloaded config")
}
// 更新负载均衡器
if jwtBalancer != nil {
jwtBalancer.RefreshTokens(tokens)
}
// 更新健康检查间隔
if healthChecker != nil && cfg.HealthCheckInterval > 0 {
healthChecker.SetCheckInterval(cfg.HealthCheckInterval)
}
log.Printf("Config reloaded successfully:")
log.Printf(" - Tokens: %d", len(tokens))
log.Printf(" - Strategy: %s", cfg.LoadBalanceStrategy)
return nil
}
// StopBalancer 停止负载均衡器
func StopBalancer() {
if healthChecker != nil {
healthChecker.Stop()
}
}
// GetConfigManager 获取配置管理器
func GetConfigManager() *config.Manager {
return configManager
}
func SendJetbrainsRequest(ctx context.Context, req *types.JetbrainsRequest) (*resty.Response, error) {
// 获取一个可用的JWT token
token, err := jwtBalancer.GetToken()
if err != nil {
log.Printf("failed to get JWT token: %v", err)
return nil, fmt.Errorf("no available JWT tokens: %v", err)
}
resp, err := utils.RestySSEClient.R().
SetContext(ctx).
SetHeader(types.JwtTokenKey, token).
SetDoNotParseResponse(true).
SetBody(req).
Post(types.ChatStreamV7)
if err != nil {
log.Printf("jetbrains ai req error: %v", err)
// 标记token为不健康
jwtBalancer.MarkTokenUnhealthy(token)
return nil, err
}
// 检查响应状态码
if resp.StatusCode() == 401 {
// 401表示token无效,标记为不健康
jwtBalancer.MarkTokenUnhealthy(token)
log.Printf("JWT token invalid (401): %s...", token[:min(len(token), 10)])
return nil, fmt.Errorf("JWT token invalid")
} else if resp.StatusCode() == 200 {
// 成功响应,确保token标记为健康
jwtBalancer.MarkTokenHealthy(token)
}
return resp, nil
}
// GetBalancerStats 获取负载均衡器统计信息
func GetBalancerStats() (int, int) {
if jwtBalancer == nil {
return 0, 0
}
return jwtBalancer.GetHealthyTokenCount(), jwtBalancer.GetTotalTokenCount()
}
// min 辅助函数
func min(a, b int) int {
if a < b {
return a
}
return b
}
|