|
|
|
|
|
""" |
|
|
HuggingFace Qwen 2.5 Model Client |
|
|
Handles inference for router, main, and complex models with cost tracking |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import logging |
|
|
from typing import Dict, Any, List, Optional |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
from huggingface_hub import InferenceClient |
|
|
from langchain_huggingface import HuggingFaceEndpoint |
|
|
from langchain_core.language_models.llms import LLM |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ModelTier(Enum): |
|
|
"""Model complexity tiers for cost optimization""" |
|
|
ROUTER = "router" |
|
|
MAIN = "main" |
|
|
COMPLEX = "complex" |
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Configuration for each model""" |
|
|
name: str |
|
|
tier: ModelTier |
|
|
max_tokens: int |
|
|
temperature: float |
|
|
cost_per_token: float |
|
|
timeout: int |
|
|
requires_special_auth: bool = False |
|
|
|
|
|
@dataclass |
|
|
class InferenceResult: |
|
|
"""Result of model inference with metadata""" |
|
|
response: str |
|
|
model_used: str |
|
|
tokens_used: int |
|
|
cost_estimate: float |
|
|
response_time: float |
|
|
success: bool |
|
|
error: Optional[str] = None |
|
|
|
|
|
class QwenClient: |
|
|
"""HuggingFace client with fallback model support""" |
|
|
|
|
|
def __init__(self, hf_token: Optional[str] = None): |
|
|
"""Initialize the client with HuggingFace token for Qwen models only""" |
|
|
self.hf_token = hf_token or os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN") |
|
|
if not self.hf_token: |
|
|
raise ValueError("HuggingFace token is required for Qwen model access. Please provide HF_TOKEN or login with inference permissions.") |
|
|
|
|
|
|
|
|
self.total_cost = 0.0 |
|
|
self.request_count = 0 |
|
|
self.budget_limit = 0.10 |
|
|
|
|
|
|
|
|
self.models = { |
|
|
ModelTier.ROUTER: ModelConfig( |
|
|
name="Qwen/Qwen2.5-7B-Instruct", |
|
|
tier=ModelTier.ROUTER, |
|
|
max_tokens=512, |
|
|
temperature=0.1, |
|
|
cost_per_token=0.0003, |
|
|
timeout=15, |
|
|
requires_special_auth=True |
|
|
), |
|
|
ModelTier.MAIN: ModelConfig( |
|
|
name="Qwen/Qwen2.5-32B-Instruct", |
|
|
tier=ModelTier.MAIN, |
|
|
max_tokens=1024, |
|
|
temperature=0.1, |
|
|
cost_per_token=0.0008, |
|
|
timeout=25, |
|
|
requires_special_auth=True |
|
|
), |
|
|
ModelTier.COMPLEX: ModelConfig( |
|
|
name="Qwen/Qwen2.5-72B-Instruct", |
|
|
tier=ModelTier.COMPLEX, |
|
|
max_tokens=2048, |
|
|
temperature=0.1, |
|
|
cost_per_token=0.0015, |
|
|
timeout=35, |
|
|
requires_special_auth=True |
|
|
) |
|
|
} |
|
|
|
|
|
|
|
|
self.inference_clients = {} |
|
|
self.langchain_clients = {} |
|
|
self._initialize_clients() |
|
|
|
|
|
def _initialize_clients(self): |
|
|
"""Initialize HuggingFace clients for Qwen models only""" |
|
|
|
|
|
logger.info("🎯 Initializing Qwen models via HuggingFace Inference API...") |
|
|
success = self._try_initialize_models(self.models, "Qwen") |
|
|
|
|
|
if not success: |
|
|
raise RuntimeError("Failed to initialize any Qwen models. Please check your HF_TOKEN has inference permissions and try again.") |
|
|
|
|
|
|
|
|
logger.info("🧪 Testing Qwen model connectivity...") |
|
|
try: |
|
|
test_result = self.generate("Hello", max_tokens=10) |
|
|
if test_result.success and test_result.response.strip(): |
|
|
logger.info(f"✅ Qwen models ready: '{test_result.response.strip()}'") |
|
|
else: |
|
|
logger.error(f"❌ Qwen model test failed: {test_result}") |
|
|
raise RuntimeError("Qwen models failed connectivity test") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Qwen model test exception: {e}") |
|
|
raise RuntimeError(f"Qwen model initialization failed: {e}") |
|
|
|
|
|
def _try_initialize_models(self, model_configs: Dict, model_type: str) -> bool: |
|
|
"""Try to initialize Qwen models""" |
|
|
success_count = 0 |
|
|
|
|
|
for tier, config in model_configs.items(): |
|
|
try: |
|
|
|
|
|
test_client = InferenceClient( |
|
|
model=config.name, |
|
|
token=self.hf_token |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
test_response = test_client.chat_completion( |
|
|
messages=[{"role": "user", "content": "Hello"}], |
|
|
model=config.name, |
|
|
max_tokens=5, |
|
|
temperature=0.1 |
|
|
) |
|
|
logger.info(f"✅ {model_type} auth test passed for {config.name}") |
|
|
except Exception as auth_error: |
|
|
logger.warning(f"❌ {model_type} auth failed for {config.name}: {auth_error}") |
|
|
continue |
|
|
|
|
|
|
|
|
self.inference_clients[tier] = InferenceClient( |
|
|
model=config.name, |
|
|
token=self.hf_token |
|
|
) |
|
|
|
|
|
self.langchain_clients[tier] = HuggingFaceEndpoint( |
|
|
repo_id=config.name, |
|
|
max_new_tokens=config.max_tokens, |
|
|
temperature=config.temperature, |
|
|
huggingfacehub_api_token=self.hf_token, |
|
|
timeout=config.timeout |
|
|
) |
|
|
|
|
|
logger.info(f"✅ Initialized {model_type} {tier.value} model: {config.name}") |
|
|
success_count += 1 |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"❌ Failed to initialize {model_type} {tier.value} model: {e}") |
|
|
self.inference_clients[tier] = None |
|
|
self.langchain_clients[tier] = None |
|
|
|
|
|
return success_count > 0 |
|
|
|
|
|
def get_model_status(self) -> Dict[str, bool]: |
|
|
"""Check which models are available""" |
|
|
status = {} |
|
|
for tier in ModelTier: |
|
|
status[tier.value] = ( |
|
|
self.inference_clients.get(tier) is not None and |
|
|
self.langchain_clients.get(tier) is not None |
|
|
) |
|
|
return status |
|
|
|
|
|
def select_model_tier(self, complexity: str = "medium", budget_conscious: bool = True, question_text: str = "") -> ModelTier: |
|
|
"""Smart model selection based on task complexity, budget, and question analysis""" |
|
|
|
|
|
|
|
|
budget_used_percent = (self.total_cost / self.budget_limit) * 100 |
|
|
|
|
|
if budget_conscious and budget_used_percent > 80: |
|
|
logger.warning(f"Budget critical ({budget_used_percent:.1f}% used), forcing router model") |
|
|
return ModelTier.ROUTER |
|
|
elif budget_conscious and budget_used_percent > 60: |
|
|
logger.warning(f"Budget warning ({budget_used_percent:.1f}% used), limiting complex model usage") |
|
|
complexity = "simple" if complexity == "complex" else complexity |
|
|
|
|
|
|
|
|
if question_text: |
|
|
question_lower = question_text.lower() |
|
|
|
|
|
|
|
|
complex_indicators = [ |
|
|
"analyze", "explain why", "reasoning", "logic", "complex", "difficult", |
|
|
"multi-step", "calculate and explain", "compare and contrast", |
|
|
"what is the relationship", "how does", "why is", "prove that", |
|
|
"step by step", "detailed analysis", "comprehensive" |
|
|
] |
|
|
|
|
|
|
|
|
simple_indicators = [ |
|
|
"what is", "who is", "when", "where", "simple", "quick", |
|
|
"yes or no", "true or false", "list", "name", "find" |
|
|
] |
|
|
|
|
|
|
|
|
math_indicators = [ |
|
|
"calculate", "compute", "solve", "equation", "formula", "math", |
|
|
"number", "total", "sum", "average", "percentage", "code", "program" |
|
|
] |
|
|
|
|
|
|
|
|
file_indicators = [ |
|
|
"image", "picture", "photo", "audio", "sound", "video", "file", |
|
|
"document", "excel", "csv", "data", "chart", "graph" |
|
|
] |
|
|
|
|
|
|
|
|
complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower) |
|
|
simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower) |
|
|
math_score = sum(1 for indicator in math_indicators if indicator in question_lower) |
|
|
file_score = sum(1 for indicator in file_indicators if indicator in question_lower) |
|
|
|
|
|
|
|
|
if complex_score >= 2 or len(question_text) > 200: |
|
|
complexity = "complex" |
|
|
elif file_score >= 1 or math_score >= 2: |
|
|
complexity = "medium" |
|
|
elif simple_score >= 2 and complex_score == 0: |
|
|
complexity = "simple" |
|
|
|
|
|
|
|
|
if complexity == "complex" and budget_used_percent < 70: |
|
|
selected_tier = ModelTier.COMPLEX |
|
|
elif complexity == "simple" or budget_used_percent > 75: |
|
|
selected_tier = ModelTier.ROUTER |
|
|
else: |
|
|
selected_tier = ModelTier.MAIN |
|
|
|
|
|
|
|
|
if not self.inference_clients.get(selected_tier): |
|
|
logger.warning(f"Selected model {selected_tier.value} unavailable, falling back") |
|
|
for fallback in [ModelTier.MAIN, ModelTier.ROUTER, ModelTier.COMPLEX]: |
|
|
if self.inference_clients.get(fallback): |
|
|
selected_tier = fallback |
|
|
break |
|
|
else: |
|
|
raise RuntimeError("No models available") |
|
|
|
|
|
|
|
|
logger.info(f"Selected {selected_tier.value} model (complexity: {complexity}, budget: {budget_used_percent:.1f}%)") |
|
|
return selected_tier |
|
|
|
|
|
async def generate_async(self, |
|
|
prompt: str, |
|
|
tier: Optional[ModelTier] = None, |
|
|
max_tokens: Optional[int] = None) -> InferenceResult: |
|
|
"""Async text generation with Qwen models via HuggingFace Inference API""" |
|
|
|
|
|
if tier is None: |
|
|
tier = self.select_model_tier(question_text=prompt) |
|
|
|
|
|
config = self.models[tier] |
|
|
client = self.inference_clients.get(tier) |
|
|
|
|
|
if not client: |
|
|
return InferenceResult( |
|
|
response="", |
|
|
model_used=config.name, |
|
|
tokens_used=0, |
|
|
cost_estimate=0.0, |
|
|
response_time=0.0, |
|
|
success=False, |
|
|
error=f"Qwen model {tier.value} not available" |
|
|
) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
tokens = max_tokens or config.max_tokens |
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
logger.info(f"🤖 Generating with {config.name}...") |
|
|
response = client.chat_completion( |
|
|
messages=messages, |
|
|
model=config.name, |
|
|
max_tokens=tokens, |
|
|
temperature=config.temperature |
|
|
) |
|
|
|
|
|
|
|
|
if response and response.choices: |
|
|
response_text = response.choices[0].message.content |
|
|
else: |
|
|
raise ValueError(f"No response received from {config.name}") |
|
|
|
|
|
response_time = time.time() - start_time |
|
|
|
|
|
|
|
|
response_text = str(response_text).strip() |
|
|
|
|
|
if not response_text: |
|
|
raise ValueError(f"Empty response from {config.name}") |
|
|
|
|
|
|
|
|
estimated_tokens = len(prompt.split()) + len(response_text.split()) |
|
|
cost_estimate = estimated_tokens * config.cost_per_token |
|
|
|
|
|
|
|
|
self.total_cost += cost_estimate |
|
|
self.request_count += 1 |
|
|
|
|
|
logger.info(f"✅ Generated with {tier.value} model in {response_time:.2f}s") |
|
|
|
|
|
return InferenceResult( |
|
|
response=response_text, |
|
|
model_used=config.name, |
|
|
tokens_used=estimated_tokens, |
|
|
cost_estimate=cost_estimate, |
|
|
response_time=response_time, |
|
|
success=True |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
response_time = time.time() - start_time |
|
|
error_msg = str(e) |
|
|
|
|
|
logger.error(f"❌ Generation failed with {tier.value} model ({config.name}): {error_msg}") |
|
|
|
|
|
return InferenceResult( |
|
|
response="", |
|
|
model_used=config.name, |
|
|
tokens_used=0, |
|
|
cost_estimate=0.0, |
|
|
response_time=response_time, |
|
|
success=False, |
|
|
error=error_msg |
|
|
) |
|
|
|
|
|
def generate(self, |
|
|
prompt: str, |
|
|
tier: Optional[ModelTier] = None, |
|
|
max_tokens: Optional[int] = None) -> InferenceResult: |
|
|
"""Synchronous text generation (wrapper for async)""" |
|
|
import asyncio |
|
|
|
|
|
|
|
|
try: |
|
|
loop = asyncio.get_event_loop() |
|
|
except RuntimeError: |
|
|
loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(loop) |
|
|
|
|
|
return loop.run_until_complete( |
|
|
self.generate_async(prompt, tier, max_tokens) |
|
|
) |
|
|
|
|
|
def get_langchain_llm(self, tier: ModelTier) -> Optional[LLM]: |
|
|
"""Get LangChain LLM instance for agent integration""" |
|
|
return self.langchain_clients.get(tier) |
|
|
|
|
|
def get_usage_stats(self) -> Dict[str, Any]: |
|
|
"""Get current usage and cost statistics""" |
|
|
return { |
|
|
"total_cost": self.total_cost, |
|
|
"request_count": self.request_count, |
|
|
"budget_limit": self.budget_limit, |
|
|
"budget_remaining": self.budget_limit - self.total_cost, |
|
|
"budget_used_percent": (self.total_cost / self.budget_limit) * 100, |
|
|
"average_cost_per_request": self.total_cost / max(self.request_count, 1), |
|
|
"models_available": self.get_model_status() |
|
|
} |
|
|
|
|
|
def reset_usage_tracking(self): |
|
|
"""Reset usage statistics (for testing/development)""" |
|
|
self.total_cost = 0.0 |
|
|
self.request_count = 0 |
|
|
logger.info("Usage tracking reset") |
|
|
|
|
|
|
|
|
def test_model_connection(client: QwenClient, tier: ModelTier): |
|
|
"""Test connection to a specific model tier""" |
|
|
test_prompt = "Hello! Please respond with 'Connection successful' if you can read this." |
|
|
|
|
|
logger.info(f"Testing {tier.value} model...") |
|
|
result = client.generate(test_prompt, tier=tier, max_tokens=50) |
|
|
|
|
|
if result.success: |
|
|
logger.info(f"✅ {tier.value} model test successful: {result.response[:50]}...") |
|
|
logger.info(f" Response time: {result.response_time:.2f}s") |
|
|
logger.info(f" Cost estimate: ${result.cost_estimate:.6f}") |
|
|
else: |
|
|
logger.error(f"❌ {tier.value} model test failed: {result.error}") |
|
|
|
|
|
return result.success |
|
|
|
|
|
def test_all_models(): |
|
|
"""Test all available models""" |
|
|
logger.info("🧪 Testing all Qwen models...") |
|
|
|
|
|
client = QwenClient() |
|
|
|
|
|
results = {} |
|
|
for tier in ModelTier: |
|
|
results[tier] = test_model_connection(client, tier) |
|
|
|
|
|
logger.info("📊 Test Results Summary:") |
|
|
for tier, success in results.items(): |
|
|
status = "✅ PASS" if success else "❌ FAIL" |
|
|
logger.info(f" {tier.value:8}: {status}") |
|
|
|
|
|
logger.info("💰 Usage Statistics:") |
|
|
stats = client.get_usage_stats() |
|
|
for key, value in stats.items(): |
|
|
if key != "models_available": |
|
|
logger.info(f" {key}: {value}") |
|
|
|
|
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
test_all_models() |