| package main |
|
|
| import ( |
| "bytes" |
| "context" |
| "encoding/json" |
| "fmt" |
| "log" |
| "math/rand" |
| "net/http" |
| "os" |
| "strings" |
| "time" |
|
|
| "github.com/google/generative-ai-go/genai" |
| "github.com/rs/cors" |
| "google.golang.org/api/iterator" |
| "google.golang.org/api/option" |
| ) |
|
|
| |
|
|
| |
| var apiKeys = []string{ |
| "AIzaSyBUIl9AisD8FHUn5HLQcriXZnF4n5MqnWU", |
| "AIzaSyAId4YPsZSTLJ5_fA5BESjYxWZBzwADTJI", |
| |
| } |
|
|
| |
| var supportedModels = []ModelInfo{ |
| { |
| ID: "gemini-2.5-flash-preview-05-20", |
| Object: "model", |
| Created: time.Now().Unix(), |
| OwnedBy: "google", |
| Description: "Gemini 2.5 Flash Preview - 最新实验性模型", |
| }, |
| { |
| ID: "gemini-2.5-flash", |
| Object: "model", |
| Created: time.Now().Unix(), |
| OwnedBy: "google", |
| Description: "gemini-2.5-flash稳定经典专业模型", |
| }, |
| { |
| ID: "gemini-2.5-pro", |
| Object: "model", |
| Created: time.Now().Unix(), |
| OwnedBy: "google", |
| Description: "Gemini 2.5 Pro 专业模型", |
| }, |
| } |
|
|
| |
| |
| var modelMapping = map[string]string{ |
| "gemini-2.5-flash-preview-05-20": "gemini-2.5-flash-preview-05-20", |
| "gemini-2.5-flash": "gemini-2.5-flash", |
| "gemini-2.5-pro": "gemini-2.5-pro", |
| } |
|
|
| |
| var safetySettings = []*genai.SafetySetting{ |
| { |
| Category: genai.HarmCategoryHarassment, |
| Threshold: genai.HarmBlockNone, |
| }, |
| { |
| Category: genai.HarmCategoryHateSpeech, |
| Threshold: genai.HarmBlockNone, |
| }, |
| { |
| Category: genai.HarmCategorySexuallyExplicit, |
| Threshold: genai.HarmBlockNone, |
| }, |
| { |
| Category: genai.HarmCategoryDangerousContent, |
| Threshold: genai.HarmBlockNone, |
| }, |
| } |
|
|
| const maxRetries = 3 |
|
|
| |
|
|
| type ChatMessage struct { |
| Role string `json:"role"` |
| Content string `json:"content"` |
| } |
|
|
| type ChatCompletionRequest struct { |
| Model string `json:"model"` |
| Messages []ChatMessage `json:"messages"` |
| Stream bool `json:"stream"` |
| MaxTokens int32 `json:"max_tokens,omitempty"` |
| Temperature float32 `json:"temperature,omitempty"` |
| TopP float32 `json:"top_p,omitempty"` |
| } |
|
|
| type ChatCompletionResponse struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Created int64 `json:"created"` |
| Model string `json:"model"` |
| Choices []Choice `json:"choices"` |
| Usage Usage `json:"usage"` |
| } |
|
|
| type Choice struct { |
| Index int `json:"index"` |
| Message ChatMessage `json:"message"` |
| FinishReason string `json:"finish_reason"` |
| } |
|
|
| type Usage struct { |
| PromptTokens int `json:"prompt_tokens"` |
| CompletionTokens int `json:"completion_tokens"` |
| TotalTokens int `json:"total_tokens"` |
| } |
|
|
| type ChatCompletionStreamResponse struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Created int64 `json:"created"` |
| Model string `json:"model"` |
| Choices []StreamChoice `json:"choices"` |
| } |
|
|
| type StreamChoice struct { |
| Index int `json:"index"` |
| Delta ChatMessage `json:"delta"` |
| FinishReason *string `json:"finish_reason,omitempty"` |
| } |
|
|
| type ModelInfo struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Created int64 `json:"created"` |
| OwnedBy string `json:"owned_by"` |
| Description string `json:"description"` |
| } |
|
|
| type ModelListResponse struct { |
| Object string `json:"object"` |
| Data []ModelInfo `json:"data"` |
| } |
|
|
| |
|
|
| func getRandomAPIKey() string { |
| if len(apiKeys) == 0 { |
| log.Fatal("API密钥列表为空,请在 `apiKeys` 变量中配置密钥。") |
| } |
| r := rand.New(rand.NewSource(time.Now().UnixNano())) |
| return apiKeys[r.Intn(len(apiKeys))] |
| } |
|
|
| |
| func convertMessages(messages []ChatMessage) (history []*genai.Content, lastPrompt []genai.Part, systemInstruction *genai.Content) { |
| if len(messages) == 0 { |
| return nil, nil, nil |
| } |
|
|
| for i, msg := range messages { |
| var role string |
| if msg.Role == "system" { |
| systemInstruction = &genai.Content{Parts: []genai.Part{genai.Text(msg.Content)}} |
| continue |
| } |
|
|
| if i == len(messages)-1 && msg.Role == "user" { |
| lastPrompt = append(lastPrompt, genai.Text(msg.Content)) |
| continue |
| } |
|
|
| if msg.Role == "assistant" { |
| role = "model" |
| } else { |
| role = "user" |
| } |
|
|
| history = append(history, &genai.Content{ |
| Role: role, |
| Parts: []genai.Part{genai.Text(msg.Content)}, |
| }) |
| } |
| return history, lastPrompt, systemInstruction |
| } |
|
|
| func chatCompletionsHandler(w http.ResponseWriter, r *http.Request) { |
| if r.Method != http.MethodPost { |
| http.Error(w, "仅支持POST方法", http.StatusMethodNotAllowed) |
| return |
| } |
|
|
| var req ChatCompletionRequest |
| if err := json.NewDecoder(r.Body).Decode(&req); err != nil { |
| http.Error(w, fmt.Sprintf("解析请求体失败: %v", err), http.StatusBadRequest) |
| return |
| } |
|
|
| |
| modelName := req.Model |
| log.Printf("接收到模型请求: '%s',将直接使用该名称。", modelName) |
|
|
|
|
| history, lastPrompt, systemInstruction := convertMessages(req.Messages) |
|
|
| var lastErr error |
| usedKeys := make(map[string]bool) |
|
|
| for i := 0; i < maxRetries; i++ { |
| ctx := context.Background() |
| apiKey := getRandomAPIKey() |
|
|
| if len(usedKeys) < len(apiKeys) { |
| for usedKeys[apiKey] { |
| apiKey = getRandomAPIKey() |
| } |
| } |
| usedKeys[apiKey] = true |
|
|
| log.Printf("尝试第 %d 次, 使用密钥: ...%s", i+1, apiKey[len(apiKey)-4:]) |
|
|
| client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) |
| if err != nil { |
| lastErr = fmt.Errorf("创建客户端失败: %v", err) |
| log.Println(lastErr) |
| continue |
| } |
| defer client.Close() |
|
|
| model := client.GenerativeModel(modelName) |
| model.SystemInstruction = systemInstruction |
| model.SafetySettings = safetySettings |
| model.SetTemperature(req.Temperature) |
| model.SetTopP(req.TopP) |
| if req.MaxTokens > 0 { |
| model.SetMaxOutputTokens(req.MaxTokens) |
| } |
| |
| chat := model.StartChat() |
| chat.History = history |
|
|
| if req.Stream { |
| err = handleStream(w, ctx, chat, lastPrompt, req.Model) |
| } else { |
| err = handleNonStream(w, ctx, model, chat, lastPrompt, req.Model) |
| } |
|
|
| if err == nil { |
| return |
| } |
|
|
| lastErr = err |
| log.Printf("第 %d 次尝试失败: %v", i+1, err) |
| time.Sleep(1 * time.Second) |
| } |
|
|
| http.Error(w, fmt.Sprintf("所有重试均失败: %v", lastErr), http.StatusInternalServerError) |
| } |
|
|
| func handleStream(w http.ResponseWriter, ctx context.Context, chat *genai.ChatSession, prompt []genai.Part, modelID string) error { |
| w.Header().Set("Content-Type", "text/event-stream") |
| w.Header().Set("Cache-Control", "no-cache") |
| w.Header().Set("Connection", "keep-alive") |
|
|
| iter := chat.SendMessageStream(ctx, prompt...) |
| for { |
| resp, err := iter.Next() |
| if err == iterator.Done { |
| break |
| } |
| if err != nil { |
| return fmt.Errorf("流式生成内容失败: %v", err) |
| } |
|
|
| var contentBuilder strings.Builder |
| for _, part := range resp.Candidates[0].Content.Parts { |
| if txt, ok := part.(genai.Text); ok { |
| contentBuilder.WriteString(string(txt)) |
| } |
| } |
| |
| chunk := ChatCompletionStreamResponse{ |
| ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()), |
| Object: "chat.completion.chunk", |
| Created: time.Now().Unix(), |
| Model: modelID, |
| Choices: []StreamChoice{ |
| { |
| Index: 0, |
| Delta: ChatMessage{ |
| Role: "assistant", |
| Content: contentBuilder.String(), |
| }, |
| }, |
| }, |
| } |
|
|
| var buf bytes.Buffer |
| if err := json.NewEncoder(&buf).Encode(chunk); err != nil { |
| return fmt.Errorf("序列化流式块失败: %v", err) |
| } |
|
|
| fmt.Fprintf(w, "data: %s\n\n", buf.String()) |
| if flusher, ok := w.(http.Flusher); ok { |
| flusher.Flush() |
| } |
| } |
|
|
| finishReason := "stop" |
| doneChunk := ChatCompletionStreamResponse{ |
| ID: fmt.Sprintf("chatcmpl-%d-done", time.Now().Unix()), |
| Object: "chat.completion.chunk", |
| Created: time.Now().Unix(), |
| Model: modelID, |
| Choices: []StreamChoice{ |
| { |
| Index: 0, |
| FinishReason: &finishReason, |
| }, |
| }, |
| } |
| var buf bytes.Buffer |
| json.NewEncoder(&buf).Encode(doneChunk) |
| fmt.Fprintf(w, "data: %s\n\n", buf.String()) |
| fmt.Fprintf(w, "data: [DONE]\n\n") |
| if flusher, ok := w.(http.Flusher); ok { |
| flusher.Flush() |
| } |
|
|
| return nil |
| } |
|
|
| func handleNonStream(w http.ResponseWriter, ctx context.Context, model *genai.GenerativeModel, chat *genai.ChatSession, prompt []genai.Part, modelID string) error { |
| resp, err := chat.SendMessage(ctx, prompt...) |
| if err != nil { |
| return fmt.Errorf("生成内容失败: %v", err) |
| } |
|
|
| var contentBuilder strings.Builder |
| if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { |
| for _, part := range resp.Candidates[0].Content.Parts { |
| if txt, ok := part.(genai.Text); ok { |
| contentBuilder.WriteString(string(txt)) |
| } |
| } |
| } |
| |
| |
| var promptParts []genai.Part |
| for _, c := range chat.History { |
| promptParts = append(promptParts, c.Parts...) |
| } |
| promptParts = append(promptParts, prompt...) |
|
|
| promptTokenCount, err := model.CountTokens(ctx, promptParts...) |
| if err != nil { |
| return fmt.Errorf("计算prompt tokens失败: %v", err) |
| } |
|
|
| completionTokenCount, err := model.CountTokens(ctx, resp.Candidates[0].Content.Parts...) |
| if err != nil { |
| return fmt.Errorf("计算completion tokens失败: %v", err) |
| } |
| |
| response := ChatCompletionResponse{ |
| ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()), |
| Object: "chat.completion", |
| Created: time.Now().Unix(), |
| Model: modelID, |
| Choices: []Choice{ |
| { |
| Index: 0, |
| Message: ChatMessage{ |
| Role: "assistant", |
| Content: contentBuilder.String(), |
| }, |
| FinishReason: "stop", |
| }, |
| }, |
| Usage: Usage{ |
| PromptTokens: int(promptTokenCount.TotalTokens), |
| CompletionTokens: int(completionTokenCount.TotalTokens), |
| TotalTokens: int(promptTokenCount.TotalTokens) + int(completionTokenCount.TotalTokens), |
| }, |
| } |
|
|
| w.Header().Set("Content-Type", "application/json") |
| return json.NewEncoder(w).Encode(response) |
| } |
|
|
|
|
| func modelsHandler(w http.ResponseWriter, r *http.Request) { |
| resp := ModelListResponse{ |
| Object: "list", |
| Data: supportedModels, |
| } |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(resp) |
| } |
|
|
| func rootHandler(w http.ResponseWriter, r *http.Request) { |
| info := map[string]interface{}{ |
| "name": "Gemini Official API (Go Version)", |
| "version": "1.3.0", |
| "description": "Google Gemini官方API接口服务", |
| "endpoints": map[string]string{ |
| "models": "/v1/models", |
| "chat": "/v1/chat/completions", |
| "health": "/health", |
| }, |
| } |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(info) |
| } |
|
|
| func healthHandler(w http.ResponseWriter, r *http.Request) { |
| var modelIDs []string |
| for _, m := range supportedModels { |
| modelIDs = append(modelIDs, m.ID) |
| } |
| health := map[string]interface{}{ |
| "status": "healthy", |
| "timestamp": time.Now().Unix(), |
| "api": "gemini-official-go", |
| "available_models": modelIDs, |
| "version": "1.3.0", |
| } |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(health) |
| } |
|
|
| func main() { |
| mux := http.NewServeMux() |
|
|
| mux.HandleFunc("/", rootHandler) |
| mux.HandleFunc("/health", healthHandler) |
| mux.HandleFunc("/v1/models", modelsHandler) |
| mux.HandleFunc("/v1/chat/completions", chatCompletionsHandler) |
| mux.HandleFunc("/v1/chat/completions/v1/models", modelsHandler) |
|
|
| c := cors.New(cors.Options{ |
| AllowedOrigins: []string{"*"}, |
| AllowedMethods: []string{"GET", "POST", "OPTIONS"}, |
| AllowedHeaders: []string{"*"}, |
| AllowCredentials: true, |
| }) |
| handler := c.Handler(mux) |
|
|
| port := "7860" |
| log.Println("🚀 启动Gemini官方API服务器 (Go版本)") |
| log.Printf("📊 支持的模型: %v", func() []string { |
| var ids []string |
| for _, m := range supportedModels { |
| ids = append(ids, m.ID) |
| } |
| return ids |
| }()) |
| log.Printf("🔑 已配置 %d 个API密钥", len(apiKeys)) |
| log.Println("🔄 支持自动重试和密钥轮换") |
| log.Printf("🔗 服务器正在监听 http://0.0.0.0:%s", port) |
|
|
| envKey := os.Getenv("GEMINI_API_KEY") |
| if envKey != "" { |
| apiKeys = strings.Split(envKey, ",") |
| log.Printf("从环境变量 GEMINI_API_KEY 加载了 %d 个密钥", len(apiKeys)) |
| } |
|
|
| if err := http.ListenAndServe(":"+port, handler); err != nil { |
| log.Fatalf("启动服务器失败: %v", err) |
| } |
| } |