github-actions[bot]
Update from GitHub Actions
6fefda3
package jetbrains
import (
"bufio"
"context"
"fmt"
"github.com/bytedance/sonic"
"github.com/sashabaranov/go-openai"
"io"
"jetbrains-ai-proxy/internal/utils"
"log"
"math"
"net/http"
"strconv"
"strings"
"time"
)
const (
sseObject = "chat.completion.chunk"
completionsObject = "chat.completions"
sseFinish = "[DONE]"
initialBufferSize = 4096
maxBufferSize = 1024 * 1024 // 1MB
flushThreshold = 10
heartbeatInterval = 30 * time.Second
)
type SSEData struct {
Type string `json:"type"`
EventType string `json:"event_type"`
Content string `json:"content,omitempty"`
Reason string `json:"reason,omitempty"`
Updated *UpdatedData `json:"updated,omitempty"`
Spent *SpentData `json:"spent,omitempty"`
}
type UpdatedData struct {
License string `json:"license"`
Current AmountData `json:"current"`
Maximum AmountData `json:"maximum"`
Until int64 `json:"until"`
QuotaID QuotaInfo `json:"quotaID"`
}
type AmountData struct {
Amount string `json:"amount"`
}
type QuotaInfo struct {
QuotaId string `json:"quotaId"`
}
type SpentData struct {
Amount string `json:"amount"`
}
// ResponseJetbrainsAIToClient 处理非流式响应
func ResponseJetbrainsAIToClient(ctx context.Context, req openai.ChatCompletionRequest, r io.Reader, fp string) (openai.ChatCompletionResponse, error) {
reader := bufio.NewReader(r)
var fullContent strings.Builder
now := time.Now().Unix()
chatId := strconv.Itoa(int(now))
for {
select {
case <-ctx.Done():
return openai.ChatCompletionResponse{}, ctx.Err()
default:
}
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
log.Printf("Reached EOF for non-streaming response")
break
}
return openai.ChatCompletionResponse{}, fmt.Errorf("读取错误: %w", err)
}
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if jsonStr == "" || jsonStr == sseFinish || jsonStr == "end" {
continue
}
var sseData SSEData
if err := sonic.UnmarshalString(jsonStr, &sseData); err != nil {
log.Printf("解析SSE数据错误: %v", err)
continue
}
if sseData.Type == "Content" {
fullContent.WriteString(sseData.Content)
}
if sseData.Type == "QuotaMetadata" {
var spentAmount float64
if sseData.Spent != nil {
if amount, err := strconv.ParseFloat(sseData.Spent.Amount, 64); err == nil {
spentAmount = amount
} else {
log.Printf("Warning: failed to parse spent amount '%s': %v", sseData.Spent.Amount, err)
}
}
usage := utils.CalculateJetbrainsUsage(fullContent.String(), int(math.Round(spentAmount)))
return createMessage(chatId, now, req, usage, fullContent.String(), fp), nil
}
}
// 如果没有收到 QuotaMetadata,返回默认响应
usage := utils.CalculateJetbrainsUsage(fullContent.String(), 0)
return createMessage(chatId, now, req, usage, fullContent.String(), fp), nil
}
// StreamJetbrainsAISSEToClient 处理流式响应
func StreamJetbrainsAISSEToClient(ctx context.Context, req openai.ChatCompletionRequest, w io.Writer, r io.Reader, fp string) error {
log.Printf("=== Starting SSE Stream Processing for model: %s ===", req.Model)
reader := bufio.NewReaderSize(r, initialBufferSize)
writer := bufio.NewWriterSize(w, initialBufferSize)
now := time.Now().Unix()
chatId := strconv.Itoa(int(now))
fingerprint := fp
log.Printf("Session initialized - ChatID: %s, Fingerprint: %s", chatId, fingerprint)
var completionBuilder strings.Builder
messageCount := 0
totalBufferSize := 0
// 创建心跳检测器
heartbeat := time.NewTicker(heartbeatInterval)
defer heartbeat.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-heartbeat.C:
if err := sendHeartbeat(writer, w); err != nil {
log.Printf("Heartbeat error: %v", err)
}
continue
default:
}
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
log.Printf("Reached EOF after %d messages", messageCount)
return nil
}
return fmt.Errorf("read error: %w", err)
}
log.Printf("Received line: %s", strings.TrimSpace(line))
// 检查缓冲区大小
totalBufferSize += len(line)
if totalBufferSize > maxBufferSize {
log.Printf("Buffer overflow: current size %d exceeds max size %d", totalBufferSize, maxBufferSize)
return fmt.Errorf("buffer overflow: exceeded maximum buffer size of %d bytes", maxBufferSize)
}
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if jsonStr == "" || jsonStr == "end" {
continue
}
var sseData SSEData
if err := sonic.UnmarshalString(jsonStr, &sseData); err != nil {
log.Printf("Error unmarshaling SSE data: %v", err)
continue
}
log.Printf("Received SSE data: %+v", sseData)
messageCount++
if err := processMessage(writer, w, sseData, chatId, fingerprint, now, &completionBuilder, req); err != nil {
log.Printf("Failed to process message: %v", err)
return err
}
// 定期刷新缓冲区
if messageCount >= flushThreshold {
if err := flushWriter(writer, w); err != nil {
return fmt.Errorf("flush error: %w", err)
}
messageCount = 0
}
// 检查是否结束
if sseData.Type == "QuotaMetadata" {
if err := sendFinishSignal(writer, w); err != nil {
return fmt.Errorf("finish signal error: %w", err)
}
log.Printf("Stream completed successfully")
return nil
}
}
}
// processMessage 处理单个消息
func processMessage(writer *bufio.Writer, w io.Writer, sseData SSEData, chatId, fingerprint string, now int64, completionBuilder *strings.Builder, req openai.ChatCompletionRequest) error {
switch sseData.Type {
case "Content":
completionBuilder.WriteString(sseData.Content)
sseMsg := createStreamMessage(chatId, now, req, fingerprint, sseData.Content, "")
return sendMessage(writer, w, sseMsg)
case "QuotaMetadata":
var spentAmount float64
if sseData.Spent != nil {
if amount, err := strconv.ParseFloat(sseData.Spent.Amount, 64); err == nil {
spentAmount = amount
} else {
log.Printf("Warning: failed to parse spent amount '%s': %v", sseData.Spent.Amount, err)
}
}
usage := utils.CalculateJetbrainsUsage(completionBuilder.String(), int(math.Round(spentAmount)))
sseMsg := createStreamMessage(chatId, now, req, fingerprint, "", "")
sseMsg.Choices[0].FinishReason = openai.FinishReasonStop
sseMsg.Usage = &usage
return sendMessage(writer, w, sseMsg)
default:
// 忽略其他类型的消息
log.Printf("Ignoring message type: %s", sseData.Type)
return nil
}
}
// createStreamMessage 创建流式消息
func createStreamMessage(chatId string, now int64, req openai.ChatCompletionRequest, fingerPrint string, content string, reasoningContent string) openai.ChatCompletionStreamResponse {
choice := openai.ChatCompletionStreamChoice{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: openai.ChatMessageRoleAssistant,
Content: content,
ReasoningContent: reasoningContent,
},
ContentFilterResults: openai.ContentFilterResults{},
FinishReason: openai.FinishReasonNull,
}
return openai.ChatCompletionStreamResponse{
ID: "chatcmpl-" + chatId,
Object: sseObject,
Created: now,
Model: req.Model,
Choices: []openai.ChatCompletionStreamChoice{choice},
SystemFingerprint: fingerPrint,
}
}
// createMessage 创建非流式消息响应
func createMessage(chatId string, now int64, req openai.ChatCompletionRequest, usage openai.Usage, content string, fp string) openai.ChatCompletionResponse {
choice := openai.ChatCompletionChoice{
Index: 0,
Message: openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: content,
},
FinishReason: openai.FinishReasonStop,
}
return openai.ChatCompletionResponse{
ID: "chatcmpl-" + chatId,
Object: completionsObject,
Created: now,
Model: req.Model,
Choices: []openai.ChatCompletionChoice{choice},
SystemFingerprint: fp,
Usage: usage,
}
}
// sendMessage 发送消息到客户端
func sendMessage(writer *bufio.Writer, w io.Writer, sseMsg openai.ChatCompletionStreamResponse) error {
sendLine, err := sonic.MarshalString(sseMsg)
if err != nil {
return fmt.Errorf("marshal error: %w", err)
}
outputMsg := fmt.Sprintf("data: %s\n\n", sendLine)
if _, err := writer.WriteString(outputMsg); err != nil {
return fmt.Errorf("write error: %w", err)
}
return flushWriter(writer, w)
}
// sendHeartbeat 发送心跳包
func sendHeartbeat(writer *bufio.Writer, w io.Writer) error {
if _, err := writer.WriteString(": keepalive\n\n"); err != nil {
return fmt.Errorf("heartbeat write error: %w", err)
}
return flushWriter(writer, w)
}
// sendFinishSignal 发送结束信号
func sendFinishSignal(writer *bufio.Writer, w io.Writer) error {
finishMsg := fmt.Sprintf("data: %s\n\n", sseFinish)
if _, err := writer.WriteString(finishMsg); err != nil {
return fmt.Errorf("write finish signal error: %w", err)
}
return flushWriter(writer, w)
}
// flushWriter 刷新写入器
func flushWriter(writer *bufio.Writer, w io.Writer) error {
if err := writer.Flush(); err != nil {
return fmt.Errorf("flush error: %w", err)
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
}