Spaces:
Sleeping
Sleeping
| """ | |
| SambaNova OpenAI 接口代理 (支持模型列表透传和自动登录) | |
| """ | |
| import os | |
| import uuid | |
| import json | |
| import time | |
| import asyncio | |
| import httpx | |
| import secrets | |
| import urllib.parse | |
| from typing import Optional, Dict, Any | |
| from fastapi import FastAPI, Request, HTTPException, Depends, Header | |
| from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fake_useragent import UserAgent | |
| # 修复 Pydantic 导入 | |
| try: | |
| # 尝试从 pydantic-settings 导入 (Pydantic v2) | |
| from pydantic_settings import BaseSettings | |
| except ImportError: | |
| # 回退到旧版本 (Pydantic v1) | |
| from pydantic import BaseSettings | |
| # ================ 配置 ================ | |
| class Settings(BaseSettings): | |
| # SambaNova 配置 | |
| SAMBA_EMAIL: str = os.getenv("SAMBA_EMAIL", "") | |
| SAMBA_PASSWORD: str = os.getenv("SAMBA_PASSWORD", "") | |
| SAMBA_COMPLETION_URL: str = os.getenv("SAMBA_COMPLETION_URL", "https://cloud.sambanova.ai/api/completion") | |
| SAMBA_MODELS_URL: str = os.getenv("SAMBA_MODELS_URL", "https://api.sambanova.ai/v1/models") | |
| # 本地API密钥配置 | |
| LOCAL_API_KEY: str = os.getenv("LOCAL_API_KEY", secrets.token_urlsafe(32)) | |
| # 其他配置 | |
| TOKEN_CACHE_TIME: int = int(os.getenv("TOKEN_CACHE_TIME", 604800)) # 默认缓存7天 (7*24*60*60=604800秒) | |
| FINGERPRINT_PREFIX: str = os.getenv("FINGERPRINT_PREFIX", "anon_") | |
| class Config: | |
| env_file = ".env" | |
| settings = Settings() | |
| # ===================================== | |
| app = FastAPI(title="SambaNova OpenAI Proxy with Auto-Login") | |
| security = HTTPBearer() | |
| # 全局变量存储访问令牌和过期时间 | |
| access_token = None | |
| token_expiry = 0 | |
| token_lock = asyncio.Lock() | |
| def generate_fingerprint() -> str: | |
| """生成符合格式要求的随机指纹""" | |
| return f"{settings.FINGERPRINT_PREFIX}{uuid.uuid4().hex[:20]}" | |
| async def validate_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: | |
| """验证本地API密钥并返回SambaNova访问令牌""" | |
| api_key = credentials.credentials | |
| # 如果未配置本地API密钥或为空,则跳过验证 | |
| if settings.LOCAL_API_KEY and settings.LOCAL_API_KEY.strip(): | |
| # 验证本地API密钥 | |
| if api_key != settings.LOCAL_API_KEY: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid API key", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| else: | |
| print("[警告] LOCAL_API_KEY未配置或为空,跳过API密钥验证") | |
| # 获取或刷新SambaNova访问令牌 | |
| token = await get_samba_token() | |
| if not token: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Failed to obtain SambaNova access token. Check server logs for details." | |
| ) | |
| return token | |
| async def get_samba_token() -> Optional[str]: | |
| """获取或刷新SambaNova访问令牌""" | |
| global access_token, token_expiry | |
| # 使用锁防止并发请求同时刷新令牌 | |
| async with token_lock: | |
| current_time = time.time() | |
| # 如果令牌有效,直接返回 | |
| if access_token and current_time < token_expiry: | |
| print(f"[令牌] 使用缓存令牌: {access_token}") | |
| return access_token | |
| # 否则获取新令牌 | |
| try: | |
| # 检查凭据是否已配置 | |
| if not settings.SAMBA_EMAIL or not settings.SAMBA_PASSWORD: | |
| print("[错误] 未配置SambaNova凭据,请设置SAMBA_EMAIL和SAMBA_PASSWORD环境变量") | |
| return None | |
| print(f"[令牌] 开始获取新令牌... 邮箱: {settings.SAMBA_EMAIL}") | |
| auth = SambaAuthAsync(settings.SAMBA_EMAIL, settings.SAMBA_PASSWORD) | |
| new_token = await auth.login() | |
| if new_token: | |
| access_token = new_token | |
| token_expiry = current_time + settings.TOKEN_CACHE_TIME | |
| print(f"[令牌更新成功] 完整令牌: {new_token}") | |
| print(f"[令牌更新成功] 令牌将在 {settings.TOKEN_CACHE_TIME} 秒后过期") | |
| return access_token | |
| else: | |
| print("[令牌获取失败] 请检查SambaNova凭据是否正确") | |
| return None | |
| except Exception as e: | |
| print(f"[令牌获取异常] {str(e)}") | |
| return None | |
| def reset_token_expiry(): | |
| """重置令牌过期时间,强制下次请求重新获取令牌""" | |
| global token_expiry | |
| token_expiry = 0 | |
| print("[令牌] 令牌已过期,将在下次请求时重新获取") | |
| async def forward_get_request(url: str, token: str) -> httpx.Response: | |
| """转发 GET 请求到目标接口""" | |
| headers = { | |
| "accept": "application/json", | |
| "user-agent": "SambaNova-Proxy/1.0", | |
| "origin": "https://cloud.sambanova.ai", | |
| "referer": "https://cloud.sambanova.ai/" | |
| } | |
| cookies = { | |
| "access_token": token | |
| } | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| resp = await client.get( | |
| url, | |
| headers=headers, | |
| cookies=cookies, | |
| timeout=10.0 | |
| ) | |
| # 检查是否需要刷新令牌 | |
| if resp.status_code == 401: | |
| # 令牌已过期,需要刷新 | |
| reset_token_expiry() | |
| raise HTTPException(401, "Token expired, please retry") | |
| resp.raise_for_status() | |
| return resp | |
| except httpx.HTTPStatusError as e: | |
| if e.response.status_code == 401: | |
| # 令牌已过期,需要刷新 | |
| reset_token_expiry() | |
| raise HTTPException(401, "Token expired, please retry") | |
| raise HTTPException(e.response.status_code, f"Upstream error: {e.response.text}") | |
| async def forward_post_request(url: str, payload: dict, token: str) -> httpx.Response: | |
| """转发 POST 请求到目标接口""" | |
| headers = { | |
| "content-type": "application/json", | |
| "user-agent": "SambaNova-Proxy/1.0", | |
| "origin": "https://cloud.sambanova.ai", | |
| "referer": "https://cloud.sambanova.ai/" | |
| } | |
| cookies = { | |
| "access_token": token | |
| } | |
| async with httpx.AsyncClient() as client: | |
| try: | |
| resp = await client.post( | |
| url, | |
| json=payload, | |
| headers=headers, | |
| cookies=cookies, | |
| timeout=30.0 | |
| ) | |
| # 检查是否需要刷新令牌 | |
| if resp.status_code == 401: | |
| # 令牌已过期,需要刷新 | |
| reset_token_expiry() | |
| raise HTTPException(401, "Token expired, please retry") | |
| resp.raise_for_status() | |
| return resp | |
| except httpx.HTTPStatusError as e: | |
| if e.response.status_code == 401: | |
| # 令牌已过期,需要刷新 | |
| reset_token_expiry() | |
| raise HTTPException(401, "Token expired, please retry") | |
| raise HTTPException(e.response.status_code, f"Upstream error: {e.response.text}") | |
| async def list_models(token: str = Depends(validate_api_key)): | |
| """透传模型列表接口""" | |
| try: | |
| resp = await forward_get_request(settings.SAMBA_MODELS_URL, token) | |
| content = resp.json() | |
| json_str = json.dumps(content, separators=(',', ':'), ensure_ascii=False) | |
| json_bytes = json_str.encode('utf-8') | |
| return JSONResponse( | |
| content=content, | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Content-Length": str(len(json_bytes)), | |
| "Cache-Control": "public, max-age=300" | |
| } | |
| ) | |
| except httpx.RequestError as e: | |
| raise HTTPException(504, f"Gateway timeout: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(500, f"Internal server error: {str(e)}") | |
| async def chat_completions( | |
| request: Request, | |
| token: str = Depends(validate_api_key) | |
| ): | |
| """处理对话请求""" | |
| try: | |
| openai_payload = await request.json() | |
| print(f"[请求] 收到聊天请求,模型: {openai_payload.get('model', 'DeepSeek-R1')}") | |
| samba_payload = { | |
| "body": { | |
| "model": openai_payload.get("model", "DeepSeek-R1"), | |
| "messages": openai_payload["messages"], | |
| "stream": True, | |
| "stop": openai_payload.get("stop", ["<|eot_id|>"]), | |
| "temperature": openai_payload.get("temperature", 0), | |
| "max_tokens": openai_payload.get("max_tokens", 2048), | |
| "do_sample": openai_payload.get("temperature", 0) > 0 | |
| }, | |
| "env_type": "text", | |
| "fingerprint": generate_fingerprint() | |
| } | |
| print(f"[转发] 使用令牌 {token[:10]}... 转发请求到 SambaNova") | |
| resp = await forward_post_request(settings.SAMBA_COMPLETION_URL, samba_payload, token) | |
| print(f"[响应] 成功获取响应,开始流式传输") | |
| return StreamingResponse( | |
| resp.aiter_bytes(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "X-Proxy-Version": "1.0", | |
| "X-Request-ID": str(uuid.uuid4()) | |
| } | |
| ) | |
| except HTTPException as e: | |
| print(f"[错误] HTTP异常: {e.detail}") | |
| raise | |
| except httpx.RequestError as e: | |
| print(f"[错误] 请求错误: {str(e)}") | |
| raise HTTPException(504, f"Gateway timeout: {str(e)}") | |
| except Exception as e: | |
| print(f"[错误] 未处理异常: {str(e)}") | |
| raise HTTPException(500, f"Internal server error: {str(e)}") | |
| async def get_info(): | |
| """获取服务信息""" | |
| return { | |
| "status": "running", | |
| "api_key_configured": bool(settings.LOCAL_API_KEY), | |
| "samba_credentials_configured": bool(settings.SAMBA_EMAIL and settings.SAMBA_PASSWORD), | |
| "token_status": "active" if access_token and time.time() < token_expiry else "not_available", | |
| "token_expires_in": max(0, int(token_expiry - time.time())) if access_token else 0 | |
| } | |
| async def debug_token(): | |
| """调试端点:检查当前令牌状态""" | |
| global access_token, token_expiry | |
| current_time = time.time() | |
| return { | |
| "token_exists": access_token is not None, | |
| "token_prefix": access_token[:10] + "..." if access_token else None, | |
| "token_valid": access_token is not None and current_time < token_expiry, | |
| "expires_in_seconds": max(0, int(token_expiry - current_time)) if access_token else 0, | |
| "current_time": current_time, | |
| "expiry_time": token_expiry, | |
| } | |
| async def root(): | |
| """根路由健康检查,返回HTML界面""" | |
| current_time = time.time() | |
| token_valid = access_token is not None and current_time < token_expiry | |
| expires_in = max(0, int(token_expiry - current_time)) if access_token else 0 | |
| # 计算过期时间的可读格式 | |
| if expires_in > 0: | |
| days = expires_in // 86400 | |
| hours = (expires_in % 86400) // 3600 | |
| minutes = (expires_in % 3600) // 60 | |
| expiry_readable = f"{days}天 {hours}小时 {minutes}分钟" | |
| else: | |
| expiry_readable = "已过期" | |
| # 使用东八区时间(中国标准时间) | |
| import datetime | |
| from datetime import timezone, timedelta | |
| # 创建东八区时区对象 | |
| china_tz = timezone(timedelta(hours=8)) | |
| # 获取当前UTC时间并转换为东八区时间 | |
| current_time_china = datetime.datetime.now(china_tz) | |
| formatted_time = current_time_china.strftime('%Y-%m-%d %H:%M:%S') | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>SambaNova OpenAI 代理服务</title> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <style> | |
| body {{ | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; | |
| line-height: 1.6; | |
| color: #333; | |
| max-width: 800px; | |
| margin: 0 auto; | |
| padding: 20px; | |
| }} | |
| h1 {{ | |
| color: #2c3e50; | |
| border-bottom: 1px solid #eee; | |
| padding-bottom: 10px; | |
| }} | |
| .status-card {{ | |
| background-color: #f8f9fa; | |
| border-radius: 8px; | |
| padding: 20px; | |
| margin-bottom: 20px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| }} | |
| .status-item {{ | |
| margin-bottom: 10px; | |
| display: flex; | |
| justify-content: space-between; | |
| }} | |
| .status-label {{ | |
| font-weight: bold; | |
| color: #555; | |
| }} | |
| .status-value {{ | |
| text-align: right; | |
| }} | |
| .status-healthy {{ | |
| color: #28a745; | |
| font-weight: bold; | |
| }} | |
| .status-warning {{ | |
| color: #ffc107; | |
| font-weight: bold; | |
| }} | |
| .status-error {{ | |
| color: #dc3545; | |
| font-weight: bold; | |
| }} | |
| .code-block {{ | |
| background-color: #f1f1f1; | |
| padding: 15px; | |
| border-radius: 5px; | |
| font-family: monospace; | |
| overflow-x: auto; | |
| }} | |
| .footer {{ | |
| margin-top: 30px; | |
| font-size: 0.9em; | |
| color: #6c757d; | |
| text-align: center; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>SambaNova OpenAI 代理服务</h1> | |
| <div class="status-card"> | |
| <h2>服务状态</h2> | |
| <div class="status-item"> | |
| <span class="status-label">状态:</span> | |
| <span class="status-value status-healthy">运行中</span> | |
| </div> | |
| <div class="status-item"> | |
| <span class="status-label">版本:</span> | |
| <span class="status-value">1.0.0</span> | |
| </div> | |
| <div class="status-item"> | |
| <span class="status-label">令牌状态:</span> | |
| <span class="status-value {('status-healthy' if token_valid else 'status-error')}"> | |
| {('有效' if token_valid else '无效')} | |
| </span> | |
| </div> | |
| <div class="status-item"> | |
| <span class="status-label">令牌过期时间:</span> | |
| <span class="status-value">{expiry_readable}</span> | |
| </div> | |
| <div class="status-item"> | |
| <span class="status-label">SambaNova 凭据:</span> | |
| <span class="status-value {('status-healthy' if settings.SAMBA_EMAIL and settings.SAMBA_PASSWORD else 'status-error')}"> | |
| {('已配置' if settings.SAMBA_EMAIL and settings.SAMBA_PASSWORD else '未配置')} | |
| </span> | |
| </div> | |
| <div class="status-item"> | |
| <span class="status-label">本地API密钥:</span> | |
| <span class="status-value {('status-healthy' if settings.LOCAL_API_KEY else 'status-warning')}"> | |
| {('已配置' if settings.LOCAL_API_KEY else '未配置')} | |
| </span> | |
| </div> | |
| </div> | |
| <div class="footer"> | |
| <p>当前时间: {formatted_time} (中国标准时间)</p> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return html_content | |
| class SambaAuthAsync: | |
| def __init__(self, email, password): | |
| self.email = email | |
| self.password = password | |
| self.client = httpx.AsyncClient() | |
| self.ua = UserAgent() | |
| self.base_headers = { | |
| "accept": "*/*", | |
| "accept-language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7", | |
| "origin": "https://cloud.sambanova.ai", | |
| "referer": "https://cloud.sambanova.ai/", | |
| "user-agent": self.ua.random | |
| } | |
| self.config = None | |
| self.nonce = None # 确保nonce属性存在 | |
| async def _get_config(self): | |
| """获取动态配置信息""" | |
| config_url = "https://cloud.sambanova.ai/api/config" | |
| response = await self.client.get(config_url, headers=self.base_headers) | |
| response.raise_for_status() | |
| self.config = response.json() | |
| print(f"[配置获取成功] ClientID: {self.config['clientId']}") | |
| async def _get_login_ticket(self): | |
| """获取登录票据""" | |
| auth_url = f"https://{self.config['issuerBaseUrl']}/co/authenticate" | |
| payload = { | |
| "client_id": self.config["clientId"], | |
| "username": self.email, | |
| "password": self.password, | |
| "realm": "Username-Password-Authentication", | |
| "credential_type": "http://auth0.com/oauth/grant-type/password-realm" | |
| } | |
| headers = {**self.base_headers, "content-type": "application/json"} | |
| response = await self.client.post(auth_url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| return response.json()["login_ticket"] | |
| async def _get_auth_code(self, login_ticket: str): | |
| """获取授权码""" | |
| state = secrets.token_urlsafe(32) | |
| self.nonce = secrets.token_urlsafe(32) # 保存nonce到实例变量 | |
| params = { | |
| "client_id": self.config["clientId"], | |
| "response_type": "code", | |
| "redirect_uri": self.config["redirectURL"], | |
| "scope": "openid profile email", | |
| "nonce": self.nonce, | |
| "state": state, | |
| "login_ticket": login_ticket, | |
| "realm": "Username-Password-Authentication", | |
| "auth0Client": "eyJuYW1lIjoibG9jay5qcyIsInZlcnNpb24iOiIxMi4zLjAiLCJlbnYiOnsiYXV0aDAuanMiOiI5LjIyLjEiLCJhdXRoMC5qcy11bHAiOiI5LjIyLjEifX0=" | |
| } | |
| auth_url = f"https://{self.config['issuerBaseUrl']}/authorize" | |
| response = await self.client.get( | |
| auth_url, | |
| params=params, | |
| follow_redirects=False | |
| ) | |
| if response.status_code == 302: | |
| location = response.headers["location"] | |
| parsed = urllib.parse.urlparse(location) | |
| query = urllib.parse.parse_qs(parsed.query) | |
| return query.get("code", [None])[0], state | |
| raise Exception(f"未收到302重定向,实际状态码:{response.status_code}") | |
| async def _exchange_token(self, code: str, state: str): | |
| """交换访问令牌""" | |
| # 设置必要的cookies | |
| self.client.cookies.set("nonce", self.nonce, domain="cloud.sambanova.ai") | |
| callback_url = f"{self.config['redirectURL']}?code={code}&state={state}" | |
| response = await self.client.get( | |
| callback_url, | |
| headers={ | |
| **self.base_headers, | |
| "sec-fetch-site": "same-site", | |
| "sec-fetch-mode": "navigate", | |
| "sec-fetch-user": "?1", | |
| "sec-fetch-dest": "document" | |
| }, | |
| follow_redirects=True | |
| ) | |
| # 从cookies中提取access_token | |
| for cookie in self.client.cookies.jar: | |
| if cookie.name == "access_token" and "sambanova.ai" in cookie.domain: | |
| return cookie.value | |
| raise Exception("未找到access_token") | |
| async def login(self): | |
| """完整登录流程""" | |
| try: | |
| await self._get_config() | |
| login_ticket = await self._get_login_ticket() | |
| print(f"[登录票据获取成功] 完整票据: {login_ticket}") | |
| auth_code, state = await self._get_auth_code(login_ticket) | |
| if not auth_code: | |
| raise Exception("授权码获取失败") | |
| print(f"[授权码获取成功] 完整授权码: {auth_code}") | |
| print(f"[授权状态] state: {state}") | |
| token = await self._exchange_token(auth_code, state) | |
| print(f"[令牌获取成功] 完整令牌: {token}") | |
| return token | |
| except Exception as e: | |
| print(f"[登录失败] 详细错误: {str(e)}") | |
| return None | |
| finally: | |
| await self.client.aclose() | |
| async def startup_event(): | |
| """应用启动时预获取令牌""" | |
| print("\n" + "="*50) | |
| print("[启动] SambaNova OpenAI 代理服务启动") | |
| print("="*50) | |
| # 检查环境变量 | |
| print(f"[环境] SAMBA_EMAIL: {'已设置' if settings.SAMBA_EMAIL else '未设置'}") | |
| print(f"[环境] SAMBA_PASSWORD: {'已设置' if settings.SAMBA_PASSWORD else '未设置'}") | |
| print(f"[环境] LOCAL_API_KEY: {'已设置' if settings.LOCAL_API_KEY else '未设置'}") | |
| # 尝试直接登录 | |
| print("[登录] 开始尝试登录...") | |
| try: | |
| auth = SambaAuthAsync(settings.SAMBA_EMAIL, settings.SAMBA_PASSWORD) | |
| token = await auth.login() | |
| if token: | |
| global access_token, token_expiry | |
| access_token = token | |
| token_expiry = time.time() + settings.TOKEN_CACHE_TIME | |
| print(f"[登录] 登录成功! 令牌: {token}") | |
| print(f"[登录] 令牌将在 {settings.TOKEN_CACHE_TIME} 秒后过期") | |
| else: | |
| print("[登录] 登录失败,未获取到令牌") | |
| except Exception as e: | |
| print(f"[登录] 登录过程发生异常: {str(e)}") | |
| print("="*50 + "\n") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |