ClaCe's picture
Upload 3 files
998af39 verified
from fastapi import FastAPI, HTTPException, Header
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import chromadb
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from huggingface_hub import login
import requests
import json
from typing import List, Dict, Any
import os
import sys
import torch
import tarfile
app = FastAPI(title="ML Use Cases RAG System")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variable to store current logs
current_logs = []
def log_to_ui(message):
"""Add a log message that will be sent to UI"""
current_logs.append(message)
print(message) # Still print to console
# Initialize embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# BYOK: No server-side API key initialization
# All model access will be done via user-provided API keys
print("🔑 BYOK Mode: No server-side API key configured")
print("✅ Users will provide their own HuggingFace API keys")
generator = None
llm_available = False
# Auto-extract ChromaDB if archive exists and directory is missing/empty
def setup_chromadb():
"""Setup ChromaDB by extracting archive if needed"""
if os.path.exists("chroma_db_complete.tar.gz"):
# Check if chroma_db directory exists and has content
needs_extraction = False
if not os.path.exists("chroma_db"):
print("📦 ChromaDB directory not found, extracting archive...")
needs_extraction = True
else:
# Check if directory is empty or missing key files
try:
if not os.path.exists("chroma_db/chroma.sqlite3"):
print("📦 ChromaDB missing database file, extracting archive...")
needs_extraction = True
else:
# Quick check: try to list collections
temp_client = chromadb.PersistentClient(path="./chroma_db")
collections = temp_client.list_collections()
if len(collections) == 0:
print("📦 ChromaDB has no collections, extracting archive...")
needs_extraction = True
else:
print(f"✅ ChromaDB already setup with {len(collections)} collections")
except Exception as e:
print(f"📦 ChromaDB check failed ({e}), extracting archive...")
needs_extraction = True
if needs_extraction:
try:
print("🔧 Extracting ChromaDB archive...")
with tarfile.open("chroma_db_complete.tar.gz", "r:gz") as tar:
tar.extractall()
print("✅ ChromaDB extracted successfully")
# Verify extraction
if os.path.exists("chroma_db/chroma.sqlite3"):
print("✅ Database file found after extraction")
else:
print("❌ Database file missing after extraction")
except Exception as e:
print(f"❌ Failed to extract ChromaDB: {e}")
else:
print("📋 No ChromaDB archive found, using existing directory")
# Setup ChromaDB before initializing client
setup_chromadb()
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path="./chroma_db")
collection = None
class ChatRequest(BaseModel):
query: str
class ApiKeyRequest(BaseModel):
api_key: str
class SearchResult(BaseModel):
company: str
industry: str
year: int
description: str
summary: str
similarity_score: float
url: str
class RecommendedModels(BaseModel):
fine_tuned: List[Dict[str, Any]]
general: List[Dict[str, Any]]
class ChatResponse(BaseModel):
solution_approach: str
company_examples: List[SearchResult]
recommended_models: RecommendedModels
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
@app.get("/test-token/{token}")
async def test_token_direct(token: str):
"""Direct token test endpoint"""
print(f"🧪 Testing token: {token[:10]}...")
try:
# Test with models API
response = requests.get(
"https://huggingface.co/api/models?limit=1",
headers={"Authorization": f"Bearer {token}"},
timeout=10
)
print(f"📊 Models API Status: {response.status_code}")
if response.status_code == 200:
return {"valid": True, "method": "models_api", "status": response.status_code}
# Test whoami
response2 = requests.get(
"https://huggingface.co/api/whoami",
headers={"Authorization": f"Bearer {token}"},
timeout=10
)
print(f"📊 WhoAmI Status: {response2.status_code}")
return {
"valid": response2.status_code == 200,
"models_status": response.status_code,
"whoami_status": response2.status_code,
"whoami_response": response2.text[:200] if response2.status_code != 200 else "OK"
}
except Exception as e:
return {"error": str(e)}
@app.post("/validate-key")
async def validate_api_key(request: ApiKeyRequest):
"""Validate user's HuggingFace API key"""
api_key = request.api_key.strip()
print(f"🔑 Validating API key: {api_key[:10]}...")
if not api_key or not api_key.startswith('hf_'):
print(f"❌ Invalid format: {api_key[:10] if api_key else 'empty'}")
return {"valid": False, "error": "Invalid API key format. Must start with 'hf_'"}
# Simple format validation - if it looks like a valid HF token, accept it
if len(api_key) >= 30 and api_key.startswith('hf_') and all(c.isalnum() or c == '_' for c in api_key):
print("✅ API key format is valid, accepting")
return {"valid": True, "user": "User"}
print(f"❌ Invalid token format or length")
return {"valid": False, "error": "Invalid API key format"}
@app.get("/logs")
async def get_logs():
"""Get current log messages for UI"""
try:
logs_copy = current_logs.copy()
current_logs.clear()
return {"logs": logs_copy}
except Exception as e:
return {"logs": [], "error": str(e)}
@app.get("/test-logs")
async def test_logs():
"""Test endpoint to verify logging works"""
log_to_ui("🧪 Test log message 1")
log_to_ui("🧪 Test log message 2")
log_to_ui("🧪 Test log message 3")
return {"message": "Test logs added"}
def initialize_collection():
"""Initialize the ChromaDB collection with debug logging"""
global collection
# Debug: Check file system
print(f"🔍 Current working directory: {os.getcwd()}")
print(f"🔍 ChromaDB path exists: {os.path.exists('./chroma_db')}")
if os.path.exists('./chroma_db'):
try:
chroma_files = os.listdir('./chroma_db')
print(f"🔍 ChromaDB directory contents: {chroma_files}")
# Check for main database file
if 'chroma.sqlite3' in chroma_files:
print("✅ Found chroma.sqlite3")
else:
print("❌ chroma.sqlite3 NOT found")
# Check for UUID directories
uuid_dirs = [f for f in chroma_files if len(f) == 36 and '-' in f] # UUID format
if uuid_dirs:
print(f"✅ Found UUID directories: {uuid_dirs}")
for uuid_dir in uuid_dirs:
uuid_path = os.path.join('./chroma_db', uuid_dir)
if os.path.isdir(uuid_path):
uuid_files = os.listdir(uuid_path)
print(f"🔍 {uuid_dir} contents: {uuid_files}")
else:
print("❌ No UUID directories found")
except Exception as e:
print(f"❌ Error reading chroma_db directory: {e}")
else:
print("❌ chroma_db directory does not exist")
# Debug: Try to initialize ChromaDB client
try:
print("🔍 Attempting to initialize ChromaDB client...")
print(f"🔍 ChromaDB version: {chromadb.__version__}")
# List all collections
collections = chroma_client.list_collections()
print(f"🔍 Available collections: {[c.name for c in collections]}")
# Try to get the specific collection
collection = chroma_client.get_collection("ml_use_cases")
collection_count = collection.count()
print(f"✅ Found existing collection 'ml_use_cases' with {collection_count} documents")
except Exception as e:
print(f"❌ Collection initialization error: {type(e).__name__}: {e}")
print("📝 Will attempt to create collection during first use")
collection = None
# Initialize collection on import
initialize_collection()
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the main frontend"""
with open("static/index.html", "r") as f:
return HTMLResponse(f.read())
async def search_use_cases_internal(request: ChatRequest):
"""Internal search function with detailed logging"""
log_to_ui(f"🔍 Search request received: '{request.query}'")
if not collection:
log_to_ui("❌ ChromaDB collection not initialized")
raise HTTPException(status_code=500, detail="Database not initialized")
query = request.query.lower()
log_to_ui(f"📝 Normalized query: '{query}'")
# Generate query embedding for semantic search
log_to_ui("🧠 Generating query embedding...")
query_embedding = embedding_model.encode([request.query]).tolist()[0]
log_to_ui(f"✅ Embedding generated, dimension: {len(query_embedding)}")
# Semantic search
log_to_ui("🔎 Performing semantic search...")
semantic_results = collection.query(
query_embeddings=[query_embedding],
n_results=15,
include=['metadatas', 'documents', 'distances']
)
log_to_ui(f"📊 Semantic search found {len(semantic_results['ids'][0])} results")
# Keyword-based search using where clause for exact matches
keyword_results = None
try:
log_to_ui("🔤 Performing keyword search...")
keyword_results = collection.query(
query_texts=[request.query],
n_results=10,
include=['metadatas', 'documents', 'distances']
)
log_to_ui(f"📝 Keyword search found {len(keyword_results['ids'][0])} results")
except Exception as e:
log_to_ui(f"⚠️ Keyword search failed: {e}")
pass
# Combine and rank results
combined_results = {}
# Process semantic results
for i in range(len(semantic_results['ids'][0])):
doc_id = semantic_results['ids'][0][i]
metadata = semantic_results['metadatas'][0][i]
similarity_score = 1 - semantic_results['distances'][0][i]
# Boost score for keyword matches in metadata
boost = 0
query_words = query.split()
for word in query_words:
if word in metadata.get('title', '').lower():
boost += 0.3
if word in metadata.get('description', '').lower():
boost += 0.2
if word in metadata.get('keywords', '').lower():
boost += 0.4
if word in metadata.get('industry', '').lower():
boost += 0.1
final_score = min(similarity_score + boost, 1.0)
combined_results[doc_id] = {
'metadata': metadata,
'summary': semantic_results['documents'][0][i],
'score': final_score,
'source': 'semantic'
}
# Process keyword results if available
if keyword_results:
for i in range(len(keyword_results['ids'][0])):
doc_id = keyword_results['ids'][0][i]
if doc_id not in combined_results:
metadata = keyword_results['metadatas'][0][i]
similarity_score = 1 - keyword_results['distances'][0][i]
combined_results[doc_id] = {
'metadata': metadata,
'summary': keyword_results['documents'][0][i],
'score': similarity_score + 0.1, # Small boost for keyword matches
'source': 'keyword'
}
# Sort by score and take top results
sorted_results = sorted(combined_results.values(), key=lambda x: x['score'], reverse=True)[:10]
log_to_ui(f"🎯 Combined and ranked results: {len(sorted_results)} final results")
search_results = []
for i, result in enumerate(sorted_results):
metadata = result['metadata']
search_results.append(SearchResult(
company=metadata.get('company', ''),
industry=metadata.get('industry', ''),
year=metadata.get('year', 2023),
description=metadata.get('description', ''),
summary=result['summary'],
similarity_score=result['score'],
url=metadata.get('url', '')
))
log_to_ui(f" {i+1}. {metadata.get('company', 'Unknown')} - Score: {result['score']:.3f}")
log_to_ui(f"✅ Search completed, returning {len(search_results)} results")
return search_results
@app.post("/search")
async def search_use_cases(request: ChatRequest):
"""Public search endpoint"""
results = await search_use_cases_internal(request)
return {"results": results}
async def generate_response_with_user_key(prompt: str, api_key: str, max_length: int = 500) -> str:
"""Generate response using user's HuggingFace API key via Inference API"""
try:
# Use HuggingFace Inference API with user's key
api_url = "https://api-inference.huggingface.co/models/google/gemma-2-2b-it"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_length,
"temperature": 0.7,
"do_sample": True,
"return_full_text": False
}
}
response = requests.post(api_url, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and len(result) > 0:
generated_text = result[0].get('generated_text', '')
return generated_text.strip()
else:
return "Unable to generate response. Please try again."
elif response.status_code == 503:
# Model is loading, try fallback
return await try_fallback_model(prompt, api_key, max_length)
else:
raise Exception(f"API request failed with status {response.status_code}")
except Exception as e:
print(f"Error generating response with user API key: {e}")
return generate_template_response(prompt)
async def try_fallback_model(prompt: str, api_key: str, max_length: int = 500) -> str:
"""Try fallback model when primary model is unavailable"""
try:
# Try a more readily available model as fallback
fallback_models = [
"microsoft/DialoGPT-medium",
"microsoft/DialoGPT-small",
"gpt2"
]
for model_name in fallback_models:
try:
api_url = f"https://api-inference.huggingface.co/models/{model_name}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_length,
"temperature": 0.7,
"do_sample": True,
"return_full_text": False
}
}
response = requests.post(api_url, headers=headers, json=payload, timeout=20)
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and len(result) > 0:
generated_text = result[0].get('generated_text', '')
return generated_text.strip()
except:
continue
# If all models fail, return template
return generate_template_response(prompt)
except Exception as e:
return generate_template_response(prompt)
def generate_template_response(prompt: str) -> str:
"""Generate a template response when AI models are not available"""
return f"""Based on the analysis of similar ML/AI implementations from companies in our database, here are key recommendations for your problem:
**Technical Approach:**
- Consider machine learning classification or prediction models
- Leverage data preprocessing and feature engineering
- Implement proper model validation and testing
**Implementation Strategy:**
- Start with a minimum viable model using existing data
- Iterate based on performance metrics
- Consider scalability and real-time requirements
**Key Considerations:**
- Data quality and availability
- Business metrics alignment
- Technical infrastructure requirements
This analysis is based on patterns from 400+ real-world ML implementations across various industries."""
@app.post("/chat", response_model=ChatResponse)
async def chat_with_rag(request: ChatRequest, x_hf_api_key: str = Header(None, alias="X-HF-API-Key")):
"""Main RAG endpoint with user API key"""
# Validate user API key
if not x_hf_api_key or not x_hf_api_key.startswith('hf_'):
raise HTTPException(status_code=400, detail="Valid HuggingFace API key required")
# Clear previous logs and start fresh
current_logs.clear()
log_to_ui(f"🤖 Chat request received: '{request.query}'")
# First search for relevant use cases
log_to_ui("🔍 Getting relevant use cases...")
relevant_cases = await search_use_cases_internal(request)
top_cases = relevant_cases[:5] # Top 5 results
log_to_ui(f"📚 Using top {len(top_cases)} cases for context")
# Prepare context for LLM
log_to_ui("📝 Preparing context for LLM...")
context = "Here are relevant real-world ML/AI implementations:\n\n"
for i, case in enumerate(top_cases, 1):
context += f"Company: {case.company} ({case.industry}, {case.year})\n"
context += f"Description: {case.description}\n"
context += f"Implementation: {case.summary[:500]}...\n\n"
log_to_ui(f" {i}. {case.company} - {case.description}")
log_to_ui(f"📊 Context length: {len(context)} characters")
# Create prompt for language model
prompt = f"""Based on the following real ML/AI implementations from companies, provide advice for this business problem:
{context}
User Problem: {request.query}
Please provide a comprehensive solution approach that considers what has worked for these companies. Focus on:
1. What type of ML/AI solution would address this problem
2. Key technical approaches that have proven successful
3. Implementation considerations based on the examples
Be specific and reference the examples when relevant.
Response:"""
log_to_ui(f"💭 Full prompt length: {len(prompt)} characters")
# Generate response using user's HuggingFace API key
log_to_ui("🚀 Generating AI response with user API key...")
try:
llm_response = await generate_response_with_user_key(prompt, x_hf_api_key, max_length=400)
log_to_ui(f"✅ AI response generated, length: {len(llm_response)} characters")
except Exception as e:
llm_response = f"Error generating AI response: {str(e)}"
log_to_ui(f"❌ AI response error: {e}")
# Get HuggingFace model recommendations using user's API key
log_to_ui("🤗 Getting HuggingFace model recommendations...")
recommended_models = await get_huggingface_models(request.query, top_cases, x_hf_api_key)
total_models = len(recommended_models.get("fine_tuned", [])) + len(recommended_models.get("general", []))
log_to_ui(f"🏷️ Found {total_models} recommended models")
log_to_ui("✅ Chat response complete!")
# Return response with logs included
return {
"solution_approach": llm_response,
"company_examples": [
{
"company": case.company,
"industry": case.industry,
"year": case.year,
"description": case.description,
"summary": case.summary,
"similarity_score": case.similarity_score,
"url": case.url
}
for case in top_cases
],
"recommended_models": {
"fine_tuned": recommended_models.get("fine_tuned", []),
"general": recommended_models.get("general", [])
},
"logs": current_logs.copy() # Include all logs in the response
}
async def get_huggingface_models(query: str, relevant_cases: List = None, api_key: str = None) -> Dict[str, List[Dict[str, Any]]]:
"""Get relevant ML models from HuggingFace based on query and similar use cases"""
log_to_ui(f"🔍 Analyzing query for ML task mapping: '{query}'")
try:
# Enhanced multi-dimensional classification system
business_domains = {
# Financial Services
"finance": ["fraud detection", "risk assessment", "algorithmic trading", "credit scoring"],
"banking": ["fraud detection", "credit scoring", "customer segmentation", "loan approval"],
"fintech": ["payment processing", "robo-advisor", "fraud detection", "credit scoring"],
"insurance": ["risk assessment", "claim processing", "fraud detection", "pricing optimization"],
# E-commerce & Retail
"ecommerce": ["recommendation systems", "demand forecasting", "price optimization", "customer segmentation"],
"retail": ["inventory management", "demand forecasting", "customer analytics", "supply chain"],
"marketplace": ["search ranking", "recommendation systems", "fraud detection", "seller analytics"],
# Healthcare & Life Sciences
"healthcare": ["medical imaging", "drug discovery", "patient risk prediction", "clinical decision support"],
"medical": ["diagnostic imaging", "treatment optimization", "patient monitoring", "clinical trials"],
"pharma": ["drug discovery", "clinical trials", "adverse event detection", "molecular analysis"],
# Technology & Media
"tech": ["user behavior analysis", "system optimization", "content moderation", "search ranking"],
"media": ["content recommendation", "audience analytics", "content generation", "sentiment analysis"],
"gaming": ["player behavior prediction", "game optimization", "content generation", "cheat detection"],
# Marketing & Advertising
"marketing": ["customer segmentation", "campaign optimization", "lead scoring", "attribution modeling"],
"advertising": ["ad targeting", "bid optimization", "creative optimization", "audience analytics"],
"social": ["sentiment analysis", "trend prediction", "content moderation", "influence analysis"]
}
problem_types = {
# Customer Analytics
"churn": {
"domain": "customer_analytics",
"task_type": "binary_classification",
"data_types": ["tabular", "behavioral"],
"complexity": "intermediate",
"models": ["xgboost", "lightgbm", "catboost", "random_forest"],
"hf_tasks": ["tabular-classification"],
"keywords": ["retention", "attrition", "leave", "cancel", "subscription"]
},
"segmentation": {
"domain": "customer_analytics",
"task_type": "clustering",
"data_types": ["tabular", "behavioral"],
"complexity": "intermediate",
"models": ["kmeans", "dbscan", "hierarchical", "gaussian_mixture"],
"hf_tasks": ["tabular-classification"],
"keywords": ["segment", "group", "persona", "cluster", "behavior"]
},
# Risk & Fraud
"fraud": {
"domain": "risk_management",
"task_type": "anomaly_detection",
"data_types": ["tabular", "graph", "time_series"],
"complexity": "advanced",
"models": ["isolation_forest", "autoencoder", "one_class_svm", "gnn"],
"hf_tasks": ["tabular-classification"],
"keywords": ["suspicious", "anomaly", "unusual", "scam", "fake"]
},
"risk": {
"domain": "risk_management",
"task_type": "regression",
"data_types": ["tabular", "time_series"],
"complexity": "advanced",
"models": ["ensemble", "deep_learning", "survival_analysis"],
"hf_tasks": ["tabular-regression"],
"keywords": ["probability", "likelihood", "exposure", "default", "loss"]
},
# Demand & Forecasting
"forecast": {
"domain": "demand_planning",
"task_type": "time_series_forecasting",
"data_types": ["time_series", "tabular"],
"complexity": "advanced",
"models": ["prophet", "lstm", "transformer", "arima"],
"hf_tasks": ["time-series-forecasting"],
"keywords": ["predict", "future", "trend", "seasonal", "demand", "sales"]
},
"demand": {
"domain": "demand_planning",
"task_type": "regression",
"data_types": ["time_series", "tabular"],
"complexity": "intermediate",
"models": ["xgboost", "lstm", "prophet"],
"hf_tasks": ["tabular-regression", "time-series-forecasting"],
"keywords": ["inventory", "supply", "planning", "optimization"]
},
# Content & NLP
"sentiment": {
"domain": "nlp",
"task_type": "text_classification",
"data_types": ["text"],
"complexity": "beginner",
"models": ["bert", "roberta", "distilbert"],
"hf_tasks": ["text-classification"],
"keywords": ["opinion", "emotion", "feeling", "review", "feedback"]
},
"recommendation": {
"domain": "personalization",
"task_type": "recommendation",
"data_types": ["tabular", "behavioral", "content"],
"complexity": "advanced",
"models": ["collaborative_filtering", "content_based", "deep_learning"],
"hf_tasks": ["tabular-regression"],
"keywords": ["suggest", "personalize", "similar", "like", "prefer"]
},
# Pricing & Optimization
"pricing": {
"domain": "revenue_optimization",
"task_type": "regression",
"data_types": ["tabular", "time_series"],
"complexity": "advanced",
"models": ["ensemble", "reinforcement_learning", "optimization"],
"hf_tasks": ["tabular-regression"],
"keywords": ["price", "cost", "revenue", "profit", "optimize"]
}
}
# Advanced query analysis
def analyze_query_intent(query_text, cases=None):
"""Analyze query to extract business domain, problem type, and complexity"""
query_lower = query_text.lower()
# Extract business domain
detected_domain = None
domain_confidence = 0
for domain, use_cases in business_domains.items():
if domain in query_lower:
detected_domain = domain
domain_confidence = 0.9
break
# Check use case matches
for use_case in use_cases:
if use_case.lower() in query_lower:
detected_domain = domain
domain_confidence = 0.7
break
if detected_domain:
break
# Extract problem type with scoring
problem_matches = []
for problem_name, problem_info in problem_types.items():
score = 0
# Direct problem name match
if problem_name in query_lower:
score += 50
# Keyword matches
for keyword in problem_info["keywords"]:
if keyword in query_lower:
score += 10
# Context from relevant cases
if cases:
case_text = " ".join([f"{case.description} {case.summary[:300]}" for case in cases]).lower()
if problem_name in case_text:
score += 20
for keyword in problem_info["keywords"]:
if keyword in case_text:
score += 5
if score > 0:
problem_matches.append((problem_name, problem_info, score))
# Sort by score and get best matches
problem_matches.sort(key=lambda x: x[2], reverse=True)
return detected_domain, problem_matches[:3], domain_confidence
# Analyze the query
query_lower = query.lower()
detected_domain, top_problems, domain_confidence = analyze_query_intent(query, relevant_cases)
# Determine primary task and approach
if top_problems:
primary_problem = top_problems[0]
problem_info = primary_problem[1]
primary_task = problem_info["hf_tasks"][0] if problem_info["hf_tasks"] else "tabular-classification"
complexity = problem_info["complexity"]
preferred_models = problem_info["models"]
log_to_ui(f"🎯 Detected problem: '{primary_problem[0]}' (score: {primary_problem[2]})")
log_to_ui(f"📊 Domain: {detected_domain or 'general'} | Complexity: {complexity}")
log_to_ui(f"🔧 Preferred models: {', '.join(preferred_models[:3])}")
else:
# Fallback to simple keyword matching
primary_task = "tabular-classification"
complexity = "intermediate"
preferred_models = ["xgboost", "lightgbm"]
log_to_ui(f"📊 Using fallback classification | Task: {primary_task}")
matched_keywords = [p[0] for p in top_problems]
log_to_ui(f"📊 Primary task: '{primary_task}' | Keywords: {matched_keywords}")
# Search for models with multiple strategies
all_models = []
# Strategy 1: Search by primary task
models_primary = await search_models_by_task(primary_task, api_key)
all_models.extend(models_primary)
# Strategy 2: Search by specific keywords for better matches
if matched_keywords:
for keyword in matched_keywords[:2]: # Top 2 keywords
keyword_models = await search_models_by_keyword(keyword, api_key)
all_models.extend(keyword_models)
# Strategy 3: Search for domain-specific models
domain_searches = []
if "churn" in query_lower or "retention" in query_lower:
domain_searches.append("customer-analytics")
if "fraud" in query_lower:
domain_searches.append("anomaly-detection")
if "recommend" in query_lower:
domain_searches.append("recommendation")
for domain in domain_searches:
domain_models = await search_models_by_keyword(domain, api_key)
all_models.extend(domain_models)
# Remove duplicates and rank by relevance
seen_models = set()
unique_models = []
for model in all_models:
model_id = model.get("id") or model.get("name")
if model_id and model_id not in seen_models:
seen_models.add(model_id)
unique_models.append(model)
# Score models based on enhanced relevance criteria
scored_models = []
for model in unique_models:
score = calculate_model_relevance(
model, query_lower, matched_keywords,
complexity, preferred_models if 'preferred_models' in locals() else None
)
scored_models.append((model, score))
# Separate models into fine-tuned/specific vs general base models
fine_tuned_models = []
general_models = []
for model, score in scored_models:
if is_fine_tuned_model(model, matched_keywords):
fine_tuned_models.append((model, score))
elif is_general_suitable_model(model, primary_task):
general_models.append((model, score))
# Sort and take top 3 of each type
fine_tuned_models.sort(key=lambda x: x[1], reverse=True)
general_models.sort(key=lambda x: x[1], reverse=True)
top_fine_tuned = [model for model, score in fine_tuned_models[:3]]
top_general = [model for model, score in general_models[:3]]
# Add curated high-quality models for specific use cases
def get_curated_models(problem_type: str, complexity_level: str) -> List[Dict]:
"""Get curated high-quality models for specific use cases"""
curated = {
"churn": {
"beginner": [
{"id": "scikit-learn/RandomForestClassifier", "task": "tabular-classification"},
{"id": "xgboost/XGBClassifier", "task": "tabular-classification"}
],
"intermediate": [
{"id": "microsoft/TabNet", "task": "tabular-classification"},
{"id": "AutoML/AutoGluon-Tabular", "task": "tabular-classification"}
],
"advanced": [
{"id": "microsoft/LightGBM", "task": "tabular-classification"},
{"id": "dmlc/xgboost", "task": "tabular-classification"}
]
},
"sentiment": {
"beginner": [
{"id": "cardiffnlp/twitter-roberta-base-sentiment-latest", "task": "text-classification"},
{"id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"}
],
"intermediate": [
{"id": "nlptown/bert-base-multilingual-uncased-sentiment", "task": "text-classification"},
{"id": "microsoft/DialoGPT-medium", "task": "text-classification"}
],
"advanced": [
{"id": "roberta-large-mnli", "task": "text-classification"},
{"id": "microsoft/deberta-v3-large", "task": "text-classification"}
]
},
"fraud": {
"intermediate": [
{"id": "microsoft/TabNet", "task": "tabular-classification"},
{"id": "IsolationForest/AnomalyDetection", "task": "tabular-classification"}
],
"advanced": [
{"id": "pyod/AutoEncoder", "task": "tabular-classification"},
{"id": "GraphNeuralNetworks/FraudDetection", "task": "tabular-classification"}
]
},
"forecast": {
"intermediate": [
{"id": "facebook/prophet", "task": "time-series-forecasting"},
{"id": "statsmodels/ARIMA", "task": "time-series-forecasting"}
],
"advanced": [
{"id": "microsoft/DeepAR", "task": "time-series-forecasting"},
{"id": "google/temporal-fusion-transformer", "task": "time-series-forecasting"}
]
}
}
# Get curated models for the specific problem and complexity
if problem_type in curated and complexity_level in curated[problem_type]:
return curated[problem_type][complexity_level]
elif problem_type in curated:
# Fallback to any complexity level available
for level in ["beginner", "intermediate", "advanced"]:
if level in curated[problem_type]:
return curated[problem_type][level]
return []
# Add curated models
if top_problems:
primary_problem_name = top_problems[0][0]
curated_models = get_curated_models(primary_problem_name, complexity)
for curated_model in curated_models:
if len(top_general) < 3:
# Format as HuggingFace model dict
formatted_model = {
"id": curated_model["id"],
"pipeline_tag": curated_model["task"],
"downloads": 50000, # Reasonable default
"tags": ["curated", "production-ready"]
}
top_general.append(formatted_model)
# Add general foundation models if we still don't have enough
if len(top_general) < 3:
foundation_models = await get_foundation_models(primary_task, matched_keywords, api_key)
top_general.extend(foundation_models[:3-len(top_general)])
# Format response with categories
model_response = {
"fine_tuned": [],
"general": []
}
# Enhanced model descriptions based on detected problem type
def get_enhanced_model_description(model: Dict, model_type: str, problem_context: str = None) -> str:
"""Generate context-aware model descriptions"""
model_name = model.get("id", "").lower()
if model_type == "fine-tuned":
if problem_context == "churn":
return "Pre-trained model optimized for customer retention prediction"
elif problem_context == "fraud":
return "Specialized anomaly detection model for fraud identification"
elif problem_context == "sentiment":
return "Fine-tuned sentiment analysis model for text classification"
elif problem_context == "forecast":
return "Time series forecasting model for demand prediction"
else:
return f"Specialized model fine-tuned for {get_model_specialty(model, matched_keywords)}"
else: # general
if "curated" in str(model.get("tags", [])):
return "Production-ready model recommended for business use cases"
elif any(term in model_name for term in ["bert", "roberta", "distilbert"]):
return "Transformer-based foundation model for fine-tuning"
elif any(term in model_name for term in ["xgboost", "lightgbm", "catboost"]):
return "Gradient boosting model excellent for tabular data"
elif "prophet" in model_name:
return "Facebook's time series forecasting framework"
else:
return f"Foundation model suitable for {primary_task.replace('-', ' ')}"
# Format fine-tuned models with enhanced descriptions
primary_problem_name = top_problems[0][0] if top_problems else None
for model in top_fine_tuned:
model_info = {
"name": model.get("id", model.get("name", "Unknown")),
"downloads": model.get("downloads", 0),
"task": model.get("pipeline_tag", primary_task),
"url": f"https://huggingface.co/{model.get('id', '')}",
"type": "fine-tuned",
"description": get_enhanced_model_description(model, "fine-tuned", primary_problem_name)
}
model_response["fine_tuned"].append(model_info)
# Format general models with enhanced descriptions
for model in top_general:
model_info = {
"name": model.get("id", model.get("name", "Unknown")),
"downloads": model.get("downloads", 0),
"task": model.get("pipeline_tag", primary_task),
"url": f"https://huggingface.co/{model.get('id', '')}",
"type": "general",
"description": get_enhanced_model_description(model, "general", primary_problem_name)
}
model_response["general"].append(model_info)
total_models = len(model_response["fine_tuned"]) + len(model_response["general"])
log_to_ui(f"📦 Found {len(model_response['fine_tuned'])} fine-tuned + {len(model_response['general'])} general models")
# Log details
if model_response["fine_tuned"]:
log_to_ui("🎯 Fine-tuned/Specialized models:")
for i, model in enumerate(model_response["fine_tuned"], 1):
log_to_ui(f" {i}. {model['name']} - {model['downloads']:,} downloads")
if model_response["general"]:
log_to_ui("🔧 General/Foundation models:")
for i, model in enumerate(model_response["general"], 1):
log_to_ui(f" {i}. {model['name']} - {model['downloads']:,} downloads")
return model_response
except Exception as e:
log_to_ui(f"❌ Error fetching HuggingFace models: {e}")
return {"fine_tuned": [], "general": []}
async def search_models_by_task(task: str, api_key: str = None) -> List[Dict]:
"""Search models by specific task"""
try:
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
response = requests.get(
f"https://huggingface.co/api/models?pipeline_tag={task}&sort=downloads&limit=10",
headers=headers,
timeout=10
)
if response.status_code == 200:
return response.json()
except:
pass
return []
async def search_models_by_keyword(keyword: str, api_key: str = None) -> List[Dict]:
"""Search models by keyword in name/description"""
try:
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
response = requests.get(
f"https://huggingface.co/api/models?search={keyword}&sort=downloads&limit=10",
headers=headers,
timeout=10
)
if response.status_code == 200:
return response.json()
except:
pass
return []
def calculate_model_relevance(model: Dict, query: str, keywords: List[str],
complexity: str = "intermediate", preferred_models: List[str] = None) -> float:
"""Enhanced multi-criteria model relevance scoring"""
score = 0
model_name = model.get("id", "").lower()
model_task = model.get("pipeline_tag", "").lower()
downloads = model.get("downloads", 0)
# 1. Base popularity score (0-15 points)
if downloads > 10000000: # 10M+
score += 15
elif downloads > 1000000: # 1M+
score += 12
elif downloads > 100000: # 100K+
score += 8
elif downloads > 10000: # 10K+
score += 5
elif downloads > 1000: # 1K+
score += 2
# 2. Direct keyword relevance (0-30 points)
for keyword in keywords:
if keyword in model_name:
score += 25
# Check in model description/tags if available
model_tags = model.get("tags", [])
if any(keyword in str(tag).lower() for tag in model_tags):
score += 15
# 3. Query term matches (0-20 points)
query_words = [w for w in query.lower().split() if len(w) > 3]
for word in query_words:
if word in model_name:
score += 8
if word in str(model.get("tags", [])).lower():
score += 5
# 4. Preferred model architecture bonus (0-25 points)
if preferred_models:
for preferred in preferred_models:
if preferred.lower() in model_name:
score += 20
break
# Partial matches
for preferred in preferred_models:
if any(part in model_name for part in preferred.lower().split('_')):
score += 10
break
# 5. Task alignment (0-20 points)
relevant_tasks = ["tabular-classification", "tabular-regression", "text-classification",
"time-series-forecasting", "question-answering"]
if model_task in relevant_tasks:
score += 15
# 6. Complexity alignment (0-15 points)
complexity_indicators = {
"beginner": ["base", "simple", "basic", "distil", "small", "mini"],
"intermediate": ["medium", "standard", "v2", "improved"],
"advanced": ["large", "xl", "xxl", "advanced", "complex", "ensemble"]
}
if complexity in complexity_indicators:
for indicator in complexity_indicators[complexity]:
if indicator in model_name:
score += 12
break
# 7. Production readiness indicators (0-10 points)
production_terms = ["production", "optimized", "efficient", "fast", "deployment"]
for term in production_terms:
if term in model_name:
score += 8
break
# 8. Penalties for problematic models (negative points)
penalty_terms = ["nsfw", "adult", "sexual", "violence", "toxic", "unsafe", "experimental"]
for term in penalty_terms:
if term in model_name:
score -= 30
# Generic model penalty
generic_terms = ["general", "random", "test", "example", "demo"]
for term in generic_terms:
if term in model_name:
score -= 10
# 9. Model quality indicators (0-10 points)
quality_terms = ["sota", "benchmark", "award", "winner", "best", "top"]
for term in quality_terms:
if term in model_name or term in str(model.get("tags", [])).lower():
score += 8
break
# 10. Recency bonus (0-5 points) - prefer newer models
# This would require model creation date, approximating with model name patterns
recent_indicators = ["2024", "2023", "v3", "v4", "v5", "latest", "new"]
for indicator in recent_indicators:
if indicator in model_name:
score += 3
break
return max(score, 0)
def is_fine_tuned_model(model: Dict, keywords: List[str]) -> bool:
"""Determine if a model is fine-tuned/specialized for the specific task"""
model_name = model.get("id", "").lower()
# Models with specific task keywords in name are likely fine-tuned
for keyword in keywords:
if keyword in model_name:
return True
# Models with specific fine-tuning indicators
fine_tuned_indicators = [
"fine-tuned", "ft", "finetuned", "specialized", "custom",
"churn", "fraud", "sentiment", "classification", "detection",
"prediction", "analytics", "recommendation", "recommender"
]
for indicator in fine_tuned_indicators:
if indicator in model_name:
return True
# Models from specific companies/domains (often specialized)
domain_indicators = ["customer", "business", "financial", "ecommerce", "retail"]
for domain in domain_indicators:
if domain in model_name:
return True
return False
def is_general_suitable_model(model: Dict, primary_task: str) -> bool:
"""Determine if a model is a general foundation model suitable for the task"""
model_name = model.get("id", "").lower()
model_task = model.get("pipeline_tag", "").lower()
# Check if model task matches primary task
if model_task == primary_task:
return True
# General foundation models (base models good for fine-tuning)
foundation_indicators = [
"base", "large", "xlarge", "bert", "roberta", "distilbert",
"electra", "albert", "transformer", "gpt", "t5", "bart",
"deberta", "xlnet", "longformer"
]
for indicator in foundation_indicators:
if indicator in model_name and not any(x in model_name for x in ["nsfw", "safety", "toxicity"]):
return True
return False
async def get_foundation_models(primary_task: str, keywords: List[str], api_key: str = None) -> List[Dict]:
"""Get well-known foundation models suitable for the task"""
foundation_searches = []
if primary_task in ["text-classification", "token-classification"]:
foundation_searches = [
"bert-base-uncased", "roberta-base", "distilbert-base-uncased",
"microsoft/deberta-v3-base", "google/electra-base-discriminator"
]
elif primary_task in ["tabular-classification", "tabular-regression"]:
foundation_searches = [
"scikit-learn", "xgboost", "lightgbm", "catboost", "pytorch-tabular"
]
elif primary_task in ["text-generation", "conversational"]:
foundation_searches = [
"gpt2", "microsoft/DialoGPT-medium", "facebook/blenderbot"
]
elif primary_task in ["question-answering"]:
foundation_searches = [
"bert-base-uncased", "distilbert-base-uncased", "roberta-base"
]
models = []
for search_term in foundation_searches[:3]: # Top 3 foundation searches
try:
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
response = requests.get(
f"https://huggingface.co/api/models?search={search_term}&sort=downloads&limit=3",
headers=headers,
timeout=10
)
if response.status_code == 200:
models.extend(response.json())
except:
continue
return models[:3] # Return top 3
def get_model_specialty(model: Dict, keywords: List[str]) -> str:
"""Get human-readable specialty description for a model"""
model_name = model.get("id", "").lower()
# Map keywords to descriptions
specialty_map = {
"churn": "customer churn prediction",
"fraud": "fraud detection",
"sentiment": "sentiment analysis",
"recommendation": "recommendation systems",
"classification": "classification tasks",
"detection": "detection tasks",
"prediction": "prediction tasks"
}
for keyword in keywords:
if keyword in specialty_map:
return specialty_map[keyword]
# Fallback: try to infer from model name
if "churn" in model_name:
return "customer churn prediction"
elif "fraud" in model_name:
return "fraud detection"
elif "sentiment" in model_name:
return "sentiment analysis"
elif "recommend" in model_name:
return "recommendation systems"
else:
return "specialized ML tasks"
# Serve static files
app.mount("/static", StaticFiles(directory="static"), name="static")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) # HF Spaces uses port 7860