Spaces:
Sleeping
Sleeping
package main | |
import ( | |
"bytes" | |
"encoding/json" | |
"io" | |
"log" | |
"net/http" | |
"strings" | |
) | |
const ( | |
defaultPort = "8080" | |
geminiOpenAIEndpoint = "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions" | |
) | |
// RelayServer 中继服务器 | |
type RelayServer struct { | |
client *http.Client | |
} | |
// NewRelayServer 创建新的中继服务器 | |
func NewRelayServer() *RelayServer { | |
return &RelayServer{ | |
client: &http.Client{}, | |
} | |
} | |
// filterRequest 过滤掉Gemini不支持的参数 | |
func filterRequest(body []byte) ([]byte, error) { | |
var requestData map[string]interface{} | |
if err := json.Unmarshal(body, &requestData); err != nil { | |
return body, nil // 如果解析失败,返回原始数据 | |
} | |
// Gemini不支持的OpenAI参数列表 | |
unsupportedParams := []string{ | |
"frequency_penalty", | |
"presence_penalty", | |
"logit_bias", | |
"user", | |
"n", | |
"stop", | |
"suffix", | |
"logprobs", | |
"echo", | |
"best_of", | |
"response_format", | |
"seed", | |
"tools", | |
"tool_choice", | |
"parallel_tool_calls", | |
} | |
// 删除不支持的参数 | |
for _, param := range unsupportedParams { | |
delete(requestData, param) | |
} | |
// 重新序列化 | |
return json.Marshal(requestData) | |
} | |
// handleRequest 处理所有的API请求 | |
func (s *RelayServer) handleRequest(w http.ResponseWriter, r *http.Request) { | |
// 检查是否有Authorization头 | |
authHeader := r.Header.Get("Authorization") | |
if authHeader == "" { | |
w.Header().Set("Content-Type", "application/json") | |
w.WriteHeader(http.StatusUnauthorized) | |
w.Write([]byte(`{"error": {"message": "Missing Authorization header", "type": "invalid_request_error"}}`)) | |
return | |
} | |
// 读取请求体 | |
bodyBytes, err := io.ReadAll(r.Body) | |
if err != nil { | |
w.Header().Set("Content-Type", "application/json") | |
w.WriteHeader(http.StatusBadRequest) | |
w.Write([]byte(`{"error": {"message": "Failed to read request body", "type": "invalid_request_error"}}`)) | |
return | |
} | |
defer r.Body.Close() | |
// 打印原始请求(调试用) | |
log.Printf("Original request body: %s", string(bodyBytes)) | |
// 过滤请求参数 | |
filteredBody, err := filterRequest(bodyBytes) | |
if err != nil { | |
log.Printf("Failed to filter request: %v", err) | |
filteredBody = bodyBytes // 使用原始数据 | |
} | |
log.Printf("Filtered request body: %s", string(filteredBody)) | |
// 创建新的请求 | |
proxyReq, err := http.NewRequest("POST", geminiOpenAIEndpoint, bytes.NewReader(filteredBody)) | |
if err != nil { | |
w.Header().Set("Content-Type", "application/json") | |
w.WriteHeader(http.StatusInternalServerError) | |
w.Write([]byte(`{"error": {"message": "Failed to create proxy request", "type": "server_error"}}`)) | |
return | |
} | |
// 复制所有请求头 | |
for name, values := range r.Header { | |
// 跳过Host和Content-Length,这些会自动设置 | |
if name == "Host" || name == "Content-Length" { | |
continue | |
} | |
for _, value := range values { | |
proxyReq.Header.Add(name, value) | |
} | |
} | |
// 确保Authorization头被正确设置 | |
proxyReq.Header.Set("Authorization", authHeader) | |
proxyReq.Header.Set("Content-Type", "application/json") | |
log.Printf("Request headers being sent to Gemini: %v", proxyReq.Header) | |
// 发送请求 | |
resp, err := s.client.Do(proxyReq) | |
if err != nil { | |
log.Printf("Failed to send request to Gemini: %v", err) | |
w.Header().Set("Content-Type", "application/json") | |
w.WriteHeader(http.StatusBadGateway) | |
w.Write([]byte(`{"error": {"message": "Failed to connect to Gemini API", "type": "server_error"}}`)) | |
return | |
} | |
defer resp.Body.Close() | |
// 打印响应状态(调试用) | |
log.Printf("Response status from Gemini: %d", resp.StatusCode) | |
// 复制响应头 | |
for name, values := range resp.Header { | |
// 跳过一些头部 | |
if name == "Content-Length" { | |
continue | |
} | |
for _, value := range values { | |
w.Header().Add(name, value) | |
} | |
} | |
// 添加CORS头 | |
w.Header().Set("Access-Control-Allow-Origin", "*") | |
// 设置状态码 | |
w.WriteHeader(resp.StatusCode) | |
// 处理流式响应 | |
if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { | |
// 确保流式响应的头部设置正确 | |
w.Header().Set("Cache-Control", "no-cache") | |
w.Header().Set("Connection", "keep-alive") | |
// 使用缓冲区进行流式传输 | |
buf := make([]byte, 1024) | |
for { | |
n, err := resp.Body.Read(buf) | |
if n > 0 { | |
if _, writeErr := w.Write(buf[:n]); writeErr != nil { | |
log.Printf("Error writing response: %v", writeErr) | |
return | |
} | |
if flusher, ok := w.(http.Flusher); ok { | |
flusher.Flush() | |
} | |
} | |
if err != nil { | |
if err != io.EOF { | |
log.Printf("Error reading response: %v", err) | |
} | |
break | |
} | |
} | |
} else { | |
// 非流式响应 | |
// 如果是错误响应,打印出来以便调试 | |
if resp.StatusCode >= 400 { | |
bodyBytes, _ := io.ReadAll(resp.Body) | |
log.Printf("Error response from Gemini: %s", string(bodyBytes)) | |
w.Write(bodyBytes) | |
} else { | |
io.Copy(w, resp.Body) | |
} | |
} | |
} | |
// handleHealth 健康检查端点 | |
func (s *RelayServer) handleHealth(w http.ResponseWriter, r *http.Request) { | |
w.Header().Set("Content-Type", "application/json") | |
w.WriteHeader(http.StatusOK) | |
w.Write([]byte(`{"status": "ok", "service": "gemini-relay"}`)) | |
} | |
// corsMiddleware CORS中间件 | |
func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { | |
return func(w http.ResponseWriter, r *http.Request) { | |
w.Header().Set("Access-Control-Allow-Origin", "*") | |
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") | |
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") | |
w.Header().Set("Access-Control-Max-Age", "86400") | |
if r.Method == "OPTIONS" { | |
w.WriteHeader(http.StatusOK) | |
return | |
} | |
next(w, r) | |
} | |
} | |
// loggingMiddleware 日志中间件 | |
func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc { | |
return func(w http.ResponseWriter, r *http.Request) { | |
log.Printf("[%s] %s %s %s", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent()) | |
next(w, r) | |
} | |
} | |
// 模型映射 | |
func (s *RelayServer) handleModels(w http.ResponseWriter, r *http.Request) { | |
// 检查Authorization | |
if r.Header.Get("Authorization") == "" { | |
w.Header().Set("Content-Type", "application/json") | |
w.WriteHeader(http.StatusUnauthorized) | |
w.Write([]byte(`{"error": {"message": "Missing Authorization header", "type": "invalid_request_error"}}`)) | |
return | |
} | |
w.Header().Set("Content-Type", "application/json") | |
w.WriteHeader(http.StatusOK) | |
models := `{ | |
"object": "list", | |
"data": [ | |
{ | |
"id": "gemini-1.5-pro", | |
"object": "model", | |
"created": 1686935002, | |
"owned_by": "google" | |
}, | |
{ | |
"id": "gemini-1.5-flash", | |
"object": "model", | |
"created": 1686935002, | |
"owned_by": "google" | |
}, | |
{ | |
"id": "gemini-1.5-flash-8b", | |
"object": "model", | |
"created": 1686935002, | |
"owned_by": "google" | |
}, | |
{ | |
"id": "gemini-2.0-flash-exp", | |
"object": "model", | |
"created": 1686935002, | |
"owned_by": "google" | |
} | |
] | |
}` | |
w.Write([]byte(models)) | |
} | |
func main() { | |
// 创建中继服务器 | |
server := NewRelayServer() | |
// 设置路由 | |
mux := http.NewServeMux() | |
// OpenAI兼容的端点 | |
mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(server.handleRequest))) | |
mux.HandleFunc("/chat/completions", corsMiddleware(loggingMiddleware(server.handleRequest))) | |
mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(server.handleModels))) | |
mux.HandleFunc("/models", corsMiddleware(loggingMiddleware(server.handleModels))) | |
// 健康检查 | |
mux.HandleFunc("/health", corsMiddleware(server.handleHealth)) | |
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
w.Header().Set("Content-Type", "application/json") | |
w.Write([]byte(`{ | |
"service": "Gemini API Relay", | |
"version": "1.0.0", | |
"endpoints": { | |
"chat": "/v1/chat/completions", | |
"models": "/v1/models", | |
"health": "/health" | |
}, | |
"supported_models": [ | |
"gemini-1.5-pro", | |
"gemini-1.5-flash", | |
"gemini-1.5-flash-8b", | |
"gemini-2.0-flash-exp" | |
], | |
"note": "Use Authorization header with 'Bearer YOUR_GEMINI_API_KEY'" | |
}`)) | |
}) | |
// 启动服务器 | |
port := defaultPort | |
log.Printf("========================================") | |
log.Printf("Gemini API Relay Server") | |
log.Printf("Port: %s", port) | |
log.Printf("Endpoint: %s", geminiOpenAIEndpoint) | |
log.Printf("========================================") | |
log.Printf("Usage:") | |
log.Printf(" Authorization: Bearer YOUR_GEMINI_API_KEY") | |
log.Printf("========================================") | |
if err := http.ListenAndServe(":"+port, mux); err != nil { | |
log.Fatalf("Server failed to start: %v", err) | |
} | |
} |