Spaces:
Running
Running
| import asyncio | |
| import aiohttp | |
| from typing import List, Dict, Any, Optional, Set | |
| from datetime import datetime | |
| import random | |
| from .openrouter_client import OpenRouterClient | |
| from . import config | |
| from .utils import clean_model_name | |
| class ModelTester: | |
| def __init__(self): | |
| self.client = OpenRouterClient() | |
| self.max_concurrency = config.get_max_concurrency() | |
| self.test_prompt = config.get_test_prompt() | |
| self._all_models: List[str] = [] | |
| self._free_models: List[str] = [] | |
| self._available_models: List[str] = [] | |
| self._available_free_models: List[str] = [] | |
| self._scan_in_progress = False | |
| self._last_scan_time: Optional[datetime] = None | |
| self.scan_result: Dict[str, Any] = { | |
| "available_models": [], | |
| "available_free_models": [], | |
| "total_available": 0, | |
| "free_available": 0, | |
| "timestamp": None | |
| } | |
| def refresh_model_list(self): | |
| """Get latest model list from API""" | |
| models = self.client.get_models() | |
| all_ids = [] | |
| free_ids = [] | |
| for model in models: | |
| model_id = model.get("id", "") | |
| if model_id: | |
| all_ids.append(model_id) | |
| if ":free" in model_id: | |
| free_ids.append(model_id) | |
| self._all_models = all_ids | |
| self._free_models = free_ids | |
| return len(self._all_models), len(self._free_models) | |
| async def test_single_model_async( | |
| self, | |
| session: aiohttp.ClientSession, | |
| model_id: str, | |
| api_key: str | |
| ) -> tuple[str, bool]: | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| payload = { | |
| "model": model_id, | |
| "messages": [{"role": "user", "content": self.test_prompt}], | |
| "max_tokens": 10 | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=config.get_request_timeout()) | |
| async with session.post(url, json=payload, headers=headers, timeout=timeout) as response: | |
| is_success = response.status == 200 | |
| is_free = ":free" in model_id | |
| return model_id, is_success | |
| except Exception: | |
| return model_id, False | |
| async def scan_all_models_async(self): | |
| """Async scan all models concurrently""" | |
| if self._scan_in_progress: | |
| return {"error": "Scan already in progress"} | |
| self._scan_in_progress = True | |
| print(f"[{datetime.now()}] Starting model scan...") | |
| all_count, free_count = self.refresh_model_list() | |
| print(f"Total models: {all_count}, Free models: {free_count}") | |
| api_keys = config.get_api_keys() | |
| api_key = random.choice(api_keys) | |
| available: Set[str] = set() | |
| available_free: Set[str] = set() | |
| async with aiohttp.ClientSession() as session: | |
| semaphore = asyncio.Semaphore(self.max_concurrency) | |
| async def test_with_semaphore(model_id: str): | |
| async with semaphore: | |
| return await self.test_single_model_async(session, model_id, api_key) | |
| tasks = [test_with_semaphore(m) for m in self._all_models] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| for result in results: | |
| if isinstance(result, tuple): | |
| model_id, success = result | |
| cleaned = clean_model_name(model_id) | |
| if success: | |
| available.add(cleaned) | |
| if ":free" in model_id: | |
| available_free.add(cleaned) | |
| self._available_models = sorted(list(available)) | |
| self._available_free_models = sorted(list(available_free)) | |
| self._last_scan_time = datetime.now() | |
| self._scan_in_progress = False | |
| self.scan_result = { | |
| "available_models": self._available_models, | |
| "available_free_models": self._available_free_models, | |
| "total_available": len(self._available_models), | |
| "free_available": len(self._available_free_models), | |
| "timestamp": self._last_scan_time.isoformat() if self._last_scan_time else None | |
| } | |
| print(f"Scan complete: {len(self._available_free_models)} free, {len(self._available_models)} total available") | |
| return self.scan_result | |
| def scan_all_models(self): | |
| """Sync wrapper for scan""" | |
| return asyncio.run(self.scan_all_models_async()) | |
| def get_available_models(self, free_only: bool = False) -> List[str]: | |
| """Get available models list""" | |
| if free_only: | |
| return self._available_free_models | |
| return self._available_models | |
| def get_all_free_models(self) -> List[str]: | |
| """Get all free models from API list (not tested)""" | |
| return self._free_models | |
| async def try_model_direct_stream( | |
| self, | |
| session: aiohttp.ClientSession, | |
| model_id: str, | |
| api_key: str, | |
| messages: List[Dict[str, str]] | |
| ): | |
| """发送流式请求到OpenRouter,返回流式迭代器""" | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| payload = { | |
| "model": model_id, | |
| "messages": messages, | |
| "stream": True | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| async with session.post(url, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=120)) as response: | |
| async for line in response.content: | |
| yield line | |
| async def try_model_direct( | |
| self, | |
| session: aiohttp.ClientSession, | |
| model_id: str, | |
| api_key: str, | |
| prompt: str = None | |
| ) -> Optional[Dict[str, Any]]: | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| payload = { | |
| "model": model_id, | |
| "messages": [{"role": "user", "content": prompt or self.test_prompt}] | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=config.get_request_timeout()) | |
| async with session.post(url, json=payload, headers=headers, timeout=timeout) as response: | |
| body = await response.text() | |
| if response.status == 200: | |
| data = await response.json() | |
| return { | |
| "success": True, | |
| "model": model_id, | |
| "response": data, | |
| "method": "direct" | |
| } | |
| else: | |
| print(f"[try_model_direct] ERROR {model_id}: HTTP {response.status}, body: {body[:200]}") | |
| return { | |
| "success": False, | |
| "model": model_id, | |
| "error": f"HTTP {response.status}: {body[:100]}", | |
| "method": "direct" | |
| } | |
| except asyncio.TimeoutError: | |
| return {"success": False, "model": model_id, "error": "timeout", "method": "direct"} | |
| except Exception as e: | |
| return {"success": False, "model": model_id, "error": str(e), "method": "direct"} | |
| async def try_best_available_model( | |
| self, | |
| session: aiohttp.ClientSession, | |
| keyword: str, | |
| api_key: str, | |
| prompt: str = None | |
| ) -> Optional[Dict[str, Any]]: | |
| # 第一步:从API获取最新的free模型列表 | |
| print(f"[try_best] Keyword: {keyword}, Refreshing model list...") | |
| try: | |
| self.refresh_model_list() | |
| except Exception as e: | |
| print(f"[try_best] refresh_model_list failed: {e}") | |
| # 使用所有free模型,而不是已测试的 | |
| available_free = self.get_all_free_models() | |
| print(f"[try_best] Found {len(available_free)} free models") | |
| # 第二步:用关键词匹配模型(避免匹配到不完整的ID) | |
| candidates = [] | |
| if keyword and available_free: | |
| # 只匹配模型名部分,不匹配作者前缀 | |
| matched = [] | |
| for m in available_free: | |
| model_name = m.replace(":free", "").split("/")[-1] | |
| if keyword.lower() in model_name.lower(): | |
| matched.append(m) | |
| print(f"[try_best] Keyword '{keyword}' matched: {matched[:5]}") | |
| if matched: | |
| candidates.extend([(m, "matched") for m in matched[:10]]) | |
| # 如果关键词没匹配或没提供关键词,随机取free模型 | |
| if not candidates and available_free: | |
| candidates = [(m, "random") for m in available_free[:15]] | |
| print(f"[try_best] Using random models: {candidates[:3]}") | |
| # 如果列表为空,从API直接获取并测试 | |
| if not candidates: | |
| print("[try_best] No candidates, fetching from API directly...") | |
| try: | |
| all_models = self.client.get_models() | |
| all_free = [m.get("id", "") for m in all_models if ":free" in m.get("id", "")] | |
| print(f"[try_best] API returned {len(all_free)} free models") | |
| if all_free: | |
| candidates = [(m, "api") for m in all_free[:20]] | |
| except Exception as e: | |
| print(f"[try_best] API fetch failed: {e}") | |
| # 第三步:并发发送请求测试这些模型 | |
| if not candidates: | |
| return { | |
| "success": False, | |
| "model": None, | |
| "error": "No candidates available", | |
| "method": "list_empty" | |
| } | |
| print(f"[try_best] Testing {len(candidates)} candidates...") | |
| semaphore = asyncio.Semaphore(5) | |
| async def try_one(model_id, match_type): | |
| async with semaphore: | |
| # 模型ID已经是完整格式(包含:free),不需要再添加 | |
| print(f"[try_best] Testing model: {model_id}") | |
| result = await self.try_model_direct(session, model_id, api_key, prompt) | |
| if result: | |
| print(f"[try_best] Result for {model_id}: success={result.get('success')}, error={result.get('error', 'none')}") | |
| else: | |
| print(f"[try_best] Result for {model_id}: None") | |
| return result | |
| tasks = [try_one(m, t) for m, t in candidates] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| for i, result in enumerate(results): | |
| if isinstance(result, dict) and result.get("success"): | |
| result["method"] = f"list_{candidates[i][1]}" | |
| print(f"[try_best] SUCCESS with {candidates[i][0]}") | |
| return result | |
| print(f"[try_best] All candidates failed") | |
| return { | |
| "success": False, | |
| "model": candidates[0][0] if candidates else None, | |
| "error": "No available model found", | |
| "method": "list_fallback" | |
| } | |
| def find_model_in_list(self, keyword: str) -> Optional[str]: | |
| """Find full model ID from keyword""" | |
| available_free = self.get_all_free_models() | |
| # 先精确匹配 | |
| for model in available_free: | |
| model_name = model.replace(":free", "").split("/")[-1] | |
| if model_name.lower() == keyword.lower(): | |
| return model | |
| # 然后模糊匹配 | |
| for model in available_free: | |
| if keyword.lower() in model.lower(): | |
| return model | |
| return None | |
| async def chat_completion(self, prompt: str, model_hint: Optional[str] = None) -> Dict[str, Any]: | |
| api_keys = config.get_api_keys() | |
| api_key = random.choice(api_keys) | |
| async with aiohttp.ClientSession() as session: | |
| tasks = [] | |
| # 方案1:用户指定模型,需要先找到完整的模型ID | |
| if model_hint: | |
| # 尝试在模型列表中找到匹配的完整模型ID | |
| full_model_id = self.find_model_in_list(model_hint) | |
| if full_model_id: | |
| # 找到完整ID,直接使用 | |
| tasks.append(asyncio.create_task( | |
| self.try_model_direct(session, full_model_id, api_key, prompt) | |
| )) | |
| else: | |
| # 没找到,尝试用原始输入(可能是完整ID) | |
| full_model = f"{model_hint}:free" if ":free" not in model_hint else model_hint | |
| tasks.append(asyncio.create_task( | |
| self.try_model_direct(session, full_model, api_key, prompt) | |
| )) | |
| tasks.append(asyncio.create_task( | |
| self.try_best_available_model(session, model_hint or "", api_key, prompt) | |
| )) | |
| # 等待所有任务完成 | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # 先检查方案1 | |
| result1 = results[0] | |
| if isinstance(result1, dict) and result1.get("success"): | |
| return { | |
| "success": True, | |
| "response": result1.get("response"), | |
| "method": result1.get("method"), | |
| "model": result1.get("model") | |
| } | |
| # 方案1失败,检查方案2 | |
| result2 = results[1] | |
| if isinstance(result2, dict) and result2.get("success"): | |
| return { | |
| "success": True, | |
| "response": result2.get("response"), | |
| "method": result2.get("method"), | |
| "model": result2.get("model") | |
| } | |
| # 都失败了,返回方案2的错误(更详细) | |
| return { | |
| "success": False, | |
| "error": result2.get("error", "Unknown error") if isinstance(result2, dict) else "Request failed", | |
| "method": result2.get("method", "both_failed") if isinstance(result2, dict) else "both_failed" | |
| } | |
| def chat_completion_sync(self, prompt: str, model_hint: Optional[str] = None) -> Dict[str, Any]: | |
| return asyncio.run(self.chat_completion(prompt, model_hint)) | |
| def test_single_model(self, model_id: str) -> tuple[str, bool]: | |
| is_available = self.client.test_model(model_id, self.test_prompt) | |
| cleaned_name = clean_model_name(model_id) | |
| return cleaned_name, is_available | |
| def test_all_models(self) -> Dict[str, Any]: | |
| """Legacy sync method - use scan_all_models instead""" | |
| return self.scan_all_models() | |
| async def chat_completion_stream(self, model_hint: Optional[str], messages: List[Dict[str, str]]): | |
| """流式聊天 - 返回生成器""" | |
| api_keys = config.get_api_keys() | |
| api_key = random.choice(api_keys) | |
| # 方案1:尝试用户指定的模型 | |
| if model_hint: | |
| full_model_id = self.find_model_in_list(model_hint) | |
| if full_model_id: | |
| async with aiohttp.ClientSession() as session: | |
| async for chunk in self.try_model_direct_stream(session, full_model_id, api_key, messages): | |
| yield chunk | |
| return | |
| # 方案2:从列表中找到可用模型 | |
| self.refresh_model_list() | |
| available_free = self.get_all_free_models() | |
| candidates = [] | |
| if model_hint and available_free: | |
| for m in available_free: | |
| model_name = m.replace(":free", "").split("/")[-1] | |
| if model_hint.lower() in model_name.lower(): | |
| candidates.append(m) | |
| if not candidates and available_free: | |
| candidates = available_free[:10] | |
| async with aiohttp.ClientSession() as session: | |
| for model_id in candidates: | |
| async for chunk in self.try_model_direct_stream(session, model_id, api_key, messages): | |
| yield chunk | |
| return | |