package main import ( "bytes" "encoding/json" "io" "log" "net/http" "strings" ) const ( defaultPort = "8080" geminiOpenAIEndpoint = "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions" ) // RelayServer 中继服务器 type RelayServer struct { client *http.Client } // NewRelayServer 创建新的中继服务器 func NewRelayServer() *RelayServer { return &RelayServer{ client: &http.Client{}, } } // filterRequest 过滤掉Gemini不支持的参数 func filterRequest(body []byte) ([]byte, error) { var requestData map[string]interface{} if err := json.Unmarshal(body, &requestData); err != nil { return body, nil // 如果解析失败,返回原始数据 } // Gemini不支持的OpenAI参数列表 unsupportedParams := []string{ "frequency_penalty", "presence_penalty", "logit_bias", "user", "n", "stop", "suffix", "logprobs", "echo", "best_of", "response_format", "seed", "tools", "tool_choice", "parallel_tool_calls", } // 删除不支持的参数 for _, param := range unsupportedParams { delete(requestData, param) } // 重新序列化 return json.Marshal(requestData) } // handleRequest 处理所有的API请求 func (s *RelayServer) handleRequest(w http.ResponseWriter, r *http.Request) { // 检查是否有Authorization头 authHeader := r.Header.Get("Authorization") if authHeader == "" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error": {"message": "Missing Authorization header", "type": "invalid_request_error"}}`)) return } // 读取请求体 bodyBytes, err := io.ReadAll(r.Body) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"error": {"message": "Failed to read request body", "type": "invalid_request_error"}}`)) return } defer r.Body.Close() // 打印原始请求(调试用) log.Printf("Original request body: %s", string(bodyBytes)) // 过滤请求参数 filteredBody, err := filterRequest(bodyBytes) if err != nil { log.Printf("Failed to filter request: %v", err) filteredBody = bodyBytes // 使用原始数据 } log.Printf("Filtered request body: %s", string(filteredBody)) // 创建新的请求 proxyReq, err := http.NewRequest("POST", geminiOpenAIEndpoint, bytes.NewReader(filteredBody)) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(`{"error": {"message": "Failed to create proxy request", "type": "server_error"}}`)) return } // 复制所有请求头 for name, values := range r.Header { // 跳过Host和Content-Length,这些会自动设置 if name == "Host" || name == "Content-Length" { continue } for _, value := range values { proxyReq.Header.Add(name, value) } } // 确保Authorization头被正确设置 proxyReq.Header.Set("Authorization", authHeader) proxyReq.Header.Set("Content-Type", "application/json") log.Printf("Request headers being sent to Gemini: %v", proxyReq.Header) // 发送请求 resp, err := s.client.Do(proxyReq) if err != nil { log.Printf("Failed to send request to Gemini: %v", err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadGateway) w.Write([]byte(`{"error": {"message": "Failed to connect to Gemini API", "type": "server_error"}}`)) return } defer resp.Body.Close() // 打印响应状态(调试用) log.Printf("Response status from Gemini: %d", resp.StatusCode) // 复制响应头 for name, values := range resp.Header { // 跳过一些头部 if name == "Content-Length" { continue } for _, value := range values { w.Header().Add(name, value) } } // 添加CORS头 w.Header().Set("Access-Control-Allow-Origin", "*") // 设置状态码 w.WriteHeader(resp.StatusCode) // 处理流式响应 if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { // 确保流式响应的头部设置正确 w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") // 使用缓冲区进行流式传输 buf := make([]byte, 1024) for { n, err := resp.Body.Read(buf) if n > 0 { if _, writeErr := w.Write(buf[:n]); writeErr != nil { log.Printf("Error writing response: %v", writeErr) return } if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } } if err != nil { if err != io.EOF { log.Printf("Error reading response: %v", err) } break } } } else { // 非流式响应 // 如果是错误响应,打印出来以便调试 if resp.StatusCode >= 400 { bodyBytes, _ := io.ReadAll(resp.Body) log.Printf("Error response from Gemini: %s", string(bodyBytes)) w.Write(bodyBytes) } else { io.Copy(w, resp.Body) } } } // handleHealth 健康检查端点 func (s *RelayServer) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"status": "ok", "service": "gemini-relay"}`)) } // corsMiddleware CORS中间件 func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next(w, r) } } // loggingMiddleware 日志中间件 func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { log.Printf("[%s] %s %s %s", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent()) next(w, r) } } // 模型映射 func (s *RelayServer) handleModels(w http.ResponseWriter, r *http.Request) { // 检查Authorization if r.Header.Get("Authorization") == "" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"error": {"message": "Missing Authorization header", "type": "invalid_request_error"}}`)) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) models := `{ "object": "list", "data": [ { "id": "gemini-1.5-pro", "object": "model", "created": 1686935002, "owned_by": "google" }, { "id": "gemini-1.5-flash", "object": "model", "created": 1686935002, "owned_by": "google" }, { "id": "gemini-1.5-flash-8b", "object": "model", "created": 1686935002, "owned_by": "google" }, { "id": "gemini-2.0-flash-exp", "object": "model", "created": 1686935002, "owned_by": "google" } ] }` w.Write([]byte(models)) } func main() { // 创建中继服务器 server := NewRelayServer() // 设置路由 mux := http.NewServeMux() // OpenAI兼容的端点 mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(server.handleRequest))) mux.HandleFunc("/chat/completions", corsMiddleware(loggingMiddleware(server.handleRequest))) mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(server.handleModels))) mux.HandleFunc("/models", corsMiddleware(loggingMiddleware(server.handleModels))) // 健康检查 mux.HandleFunc("/health", corsMiddleware(server.handleHealth)) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{ "service": "Gemini API Relay", "version": "1.0.0", "endpoints": { "chat": "/v1/chat/completions", "models": "/v1/models", "health": "/health" }, "supported_models": [ "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b", "gemini-2.0-flash-exp" ], "note": "Use Authorization header with 'Bearer YOUR_GEMINI_API_KEY'" }`)) }) // 启动服务器 port := defaultPort log.Printf("========================================") log.Printf("Gemini API Relay Server") log.Printf("Port: %s", port) log.Printf("Endpoint: %s", geminiOpenAIEndpoint) log.Printf("========================================") log.Printf("Usage:") log.Printf(" Authorization: Bearer YOUR_GEMINI_API_KEY") log.Printf("========================================") if err := http.ListenAndServe(":"+port, mux); err != nil { log.Fatalf("Server failed to start: %v", err) } }