|
package jetbrains |
|
|
|
import ( |
|
"bufio" |
|
"context" |
|
"fmt" |
|
"github.com/bytedance/sonic" |
|
"github.com/sashabaranov/go-openai" |
|
"io" |
|
"jetbrains-ai-proxy/internal/utils" |
|
"log" |
|
"math" |
|
"net/http" |
|
"strconv" |
|
"strings" |
|
"time" |
|
) |
|
|
|
const ( |
|
sseObject = "chat.completion.chunk" |
|
completionsObject = "chat.completions" |
|
sseFinish = "[DONE]" |
|
initialBufferSize = 4096 |
|
maxBufferSize = 1024 * 1024 |
|
flushThreshold = 10 |
|
heartbeatInterval = 30 * time.Second |
|
) |
|
|
|
type SSEData struct { |
|
Type string `json:"type"` |
|
EventType string `json:"event_type"` |
|
Content string `json:"content,omitempty"` |
|
Reason string `json:"reason,omitempty"` |
|
Updated *UpdatedData `json:"updated,omitempty"` |
|
Spent *SpentData `json:"spent,omitempty"` |
|
} |
|
|
|
type UpdatedData struct { |
|
License string `json:"license"` |
|
Current AmountData `json:"current"` |
|
Maximum AmountData `json:"maximum"` |
|
Until int64 `json:"until"` |
|
QuotaID QuotaInfo `json:"quotaID"` |
|
} |
|
|
|
type AmountData struct { |
|
Amount string `json:"amount"` |
|
} |
|
|
|
type QuotaInfo struct { |
|
QuotaId string `json:"quotaId"` |
|
} |
|
|
|
type SpentData struct { |
|
Amount string `json:"amount"` |
|
} |
|
|
|
|
|
func ResponseJetbrainsAIToClient(ctx context.Context, req openai.ChatCompletionRequest, r io.Reader, fp string) (openai.ChatCompletionResponse, error) { |
|
reader := bufio.NewReader(r) |
|
var fullContent strings.Builder |
|
|
|
now := time.Now().Unix() |
|
chatId := strconv.Itoa(int(now)) |
|
|
|
for { |
|
select { |
|
case <-ctx.Done(): |
|
return openai.ChatCompletionResponse{}, ctx.Err() |
|
default: |
|
} |
|
|
|
line, err := reader.ReadString('\n') |
|
if err != nil { |
|
if err == io.EOF { |
|
log.Printf("Reached EOF for non-streaming response") |
|
break |
|
} |
|
return openai.ChatCompletionResponse{}, fmt.Errorf("读取错误: %w", err) |
|
} |
|
|
|
if !strings.HasPrefix(line, "data: ") { |
|
continue |
|
} |
|
|
|
jsonStr := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) |
|
if jsonStr == "" || jsonStr == sseFinish || jsonStr == "end" { |
|
continue |
|
} |
|
|
|
var sseData SSEData |
|
if err := sonic.UnmarshalString(jsonStr, &sseData); err != nil { |
|
log.Printf("解析SSE数据错误: %v", err) |
|
continue |
|
} |
|
|
|
if sseData.Type == "Content" { |
|
fullContent.WriteString(sseData.Content) |
|
} |
|
|
|
if sseData.Type == "QuotaMetadata" { |
|
var spentAmount float64 |
|
if sseData.Spent != nil { |
|
if amount, err := strconv.ParseFloat(sseData.Spent.Amount, 64); err == nil { |
|
spentAmount = amount |
|
} else { |
|
log.Printf("Warning: failed to parse spent amount '%s': %v", sseData.Spent.Amount, err) |
|
} |
|
} |
|
usage := utils.CalculateJetbrainsUsage(fullContent.String(), int(math.Round(spentAmount))) |
|
return createMessage(chatId, now, req, usage, fullContent.String(), fp), nil |
|
} |
|
} |
|
|
|
|
|
usage := utils.CalculateJetbrainsUsage(fullContent.String(), 0) |
|
return createMessage(chatId, now, req, usage, fullContent.String(), fp), nil |
|
} |
|
|
|
|
|
func StreamJetbrainsAISSEToClient(ctx context.Context, req openai.ChatCompletionRequest, w io.Writer, r io.Reader, fp string) error { |
|
log.Printf("=== Starting SSE Stream Processing for model: %s ===", req.Model) |
|
|
|
reader := bufio.NewReaderSize(r, initialBufferSize) |
|
writer := bufio.NewWriterSize(w, initialBufferSize) |
|
|
|
now := time.Now().Unix() |
|
chatId := strconv.Itoa(int(now)) |
|
fingerprint := fp |
|
|
|
log.Printf("Session initialized - ChatID: %s, Fingerprint: %s", chatId, fingerprint) |
|
|
|
var completionBuilder strings.Builder |
|
messageCount := 0 |
|
totalBufferSize := 0 |
|
|
|
|
|
heartbeat := time.NewTicker(heartbeatInterval) |
|
defer heartbeat.Stop() |
|
|
|
for { |
|
select { |
|
case <-ctx.Done(): |
|
return ctx.Err() |
|
case <-heartbeat.C: |
|
if err := sendHeartbeat(writer, w); err != nil { |
|
log.Printf("Heartbeat error: %v", err) |
|
} |
|
continue |
|
default: |
|
} |
|
|
|
line, err := reader.ReadString('\n') |
|
if err != nil { |
|
if err == io.EOF { |
|
log.Printf("Reached EOF after %d messages", messageCount) |
|
return nil |
|
} |
|
return fmt.Errorf("read error: %w", err) |
|
} |
|
|
|
log.Printf("Received line: %s", strings.TrimSpace(line)) |
|
|
|
|
|
totalBufferSize += len(line) |
|
if totalBufferSize > maxBufferSize { |
|
log.Printf("Buffer overflow: current size %d exceeds max size %d", totalBufferSize, maxBufferSize) |
|
return fmt.Errorf("buffer overflow: exceeded maximum buffer size of %d bytes", maxBufferSize) |
|
} |
|
|
|
if !strings.HasPrefix(line, "data: ") { |
|
continue |
|
} |
|
|
|
jsonStr := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) |
|
if jsonStr == "" || jsonStr == "end" { |
|
continue |
|
} |
|
|
|
var sseData SSEData |
|
if err := sonic.UnmarshalString(jsonStr, &sseData); err != nil { |
|
log.Printf("Error unmarshaling SSE data: %v", err) |
|
continue |
|
} |
|
|
|
log.Printf("Received SSE data: %+v", sseData) |
|
|
|
messageCount++ |
|
|
|
if err := processMessage(writer, w, sseData, chatId, fingerprint, now, &completionBuilder, req); err != nil { |
|
log.Printf("Failed to process message: %v", err) |
|
return err |
|
} |
|
|
|
|
|
if messageCount >= flushThreshold { |
|
if err := flushWriter(writer, w); err != nil { |
|
return fmt.Errorf("flush error: %w", err) |
|
} |
|
messageCount = 0 |
|
} |
|
|
|
|
|
if sseData.Type == "QuotaMetadata" { |
|
if err := sendFinishSignal(writer, w); err != nil { |
|
return fmt.Errorf("finish signal error: %w", err) |
|
} |
|
log.Printf("Stream completed successfully") |
|
return nil |
|
} |
|
} |
|
} |
|
|
|
|
|
func processMessage(writer *bufio.Writer, w io.Writer, sseData SSEData, chatId, fingerprint string, now int64, completionBuilder *strings.Builder, req openai.ChatCompletionRequest) error { |
|
switch sseData.Type { |
|
case "Content": |
|
completionBuilder.WriteString(sseData.Content) |
|
sseMsg := createStreamMessage(chatId, now, req, fingerprint, sseData.Content, "") |
|
return sendMessage(writer, w, sseMsg) |
|
|
|
case "QuotaMetadata": |
|
var spentAmount float64 |
|
if sseData.Spent != nil { |
|
if amount, err := strconv.ParseFloat(sseData.Spent.Amount, 64); err == nil { |
|
spentAmount = amount |
|
} else { |
|
log.Printf("Warning: failed to parse spent amount '%s': %v", sseData.Spent.Amount, err) |
|
} |
|
} |
|
|
|
usage := utils.CalculateJetbrainsUsage(completionBuilder.String(), int(math.Round(spentAmount))) |
|
sseMsg := createStreamMessage(chatId, now, req, fingerprint, "", "") |
|
sseMsg.Choices[0].FinishReason = openai.FinishReasonStop |
|
sseMsg.Usage = &usage |
|
return sendMessage(writer, w, sseMsg) |
|
|
|
default: |
|
|
|
log.Printf("Ignoring message type: %s", sseData.Type) |
|
return nil |
|
} |
|
} |
|
|
|
|
|
func createStreamMessage(chatId string, now int64, req openai.ChatCompletionRequest, fingerPrint string, content string, reasoningContent string) openai.ChatCompletionStreamResponse { |
|
choice := openai.ChatCompletionStreamChoice{ |
|
Index: 0, |
|
Delta: openai.ChatCompletionStreamChoiceDelta{ |
|
Role: openai.ChatMessageRoleAssistant, |
|
Content: content, |
|
ReasoningContent: reasoningContent, |
|
}, |
|
ContentFilterResults: openai.ContentFilterResults{}, |
|
FinishReason: openai.FinishReasonNull, |
|
} |
|
|
|
return openai.ChatCompletionStreamResponse{ |
|
ID: "chatcmpl-" + chatId, |
|
Object: sseObject, |
|
Created: now, |
|
Model: req.Model, |
|
Choices: []openai.ChatCompletionStreamChoice{choice}, |
|
SystemFingerprint: fingerPrint, |
|
} |
|
} |
|
|
|
|
|
func createMessage(chatId string, now int64, req openai.ChatCompletionRequest, usage openai.Usage, content string, fp string) openai.ChatCompletionResponse { |
|
choice := openai.ChatCompletionChoice{ |
|
Index: 0, |
|
Message: openai.ChatCompletionMessage{ |
|
Role: openai.ChatMessageRoleAssistant, |
|
Content: content, |
|
}, |
|
FinishReason: openai.FinishReasonStop, |
|
} |
|
|
|
return openai.ChatCompletionResponse{ |
|
ID: "chatcmpl-" + chatId, |
|
Object: completionsObject, |
|
Created: now, |
|
Model: req.Model, |
|
Choices: []openai.ChatCompletionChoice{choice}, |
|
SystemFingerprint: fp, |
|
Usage: usage, |
|
} |
|
} |
|
|
|
|
|
func sendMessage(writer *bufio.Writer, w io.Writer, sseMsg openai.ChatCompletionStreamResponse) error { |
|
sendLine, err := sonic.MarshalString(sseMsg) |
|
if err != nil { |
|
return fmt.Errorf("marshal error: %w", err) |
|
} |
|
|
|
outputMsg := fmt.Sprintf("data: %s\n\n", sendLine) |
|
if _, err := writer.WriteString(outputMsg); err != nil { |
|
return fmt.Errorf("write error: %w", err) |
|
} |
|
|
|
return flushWriter(writer, w) |
|
} |
|
|
|
|
|
func sendHeartbeat(writer *bufio.Writer, w io.Writer) error { |
|
if _, err := writer.WriteString(": keepalive\n\n"); err != nil { |
|
return fmt.Errorf("heartbeat write error: %w", err) |
|
} |
|
return flushWriter(writer, w) |
|
} |
|
|
|
|
|
func sendFinishSignal(writer *bufio.Writer, w io.Writer) error { |
|
finishMsg := fmt.Sprintf("data: %s\n\n", sseFinish) |
|
if _, err := writer.WriteString(finishMsg); err != nil { |
|
return fmt.Errorf("write finish signal error: %w", err) |
|
} |
|
return flushWriter(writer, w) |
|
} |
|
|
|
|
|
func flushWriter(writer *bufio.Writer, w io.Writer) error { |
|
if err := writer.Flush(); err != nil { |
|
return fmt.Errorf("flush error: %w", err) |
|
} |
|
if f, ok := w.(http.Flusher); ok { |
|
f.Flush() |
|
} |
|
return nil |
|
} |
|
|