|
import asyncio
|
|
import json
|
|
from datetime import datetime, timezone
|
|
import os
|
|
import base64
|
|
import tempfile
|
|
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional, Dict, Any, Union
|
|
import time
|
|
import uuid
|
|
import logging
|
|
|
|
from gemini_webapi import GeminiClient, set_log_level
|
|
from gemini_webapi.constants import Model
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
set_log_level("INFO")
|
|
|
|
app = FastAPI(title="Gemini API FastAPI Server")
|
|
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
gemini_client = None
|
|
|
|
|
|
SECURE_1PSID = os.environ.get("SECURE_1PSID", "")
|
|
SECURE_1PSIDTS = os.environ.get("SECURE_1PSIDTS", "")
|
|
|
|
|
|
if not SECURE_1PSID or not SECURE_1PSIDTS:
|
|
logger.warning("⚠️ Gemini API credentials are not set or empty! Please check your environment variables.")
|
|
logger.warning("Make sure SECURE_1PSID and SECURE_1PSIDTS are correctly set in your .env file or environment.")
|
|
logger.warning("If using Docker, ensure the .env file is correctly mounted and formatted.")
|
|
logger.warning("Example format in .env file (no quotes):")
|
|
logger.warning("SECURE_1PSID=your_secure_1psid_value_here")
|
|
logger.warning("SECURE_1PSIDTS=your_secure_1psidts_value_here")
|
|
else:
|
|
|
|
logger.info(f"Credentials found. SECURE_1PSID starts with: {SECURE_1PSID[:5]}...")
|
|
logger.info(f"Credentials found. SECURE_1PSIDTS starts with: {SECURE_1PSIDTS[:5]}...")
|
|
|
|
|
|
|
|
class ContentItem(BaseModel):
|
|
type: str
|
|
text: Optional[str] = None
|
|
image_url: Optional[Dict[str, str]] = None
|
|
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: Union[str, List[ContentItem]]
|
|
name: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages: List[Message]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0
|
|
frequency_penalty: Optional[float] = 0
|
|
user: Optional[str] = None
|
|
|
|
|
|
class Choice(BaseModel):
|
|
index: int
|
|
message: Message
|
|
finish_reason: str
|
|
|
|
|
|
class Usage(BaseModel):
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
total_tokens: int
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
id: str
|
|
object: str = "chat.completion"
|
|
created: int
|
|
model: str
|
|
choices: List[Choice]
|
|
usage: Usage
|
|
|
|
|
|
class ModelData(BaseModel):
|
|
id: str
|
|
object: str = "model"
|
|
created: int
|
|
owned_by: str = "google"
|
|
|
|
|
|
class ModelList(BaseModel):
|
|
object: str = "list"
|
|
data: List[ModelData]
|
|
|
|
|
|
|
|
@app.middleware("http")
|
|
async def error_handling(request: Request, call_next):
|
|
try:
|
|
return await call_next(request)
|
|
except Exception as e:
|
|
logger.error(f"Request failed: {str(e)}")
|
|
return JSONResponse(status_code=500, content={"error": {"message": str(e), "type": "internal_server_error"}})
|
|
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models():
|
|
"""返回 gemini_webapi 中声明的模型列表"""
|
|
now = int(datetime.now(tz=timezone.utc).timestamp())
|
|
data = [
|
|
{
|
|
"id": m.model_name,
|
|
"object": "model",
|
|
"created": now,
|
|
"owned_by": "google-gemini-web",
|
|
}
|
|
for m in Model
|
|
]
|
|
print(data)
|
|
return {"object": "list", "data": data}
|
|
|
|
|
|
|
|
def map_model_name(openai_model_name: str) -> Model:
|
|
"""根据模型名称字符串查找匹配的 Model 枚举值"""
|
|
|
|
all_models = [m.model_name if hasattr(m, "model_name") else str(m) for m in Model]
|
|
logger.info(f"Available models: {all_models}")
|
|
|
|
|
|
for m in Model:
|
|
model_name = m.model_name if hasattr(m, "model_name") else str(m)
|
|
if openai_model_name.lower() in model_name.lower():
|
|
return m
|
|
|
|
|
|
model_keywords = {
|
|
"gemini-pro": ["pro", "2.0"],
|
|
"gemini-pro-vision": ["vision", "pro"],
|
|
"gemini-flash": ["flash", "2.0"],
|
|
"gemini-1.5-pro": ["1.5", "pro"],
|
|
"gemini-1.5-flash": ["1.5", "flash"],
|
|
}
|
|
|
|
|
|
keywords = model_keywords.get(openai_model_name, ["pro"])
|
|
|
|
for m in Model:
|
|
model_name = m.model_name if hasattr(m, "model_name") else str(m)
|
|
if all(kw.lower() in model_name.lower() for kw in keywords):
|
|
return m
|
|
|
|
|
|
return next(iter(Model))
|
|
|
|
|
|
|
|
def prepare_conversation(messages: List[Message]) -> tuple:
|
|
conversation = ""
|
|
temp_files = []
|
|
|
|
for msg in messages:
|
|
if isinstance(msg.content, str):
|
|
|
|
if msg.role == "system":
|
|
conversation += f"System: {msg.content}\n\n"
|
|
elif msg.role == "user":
|
|
conversation += f"Human: {msg.content}\n\n"
|
|
elif msg.role == "assistant":
|
|
conversation += f"Assistant: {msg.content}\n\n"
|
|
else:
|
|
|
|
if msg.role == "user":
|
|
conversation += "Human: "
|
|
elif msg.role == "system":
|
|
conversation += "System: "
|
|
elif msg.role == "assistant":
|
|
conversation += "Assistant: "
|
|
|
|
for item in msg.content:
|
|
if item.type == "text":
|
|
conversation += item.text or ""
|
|
elif item.type == "image_url" and item.image_url:
|
|
|
|
image_url = item.image_url.get("url", "")
|
|
if image_url.startswith("data:image/"):
|
|
|
|
try:
|
|
|
|
base64_data = image_url.split(",")[1]
|
|
image_data = base64.b64decode(base64_data)
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
|
|
tmp.write(image_data)
|
|
temp_files.append(tmp.name)
|
|
except Exception as e:
|
|
logger.error(f"Error processing base64 image: {str(e)}")
|
|
|
|
conversation += "\n\n"
|
|
|
|
|
|
conversation += "Assistant: "
|
|
|
|
return conversation, temp_files
|
|
|
|
|
|
|
|
async def get_gemini_client():
|
|
global gemini_client
|
|
if gemini_client is None:
|
|
try:
|
|
gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS)
|
|
await gemini_client.init(timeout=300)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Gemini client: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to initialize Gemini client: {str(e)}")
|
|
return gemini_client
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def create_chat_completion(request: ChatCompletionRequest):
|
|
try:
|
|
|
|
global gemini_client
|
|
if gemini_client is None:
|
|
gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS)
|
|
await gemini_client.init(timeout=300)
|
|
logger.info("Gemini client initialized successfully")
|
|
|
|
|
|
conversation, temp_files = prepare_conversation(request.messages)
|
|
logger.info(f"Prepared conversation: {conversation}")
|
|
logger.info(f"Temp files: {temp_files}")
|
|
|
|
|
|
model = map_model_name(request.model)
|
|
logger.info(f"Using model: {model}")
|
|
|
|
|
|
logger.info("Sending request to Gemini...")
|
|
if temp_files:
|
|
|
|
response = await gemini_client.generate_content(conversation, files=temp_files, model=model)
|
|
else:
|
|
|
|
response = await gemini_client.generate_content(conversation, model=model)
|
|
|
|
|
|
for temp_file in temp_files:
|
|
try:
|
|
os.unlink(temp_file)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete temp file {temp_file}: {str(e)}")
|
|
|
|
|
|
reply_text = ""
|
|
if hasattr(response, "text"):
|
|
reply_text = response.text
|
|
else:
|
|
reply_text = str(response)
|
|
|
|
logger.info(f"Response: {reply_text}")
|
|
|
|
if not reply_text or reply_text.strip() == "":
|
|
logger.warning("Empty response received from Gemini")
|
|
reply_text = "服务器返回了空响应。请检查 Gemini API 凭据是否有效。"
|
|
|
|
|
|
completion_id = f"chatcmpl-{uuid.uuid4()}"
|
|
created_time = int(time.time())
|
|
|
|
|
|
if request.stream:
|
|
|
|
async def generate_stream():
|
|
|
|
|
|
data = {
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created_time,
|
|
"model": request.model,
|
|
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
|
}
|
|
yield f"data: {json.dumps(data)}\n\n"
|
|
|
|
|
|
for char in reply_text:
|
|
data = {
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created_time,
|
|
"model": request.model,
|
|
"choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None}],
|
|
}
|
|
yield f"data: {json.dumps(data)}\n\n"
|
|
|
|
await asyncio.sleep(0.01)
|
|
|
|
|
|
data = {
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": created_time,
|
|
"model": request.model,
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
}
|
|
yield f"data: {json.dumps(data)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
|
else:
|
|
|
|
result = {
|
|
"id": completion_id,
|
|
"object": "chat.completion",
|
|
"created": created_time,
|
|
"model": request.model,
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": reply_text}, "finish_reason": "stop"}],
|
|
"usage": {
|
|
"prompt_tokens": len(conversation.split()),
|
|
"completion_tokens": len(reply_text.split()),
|
|
"total_tokens": len(conversation.split()) + len(reply_text.split()),
|
|
},
|
|
}
|
|
|
|
logger.info(f"Returning response: {result}")
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating completion: {str(e)}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}")
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"status": "online", "message": "Gemini API FastAPI Server is running"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info")
|
|
|