|
package balancer |
|
|
|
import ( |
|
"context" |
|
"github.com/go-resty/resty/v2" |
|
"jetbrains-ai-proxy/internal/types" |
|
"log" |
|
"sync" |
|
"time" |
|
) |
|
|
|
|
|
type HealthChecker struct { |
|
balancer JWTBalancer |
|
client *resty.Client |
|
checkInterval time.Duration |
|
timeout time.Duration |
|
maxRetries int |
|
stopChan chan struct{} |
|
wg sync.WaitGroup |
|
running bool |
|
mutex sync.RWMutex |
|
} |
|
|
|
|
|
func NewHealthChecker(balancer JWTBalancer) *HealthChecker { |
|
client := resty.New(). |
|
SetTimeout(10 * time.Second). |
|
SetHeaders(map[string]string{ |
|
"Content-Type": "application/json", |
|
}) |
|
|
|
return &HealthChecker{ |
|
balancer: balancer, |
|
client: client, |
|
checkInterval: 30 * time.Second, |
|
timeout: 10 * time.Second, |
|
maxRetries: 3, |
|
stopChan: make(chan struct{}), |
|
} |
|
} |
|
|
|
|
|
func (hc *HealthChecker) Start() { |
|
hc.mutex.Lock() |
|
defer hc.mutex.Unlock() |
|
|
|
if hc.running { |
|
return |
|
} |
|
|
|
hc.running = true |
|
hc.wg.Add(1) |
|
|
|
go hc.healthCheckLoop() |
|
log.Println("JWT health checker started") |
|
} |
|
|
|
|
|
func (hc *HealthChecker) Stop() { |
|
hc.mutex.Lock() |
|
defer hc.mutex.Unlock() |
|
|
|
if !hc.running { |
|
return |
|
} |
|
|
|
hc.running = false |
|
close(hc.stopChan) |
|
hc.wg.Wait() |
|
log.Println("JWT health checker stopped") |
|
} |
|
|
|
|
|
func (hc *HealthChecker) healthCheckLoop() { |
|
defer hc.wg.Done() |
|
|
|
ticker := time.NewTicker(hc.checkInterval) |
|
defer ticker.Stop() |
|
|
|
|
|
hc.performHealthCheck() |
|
|
|
for { |
|
select { |
|
case <-ticker.C: |
|
hc.performHealthCheck() |
|
case <-hc.stopChan: |
|
return |
|
} |
|
} |
|
} |
|
|
|
|
|
func (hc *HealthChecker) performHealthCheck() { |
|
log.Println("Performing JWT health check...") |
|
|
|
|
|
baseBalancer, ok := hc.balancer.(*BaseBalancer) |
|
if !ok { |
|
log.Println("Warning: Cannot access tokens for health check") |
|
return |
|
} |
|
|
|
baseBalancer.mutex.RLock() |
|
tokens := make([]string, 0, len(baseBalancer.tokens)) |
|
for token := range baseBalancer.tokens { |
|
tokens = append(tokens, token) |
|
} |
|
baseBalancer.mutex.RUnlock() |
|
|
|
|
|
var wg sync.WaitGroup |
|
for _, token := range tokens { |
|
wg.Add(1) |
|
go func(t string) { |
|
defer wg.Done() |
|
hc.checkTokenHealth(t) |
|
}(token) |
|
} |
|
wg.Wait() |
|
|
|
healthyCount := hc.balancer.GetHealthyTokenCount() |
|
totalCount := hc.balancer.GetTotalTokenCount() |
|
log.Printf("Health check completed: %d/%d tokens healthy", healthyCount, totalCount) |
|
} |
|
|
|
|
|
func (hc *HealthChecker) checkTokenHealth(token string) { |
|
ctx, cancel := context.WithTimeout(context.Background(), hc.timeout) |
|
defer cancel() |
|
|
|
|
|
testRequest := &types.JetbrainsRequest{ |
|
Prompt: types.PROMPT, |
|
Profile: "openai-gpt-4o", |
|
Chat: types.ChatField{ |
|
MessageField: []types.MessageField{ |
|
{ |
|
Type: "user_message", |
|
Content: "test", |
|
}, |
|
}, |
|
}, |
|
} |
|
|
|
success := false |
|
for retry := 0; retry < hc.maxRetries; retry++ { |
|
if hc.testTokenRequest(ctx, token, testRequest) { |
|
success = true |
|
break |
|
} |
|
|
|
|
|
if retry < hc.maxRetries-1 { |
|
time.Sleep(time.Second) |
|
} |
|
} |
|
|
|
if success { |
|
hc.balancer.MarkTokenHealthy(token) |
|
} else { |
|
hc.balancer.MarkTokenUnhealthy(token) |
|
log.Printf("JWT token health check failed: %s...", token[:min(len(token), 10)]) |
|
} |
|
} |
|
|
|
|
|
func (hc *HealthChecker) testTokenRequest(ctx context.Context, token string, req *types.JetbrainsRequest) bool { |
|
resp, err := hc.client.R(). |
|
SetContext(ctx). |
|
SetHeader(types.JwtTokenKey, token). |
|
SetBody(req). |
|
Post(types.ChatStreamV7) |
|
|
|
if err != nil { |
|
log.Printf("Health check request error for token %s...: %v", token[:min(len(token), 10)], err) |
|
return false |
|
} |
|
|
|
|
|
if resp.StatusCode() == 200 { |
|
return true |
|
} |
|
|
|
|
|
if resp.StatusCode() == 403 { |
|
|
|
return true |
|
} |
|
|
|
log.Printf("Health check failed for token %s...: status %d", |
|
token[:min(len(token), 10)], resp.StatusCode()) |
|
return false |
|
} |
|
|
|
|
|
func (hc *HealthChecker) SetCheckInterval(interval time.Duration) { |
|
hc.mutex.Lock() |
|
defer hc.mutex.Unlock() |
|
hc.checkInterval = interval |
|
} |
|
|
|
|
|
func (hc *HealthChecker) SetTimeout(timeout time.Duration) { |
|
hc.mutex.Lock() |
|
defer hc.mutex.Unlock() |
|
hc.timeout = timeout |
|
} |
|
|
|
|
|
func (hc *HealthChecker) SetMaxRetries(retries int) { |
|
hc.mutex.Lock() |
|
defer hc.mutex.Unlock() |
|
hc.maxRetries = retries |
|
} |
|
|