""" Model inference and client management for AnyCoder. Handles different model providers and inference clients. """ import os from typing import Dict, List, Optional, Tuple import re from http import HTTPStatus from huggingface_hub import InferenceClient from openai import OpenAI from mistralai import Mistral import dashscope from google import genai from google.genai import types from .config import HF_TOKEN, AVAILABLE_MODELS # Type definitions History = List[Dict[str, str]] Messages = List[Dict[str, str]] def get_inference_client(model_id, provider="auto"): """Return an InferenceClient with provider based on model_id and user selection.""" if model_id == "gemini-3.0-pro": # Use Poe (OpenAI-compatible) client for Gemini 3.0 Pro return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "qwen3-30b-a3b-instruct-2507": # Use DashScope OpenAI client return OpenAI( api_key=os.getenv("DASHSCOPE_API_KEY"), base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) elif model_id == "qwen3-30b-a3b-thinking-2507": # Use DashScope OpenAI client for Thinking model return OpenAI( api_key=os.getenv("DASHSCOPE_API_KEY"), base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) elif model_id == "qwen3-coder-30b-a3b-instruct": # Use DashScope OpenAI client for Coder model return OpenAI( api_key=os.getenv("DASHSCOPE_API_KEY"), base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) elif model_id == "gpt-5": # Use Poe (OpenAI-compatible) client for GPT-5 model return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "gpt-5.1": # Use Poe (OpenAI-compatible) client for GPT-5.1 model return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "gpt-5.1-instant": # Use Poe (OpenAI-compatible) client for GPT-5.1 Instant model return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "gpt-5.1-codex": # Use Poe (OpenAI-compatible) client for GPT-5.1 Codex model return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "gpt-5.1-codex-mini": # Use Poe (OpenAI-compatible) client for GPT-5.1 Codex Mini model return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "grok-4": # Use Poe (OpenAI-compatible) client for Grok-4 model return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "Grok-Code-Fast-1": # Use Poe (OpenAI-compatible) client for Grok-Code-Fast-1 model return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "claude-opus-4.1": # Use Poe (OpenAI-compatible) client for Claude-Opus-4.1 return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "claude-sonnet-4.5": # Use Poe (OpenAI-compatible) client for Claude-Sonnet-4.5 return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "claude-haiku-4.5": # Use Poe (OpenAI-compatible) client for Claude-Haiku-4.5 return OpenAI( api_key=os.getenv("POE_API_KEY"), base_url="https://api.poe.com/v1" ) elif model_id == "qwen3-max-preview": # Use DashScope International OpenAI client for Qwen3 Max Preview return OpenAI( api_key=os.getenv("DASHSCOPE_API_KEY"), base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) elif model_id == "openrouter/sonoma-dusk-alpha": # Use OpenRouter client for Sonoma Dusk Alpha model return OpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", ) elif model_id == "openrouter/sonoma-sky-alpha": # Use OpenRouter client for Sonoma Sky Alpha model return OpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", ) elif model_id == "x-ai/grok-4.1-fast": # Use OpenRouter client for Grok 4.1 Fast model return OpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", ) elif model_id == "openrouter/sherlock-think-alpha": # Use OpenRouter client for Sherlock Think Alpha model return OpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", ) elif model_id == "MiniMaxAI/MiniMax-M2": # Use HuggingFace InferenceClient with Novita provider for MiniMax M2 model provider = "novita" elif model_id == "step-3": # Use StepFun API client for Step-3 model return OpenAI( api_key=os.getenv("STEP_API_KEY"), base_url="https://api.stepfun.com/v1" ) elif model_id == "codestral-2508" or model_id == "mistral-medium-2508": # Use Mistral client for Mistral models return Mistral(api_key=os.getenv("MISTRAL_API_KEY")) elif model_id == "gemini-2.5-flash": # Use Google Gemini (OpenAI-compatible) client return OpenAI( api_key=os.getenv("GEMINI_API_KEY"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ) elif model_id == "gemini-2.5-pro": # Use Google Gemini Pro (OpenAI-compatible) client return OpenAI( api_key=os.getenv("GEMINI_API_KEY"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ) elif model_id == "gemini-flash-latest": # Use Google Gemini Flash Latest (OpenAI-compatible) client return OpenAI( api_key=os.getenv("GEMINI_API_KEY"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ) elif model_id == "gemini-flash-lite-latest": # Use Google Gemini Flash Lite Latest (OpenAI-compatible) client return OpenAI( api_key=os.getenv("GEMINI_API_KEY"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ) elif model_id == "kimi-k2-turbo-preview": # Use Moonshot AI (OpenAI-compatible) client for Kimi K2 Turbo (Preview) return OpenAI( api_key=os.getenv("MOONSHOT_API_KEY"), base_url="https://api.moonshot.ai/v1", ) elif model_id == "moonshotai/Kimi-K2-Thinking": # Use HuggingFace InferenceClient with Together AI provider for Kimi K2 Thinking provider = "together" elif model_id == "stealth-model-1": # Use stealth model with generic configuration api_key = os.getenv("STEALTH_MODEL_1_API_KEY") if not api_key: raise ValueError("STEALTH_MODEL_1_API_KEY environment variable is required for Carrot model") base_url = os.getenv("STEALTH_MODEL_1_BASE_URL") if not base_url: raise ValueError("STEALTH_MODEL_1_BASE_URL environment variable is required for Carrot model") return OpenAI( api_key=api_key, base_url=base_url, ) elif model_id == "moonshotai/Kimi-K2-Instruct": provider = "groq" elif model_id == "deepseek-ai/DeepSeek-V3.1": provider = "novita" elif model_id == "deepseek-ai/DeepSeek-V3.1-Terminus": provider = "novita" elif model_id == "deepseek-ai/DeepSeek-V3.2-Exp": provider = "novita" elif model_id == "zai-org/GLM-4.5": provider = "fireworks-ai" elif model_id == "zai-org/GLM-4.6": # Use Cerebras provider for GLM-4.6 via HuggingFace provider = "cerebras" return InferenceClient( provider=provider, api_key=HF_TOKEN, bill_to="huggingface" ) # Helper function to get real model ID for stealth models and special cases def get_real_model_id(model_id: str) -> str: """Get the real model ID, checking environment variables for stealth models and handling special model formats""" if model_id == "stealth-model-1": # Get the real model ID from environment variable real_model_id = os.getenv("STEALTH_MODEL_1_ID") if not real_model_id: raise ValueError("STEALTH_MODEL_1_ID environment variable is required for Carrot model") return real_model_id elif model_id == "zai-org/GLM-4.6": # GLM-4.6 requires Cerebras provider suffix in model string for API calls return "zai-org/GLM-4.6:cerebras" elif model_id == "moonshotai/Kimi-K2-Thinking": # Kimi K2 Thinking needs Together AI provider return "moonshotai/Kimi-K2-Thinking:together" return model_id # Type definitions History = List[Tuple[str, str]] Messages = List[Dict[str, str]] def history_to_messages(history: History, system: str) -> Messages: messages = [{'role': 'system', 'content': system}] for h in history: # Handle multimodal content in history user_content = h[0] if isinstance(user_content, list): # Extract text from multimodal content text_content = "" for item in user_content: if isinstance(item, dict) and item.get("type") == "text": text_content += item.get("text", "") user_content = text_content if text_content else str(user_content) messages.append({'role': 'user', 'content': user_content}) messages.append({'role': 'assistant', 'content': h[1]}) return messages def history_to_chatbot_messages(history: History) -> List[Dict[str, str]]: """Convert history tuples to chatbot message format""" messages = [] for user_msg, assistant_msg in history: # Handle multimodal content if isinstance(user_msg, list): text_content = "" for item in user_msg: if isinstance(item, dict) and item.get("type") == "text": text_content += item.get("text", "") user_msg = text_content if text_content else str(user_msg) messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) return messages def strip_tool_call_markers(text): """Remove TOOL_CALL markers that some LLMs (like Qwen) add to their output.""" if not text: return text # Remove [TOOL_CALL] and [/TOOL_CALL] markers text = re.sub(r'\[/?TOOL_CALL\]', '', text, flags=re.IGNORECASE) # Remove standalone }} that appears with tool calls # Only remove if it's on its own line or at the end text = re.sub(r'^\s*\}\}\s*$', '', text, flags=re.MULTILINE) return text.strip() def remove_code_block(text): # First strip any tool call markers text = strip_tool_call_markers(text) # Try to match code blocks with language markers patterns = [ r'```(?:html|HTML)\n([\s\S]+?)\n```', # Match ```html or ```HTML r'```\n([\s\S]+?)\n```', # Match code blocks without language markers r'```([\s\S]+?)```' # Match code blocks without line breaks ] for pattern in patterns: match = re.search(pattern, text, re.DOTALL) if match: extracted = match.group(1).strip() # Remove a leading language marker line (e.g., 'python') if present if extracted.split('\n', 1)[0].strip().lower() in ['python', 'html', 'css', 'javascript', 'json', 'c', 'cpp', 'markdown', 'latex', 'jinja2', 'typescript', 'yaml', 'dockerfile', 'shell', 'r', 'sql', 'sql-mssql', 'sql-mysql', 'sql-mariadb', 'sql-sqlite', 'sql-cassandra', 'sql-plSQL', 'sql-hive', 'sql-pgsql', 'sql-gql', 'sql-gpsql', 'sql-sparksql', 'sql-esper']: return extracted.split('\n', 1)[1] if '\n' in extracted else '' # If HTML markup starts later in the block (e.g., Poe injected preface), trim to first HTML root html_root_idx = None for tag in [' 0: return extracted[html_root_idx:].strip() return extracted # If no code block is found, check if the entire text is HTML stripped = text.strip() if stripped.startswith('') or stripped.startswith(' 0: return stripped[idx:].strip() return stripped # Special handling for python: remove python marker if text.strip().startswith('```python'): return text.strip()[9:-3].strip() # Remove a leading language marker line if present (fallback) lines = text.strip().split('\n', 1) if lines[0].strip().lower() in ['python', 'html', 'css', 'javascript', 'json', 'c', 'cpp', 'markdown', 'latex', 'jinja2', 'typescript', 'yaml', 'dockerfile', 'shell', 'r', 'sql', 'sql-mssql', 'sql-mysql', 'sql-mariadb', 'sql-sqlite', 'sql-cassandra', 'sql-plSQL', 'sql-hive', 'sql-pgsql', 'sql-gql', 'sql-gpsql', 'sql-sparksql', 'sql-esper']: return lines[1] if len(lines) > 1 else '' return text.strip() ## React CDN compatibility fixer removed per user preference def strip_thinking_tags(text: str) -> str: """Strip tags and [TOOL_CALL] markers from streaming output.""" if not text: return text # Remove opening tags text = re.sub(r'', '', text, flags=re.IGNORECASE) # Remove closing tags text = re.sub(r'', '', text, flags=re.IGNORECASE) # Remove [TOOL_CALL] markers text = re.sub(r'\[/?TOOL_CALL\]', '', text, flags=re.IGNORECASE) return text def strip_placeholder_thinking(text: str) -> str: """Remove placeholder 'Thinking...' status lines from streamed text.""" if not text: return text # Matches lines like: "Thinking..." or "Thinking... (12s elapsed)" return re.sub(r"(?mi)^[\t ]*Thinking\.\.\.(?:\s*\(\d+s elapsed\))?[\t ]*$\n?", "", text) def is_placeholder_thinking_only(text: str) -> bool: """Return True if text contains only 'Thinking...' placeholder lines (with optional elapsed).""" if not text: return False stripped = text.strip() if not stripped: return False return re.fullmatch(r"(?s)(?:\s*Thinking\.\.\.(?:\s*\(\d+s elapsed\))?\s*)+", stripped) is not None def extract_last_thinking_line(text: str) -> str: """Extract the last 'Thinking...' line to display as status.""" matches = list(re.finditer(r"Thinking\.\.\.(?:\s*\(\d+s elapsed\))?", text)) return matches[-1].group(0) if matches else "Thinking..."