File size: 4,682 Bytes
6fefda3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
package jetbrains

import (
	"context"
	"fmt"
	"github.com/go-resty/resty/v2"
	"jetbrains-ai-proxy/internal/balancer"
	"jetbrains-ai-proxy/internal/config"
	"jetbrains-ai-proxy/internal/types"
	"jetbrains-ai-proxy/internal/utils"
	"log"
	"sync"
)

var (
	jwtBalancer    balancer.JWTBalancer
	healthChecker  *balancer.HealthChecker
	initOnce       sync.Once
	configManager  *config.Manager
)

// InitializeFromConfig 从配置管理器初始化JWT负载均衡器
func InitializeFromConfig() error {
	var initErr error

	initOnce.Do(func() {
		configManager = config.GetGlobalConfig()

		// 加载配置
		if err := configManager.LoadConfig(); err != nil {
			initErr = fmt.Errorf("failed to load config: %v", err)
			return
		}

		// 获取配置
		cfg := configManager.GetConfig()
		tokens := configManager.GetJWTTokens()

		if len(tokens) == 0 {
			initErr = fmt.Errorf("no JWT tokens configured")
			return
		}

		// 创建负载均衡器
		jwtBalancer = balancer.NewJWTBalancer(tokens, cfg.LoadBalanceStrategy)

		// 创建并启动健康检查器
		healthChecker = balancer.NewHealthChecker(jwtBalancer)
		if cfg.HealthCheckInterval > 0 {
			healthChecker.SetCheckInterval(cfg.HealthCheckInterval)
		}
		healthChecker.Start()

		log.Printf("JWT balancer initialized from config:")
		log.Printf("  - Tokens: %d", len(tokens))
		log.Printf("  - Strategy: %s", cfg.LoadBalanceStrategy)
		log.Printf("  - Health check interval: %v", cfg.HealthCheckInterval)
	})

	return initErr
}

// InitializeBalancer 初始化JWT负载均衡器(向后兼容)
func InitializeBalancer(tokens []string, strategy string) error {
	if len(tokens) == 0 {
		return fmt.Errorf("no JWT tokens provided")
	}

	var balanceStrategy config.LoadBalanceStrategy
	switch strategy {
	case "random":
		balanceStrategy = config.Random
	case "round_robin", "":
		balanceStrategy = config.RoundRobin
	default:
		balanceStrategy = config.RoundRobin
	}

	// 创建负载均衡器
	jwtBalancer = balancer.NewJWTBalancer(tokens, balanceStrategy)

	// 创建并启动健康检查器
	healthChecker = balancer.NewHealthChecker(jwtBalancer)
	healthChecker.Start()

	log.Printf("JWT balancer initialized with %d tokens, strategy: %s", len(tokens), string(balanceStrategy))
	return nil
}

// ReloadConfig 重新加载配置
func ReloadConfig() error {
	if configManager == nil {
		return fmt.Errorf("config manager not initialized")
	}

	// 重新加载配置
	if err := configManager.LoadConfig(); err != nil {
		return fmt.Errorf("failed to reload config: %v", err)
	}

	// 获取新配置
	cfg := configManager.GetConfig()
	tokens := configManager.GetJWTTokens()

	if len(tokens) == 0 {
		return fmt.Errorf("no JWT tokens in reloaded config")
	}

	// 更新负载均衡器
	if jwtBalancer != nil {
		jwtBalancer.RefreshTokens(tokens)
	}

	// 更新健康检查间隔
	if healthChecker != nil && cfg.HealthCheckInterval > 0 {
		healthChecker.SetCheckInterval(cfg.HealthCheckInterval)
	}

	log.Printf("Config reloaded successfully:")
	log.Printf("  - Tokens: %d", len(tokens))
	log.Printf("  - Strategy: %s", cfg.LoadBalanceStrategy)

	return nil
}

// StopBalancer 停止负载均衡器
func StopBalancer() {
	if healthChecker != nil {
		healthChecker.Stop()
	}
}

// GetConfigManager 获取配置管理器
func GetConfigManager() *config.Manager {
	return configManager
}

func SendJetbrainsRequest(ctx context.Context, req *types.JetbrainsRequest) (*resty.Response, error) {
	// 获取一个可用的JWT token
	token, err := jwtBalancer.GetToken()
	if err != nil {
		log.Printf("failed to get JWT token: %v", err)
		return nil, fmt.Errorf("no available JWT tokens: %v", err)
	}

	resp, err := utils.RestySSEClient.R().
		SetContext(ctx).
		SetHeader(types.JwtTokenKey, token).
		SetDoNotParseResponse(true).
		SetBody(req).
		Post(types.ChatStreamV7)

	if err != nil {
		log.Printf("jetbrains ai req error: %v", err)
		// 标记token为不健康
		jwtBalancer.MarkTokenUnhealthy(token)
		return nil, err
	}

	// 检查响应状态码
	if resp.StatusCode() == 401 {
		// 401表示token无效,标记为不健康
		jwtBalancer.MarkTokenUnhealthy(token)
		log.Printf("JWT token invalid (401): %s...", token[:min(len(token), 10)])
		return nil, fmt.Errorf("JWT token invalid")
	} else if resp.StatusCode() == 200 {
		// 成功响应,确保token标记为健康
		jwtBalancer.MarkTokenHealthy(token)
	}

	return resp, nil
}

// GetBalancerStats 获取负载均衡器统计信息
func GetBalancerStats() (int, int) {
	if jwtBalancer == nil {
		return 0, 0
	}
	return jwtBalancer.GetHealthyTokenCount(), jwtBalancer.GetTotalTokenCount()
}

// min 辅助函数
func min(a, b int) int {
	if a < b {
		return a
	}
	return b
}