Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Header | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| from huggingface_hub import login | |
| import requests | |
| import json | |
| from typing import List, Dict, Any | |
| import os | |
| import sys | |
| import torch | |
| import tarfile | |
| app = FastAPI(title="ML Use Cases RAG System") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variable to store current logs | |
| current_logs = [] | |
| def log_to_ui(message): | |
| """Add a log message that will be sent to UI""" | |
| current_logs.append(message) | |
| print(message) # Still print to console | |
| # Initialize embedding model | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # BYOK: No server-side API key initialization | |
| # All model access will be done via user-provided API keys | |
| print("🔑 BYOK Mode: No server-side API key configured") | |
| print("✅ Users will provide their own HuggingFace API keys") | |
| generator = None | |
| llm_available = False | |
| # Auto-extract ChromaDB if archive exists and directory is missing/empty | |
| def setup_chromadb(): | |
| """Setup ChromaDB by extracting archive if needed""" | |
| if os.path.exists("chroma_db_complete.tar.gz"): | |
| # Check if chroma_db directory exists and has content | |
| needs_extraction = False | |
| if not os.path.exists("chroma_db"): | |
| print("📦 ChromaDB directory not found, extracting archive...") | |
| needs_extraction = True | |
| else: | |
| # Check if directory is empty or missing key files | |
| try: | |
| if not os.path.exists("chroma_db/chroma.sqlite3"): | |
| print("📦 ChromaDB missing database file, extracting archive...") | |
| needs_extraction = True | |
| else: | |
| # Quick check: try to list collections | |
| temp_client = chromadb.PersistentClient(path="./chroma_db") | |
| collections = temp_client.list_collections() | |
| if len(collections) == 0: | |
| print("📦 ChromaDB has no collections, extracting archive...") | |
| needs_extraction = True | |
| else: | |
| print(f"✅ ChromaDB already setup with {len(collections)} collections") | |
| except Exception as e: | |
| print(f"📦 ChromaDB check failed ({e}), extracting archive...") | |
| needs_extraction = True | |
| if needs_extraction: | |
| try: | |
| print("🔧 Extracting ChromaDB archive...") | |
| with tarfile.open("chroma_db_complete.tar.gz", "r:gz") as tar: | |
| tar.extractall() | |
| print("✅ ChromaDB extracted successfully") | |
| # Verify extraction | |
| if os.path.exists("chroma_db/chroma.sqlite3"): | |
| print("✅ Database file found after extraction") | |
| else: | |
| print("❌ Database file missing after extraction") | |
| except Exception as e: | |
| print(f"❌ Failed to extract ChromaDB: {e}") | |
| else: | |
| print("📋 No ChromaDB archive found, using existing directory") | |
| # Setup ChromaDB before initializing client | |
| setup_chromadb() | |
| # Initialize ChromaDB | |
| chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
| collection = None | |
| class ChatRequest(BaseModel): | |
| query: str | |
| class ApiKeyRequest(BaseModel): | |
| api_key: str | |
| class SearchResult(BaseModel): | |
| company: str | |
| industry: str | |
| year: int | |
| description: str | |
| summary: str | |
| similarity_score: float | |
| url: str | |
| class RecommendedModels(BaseModel): | |
| fine_tuned: List[Dict[str, Any]] | |
| general: List[Dict[str, Any]] | |
| class ChatResponse(BaseModel): | |
| solution_approach: str | |
| company_examples: List[SearchResult] | |
| recommended_models: RecommendedModels | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy"} | |
| async def test_token_direct(token: str): | |
| """Direct token test endpoint""" | |
| print(f"🧪 Testing token: {token[:10]}...") | |
| try: | |
| # Test with models API | |
| response = requests.get( | |
| "https://huggingface.co/api/models?limit=1", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| timeout=10 | |
| ) | |
| print(f"📊 Models API Status: {response.status_code}") | |
| if response.status_code == 200: | |
| return {"valid": True, "method": "models_api", "status": response.status_code} | |
| # Test whoami | |
| response2 = requests.get( | |
| "https://huggingface.co/api/whoami", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| timeout=10 | |
| ) | |
| print(f"📊 WhoAmI Status: {response2.status_code}") | |
| return { | |
| "valid": response2.status_code == 200, | |
| "models_status": response.status_code, | |
| "whoami_status": response2.status_code, | |
| "whoami_response": response2.text[:200] if response2.status_code != 200 else "OK" | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| async def validate_api_key(request: ApiKeyRequest): | |
| """Validate user's HuggingFace API key""" | |
| api_key = request.api_key.strip() | |
| print(f"🔑 Validating API key: {api_key[:10]}...") | |
| if not api_key or not api_key.startswith('hf_'): | |
| print(f"❌ Invalid format: {api_key[:10] if api_key else 'empty'}") | |
| return {"valid": False, "error": "Invalid API key format. Must start with 'hf_'"} | |
| # Simple format validation - if it looks like a valid HF token, accept it | |
| if len(api_key) >= 30 and api_key.startswith('hf_') and all(c.isalnum() or c == '_' for c in api_key): | |
| print("✅ API key format is valid, accepting") | |
| return {"valid": True, "user": "User"} | |
| print(f"❌ Invalid token format or length") | |
| return {"valid": False, "error": "Invalid API key format"} | |
| async def get_logs(): | |
| """Get current log messages for UI""" | |
| try: | |
| logs_copy = current_logs.copy() | |
| current_logs.clear() | |
| return {"logs": logs_copy} | |
| except Exception as e: | |
| return {"logs": [], "error": str(e)} | |
| async def test_logs(): | |
| """Test endpoint to verify logging works""" | |
| log_to_ui("🧪 Test log message 1") | |
| log_to_ui("🧪 Test log message 2") | |
| log_to_ui("🧪 Test log message 3") | |
| return {"message": "Test logs added"} | |
| def initialize_collection(): | |
| """Initialize the ChromaDB collection with debug logging""" | |
| global collection | |
| # Debug: Check file system | |
| print(f"🔍 Current working directory: {os.getcwd()}") | |
| print(f"🔍 ChromaDB path exists: {os.path.exists('./chroma_db')}") | |
| if os.path.exists('./chroma_db'): | |
| try: | |
| chroma_files = os.listdir('./chroma_db') | |
| print(f"🔍 ChromaDB directory contents: {chroma_files}") | |
| # Check for main database file | |
| if 'chroma.sqlite3' in chroma_files: | |
| print("✅ Found chroma.sqlite3") | |
| else: | |
| print("❌ chroma.sqlite3 NOT found") | |
| # Check for UUID directories | |
| uuid_dirs = [f for f in chroma_files if len(f) == 36 and '-' in f] # UUID format | |
| if uuid_dirs: | |
| print(f"✅ Found UUID directories: {uuid_dirs}") | |
| for uuid_dir in uuid_dirs: | |
| uuid_path = os.path.join('./chroma_db', uuid_dir) | |
| if os.path.isdir(uuid_path): | |
| uuid_files = os.listdir(uuid_path) | |
| print(f"🔍 {uuid_dir} contents: {uuid_files}") | |
| else: | |
| print("❌ No UUID directories found") | |
| except Exception as e: | |
| print(f"❌ Error reading chroma_db directory: {e}") | |
| else: | |
| print("❌ chroma_db directory does not exist") | |
| # Debug: Try to initialize ChromaDB client | |
| try: | |
| print("🔍 Attempting to initialize ChromaDB client...") | |
| print(f"🔍 ChromaDB version: {chromadb.__version__}") | |
| # List all collections | |
| collections = chroma_client.list_collections() | |
| print(f"🔍 Available collections: {[c.name for c in collections]}") | |
| # Try to get the specific collection | |
| collection = chroma_client.get_collection("ml_use_cases") | |
| collection_count = collection.count() | |
| print(f"✅ Found existing collection 'ml_use_cases' with {collection_count} documents") | |
| except Exception as e: | |
| print(f"❌ Collection initialization error: {type(e).__name__}: {e}") | |
| print("📝 Will attempt to create collection during first use") | |
| collection = None | |
| # Initialize collection on import | |
| initialize_collection() | |
| async def root(): | |
| """Serve the main frontend""" | |
| with open("static/index.html", "r") as f: | |
| return HTMLResponse(f.read()) | |
| async def search_use_cases_internal(request: ChatRequest): | |
| """Internal search function with detailed logging""" | |
| log_to_ui(f"🔍 Search request received: '{request.query}'") | |
| if not collection: | |
| log_to_ui("❌ ChromaDB collection not initialized") | |
| raise HTTPException(status_code=500, detail="Database not initialized") | |
| query = request.query.lower() | |
| log_to_ui(f"📝 Normalized query: '{query}'") | |
| # Generate query embedding for semantic search | |
| log_to_ui("🧠 Generating query embedding...") | |
| query_embedding = embedding_model.encode([request.query]).tolist()[0] | |
| log_to_ui(f"✅ Embedding generated, dimension: {len(query_embedding)}") | |
| # Semantic search | |
| log_to_ui("🔎 Performing semantic search...") | |
| semantic_results = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=15, | |
| include=['metadatas', 'documents', 'distances'] | |
| ) | |
| log_to_ui(f"📊 Semantic search found {len(semantic_results['ids'][0])} results") | |
| # Keyword-based search using where clause for exact matches | |
| keyword_results = None | |
| try: | |
| log_to_ui("🔤 Performing keyword search...") | |
| keyword_results = collection.query( | |
| query_texts=[request.query], | |
| n_results=10, | |
| include=['metadatas', 'documents', 'distances'] | |
| ) | |
| log_to_ui(f"📝 Keyword search found {len(keyword_results['ids'][0])} results") | |
| except Exception as e: | |
| log_to_ui(f"⚠️ Keyword search failed: {e}") | |
| pass | |
| # Combine and rank results | |
| combined_results = {} | |
| # Process semantic results | |
| for i in range(len(semantic_results['ids'][0])): | |
| doc_id = semantic_results['ids'][0][i] | |
| metadata = semantic_results['metadatas'][0][i] | |
| similarity_score = 1 - semantic_results['distances'][0][i] | |
| # Boost score for keyword matches in metadata | |
| boost = 0 | |
| query_words = query.split() | |
| for word in query_words: | |
| if word in metadata.get('title', '').lower(): | |
| boost += 0.3 | |
| if word in metadata.get('description', '').lower(): | |
| boost += 0.2 | |
| if word in metadata.get('keywords', '').lower(): | |
| boost += 0.4 | |
| if word in metadata.get('industry', '').lower(): | |
| boost += 0.1 | |
| final_score = min(similarity_score + boost, 1.0) | |
| combined_results[doc_id] = { | |
| 'metadata': metadata, | |
| 'summary': semantic_results['documents'][0][i], | |
| 'score': final_score, | |
| 'source': 'semantic' | |
| } | |
| # Process keyword results if available | |
| if keyword_results: | |
| for i in range(len(keyword_results['ids'][0])): | |
| doc_id = keyword_results['ids'][0][i] | |
| if doc_id not in combined_results: | |
| metadata = keyword_results['metadatas'][0][i] | |
| similarity_score = 1 - keyword_results['distances'][0][i] | |
| combined_results[doc_id] = { | |
| 'metadata': metadata, | |
| 'summary': keyword_results['documents'][0][i], | |
| 'score': similarity_score + 0.1, # Small boost for keyword matches | |
| 'source': 'keyword' | |
| } | |
| # Sort by score and take top results | |
| sorted_results = sorted(combined_results.values(), key=lambda x: x['score'], reverse=True)[:10] | |
| log_to_ui(f"🎯 Combined and ranked results: {len(sorted_results)} final results") | |
| search_results = [] | |
| for i, result in enumerate(sorted_results): | |
| metadata = result['metadata'] | |
| search_results.append(SearchResult( | |
| company=metadata.get('company', ''), | |
| industry=metadata.get('industry', ''), | |
| year=metadata.get('year', 2023), | |
| description=metadata.get('description', ''), | |
| summary=result['summary'], | |
| similarity_score=result['score'], | |
| url=metadata.get('url', '') | |
| )) | |
| log_to_ui(f" {i+1}. {metadata.get('company', 'Unknown')} - Score: {result['score']:.3f}") | |
| log_to_ui(f"✅ Search completed, returning {len(search_results)} results") | |
| return search_results | |
| async def search_use_cases(request: ChatRequest): | |
| """Public search endpoint""" | |
| results = await search_use_cases_internal(request) | |
| return {"results": results} | |
| async def generate_response_with_user_key(prompt: str, api_key: str, max_length: int = 500) -> str: | |
| """Generate response using user's HuggingFace API key via Inference API""" | |
| try: | |
| # Use HuggingFace Inference API with user's key | |
| api_url = "https://api-inference.huggingface.co/models/google/gemma-2-2b-it" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": max_length, | |
| "temperature": 0.7, | |
| "do_sample": True, | |
| "return_full_text": False | |
| } | |
| } | |
| response = requests.post(api_url, headers=headers, json=payload, timeout=30) | |
| if response.status_code == 200: | |
| result = response.json() | |
| if isinstance(result, list) and len(result) > 0: | |
| generated_text = result[0].get('generated_text', '') | |
| return generated_text.strip() | |
| else: | |
| return "Unable to generate response. Please try again." | |
| elif response.status_code == 503: | |
| # Model is loading, try fallback | |
| return await try_fallback_model(prompt, api_key, max_length) | |
| else: | |
| raise Exception(f"API request failed with status {response.status_code}") | |
| except Exception as e: | |
| print(f"Error generating response with user API key: {e}") | |
| return generate_template_response(prompt) | |
| async def try_fallback_model(prompt: str, api_key: str, max_length: int = 500) -> str: | |
| """Try fallback model when primary model is unavailable""" | |
| try: | |
| # Try a more readily available model as fallback | |
| fallback_models = [ | |
| "microsoft/DialoGPT-medium", | |
| "microsoft/DialoGPT-small", | |
| "gpt2" | |
| ] | |
| for model_name in fallback_models: | |
| try: | |
| api_url = f"https://api-inference.huggingface.co/models/{model_name}" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": max_length, | |
| "temperature": 0.7, | |
| "do_sample": True, | |
| "return_full_text": False | |
| } | |
| } | |
| response = requests.post(api_url, headers=headers, json=payload, timeout=20) | |
| if response.status_code == 200: | |
| result = response.json() | |
| if isinstance(result, list) and len(result) > 0: | |
| generated_text = result[0].get('generated_text', '') | |
| return generated_text.strip() | |
| except: | |
| continue | |
| # If all models fail, return template | |
| return generate_template_response(prompt) | |
| except Exception as e: | |
| return generate_template_response(prompt) | |
| def generate_template_response(prompt: str) -> str: | |
| """Generate a template response when AI models are not available""" | |
| return f"""Based on the analysis of similar ML/AI implementations from companies in our database, here are key recommendations for your problem: | |
| **Technical Approach:** | |
| - Consider machine learning classification or prediction models | |
| - Leverage data preprocessing and feature engineering | |
| - Implement proper model validation and testing | |
| **Implementation Strategy:** | |
| - Start with a minimum viable model using existing data | |
| - Iterate based on performance metrics | |
| - Consider scalability and real-time requirements | |
| **Key Considerations:** | |
| - Data quality and availability | |
| - Business metrics alignment | |
| - Technical infrastructure requirements | |
| This analysis is based on patterns from 400+ real-world ML implementations across various industries.""" | |
| async def chat_with_rag(request: ChatRequest, x_hf_api_key: str = Header(None, alias="X-HF-API-Key")): | |
| """Main RAG endpoint with user API key""" | |
| # Validate user API key | |
| if not x_hf_api_key or not x_hf_api_key.startswith('hf_'): | |
| raise HTTPException(status_code=400, detail="Valid HuggingFace API key required") | |
| # Clear previous logs and start fresh | |
| current_logs.clear() | |
| log_to_ui(f"🤖 Chat request received: '{request.query}'") | |
| # First search for relevant use cases | |
| log_to_ui("🔍 Getting relevant use cases...") | |
| relevant_cases = await search_use_cases_internal(request) | |
| top_cases = relevant_cases[:5] # Top 5 results | |
| log_to_ui(f"📚 Using top {len(top_cases)} cases for context") | |
| # Prepare context for LLM | |
| log_to_ui("📝 Preparing context for LLM...") | |
| context = "Here are relevant real-world ML/AI implementations:\n\n" | |
| for i, case in enumerate(top_cases, 1): | |
| context += f"Company: {case.company} ({case.industry}, {case.year})\n" | |
| context += f"Description: {case.description}\n" | |
| context += f"Implementation: {case.summary[:500]}...\n\n" | |
| log_to_ui(f" {i}. {case.company} - {case.description}") | |
| log_to_ui(f"📊 Context length: {len(context)} characters") | |
| # Create prompt for language model | |
| prompt = f"""Based on the following real ML/AI implementations from companies, provide advice for this business problem: | |
| {context} | |
| User Problem: {request.query} | |
| Please provide a comprehensive solution approach that considers what has worked for these companies. Focus on: | |
| 1. What type of ML/AI solution would address this problem | |
| 2. Key technical approaches that have proven successful | |
| 3. Implementation considerations based on the examples | |
| Be specific and reference the examples when relevant. | |
| Response:""" | |
| log_to_ui(f"💭 Full prompt length: {len(prompt)} characters") | |
| # Generate response using user's HuggingFace API key | |
| log_to_ui("🚀 Generating AI response with user API key...") | |
| try: | |
| llm_response = await generate_response_with_user_key(prompt, x_hf_api_key, max_length=400) | |
| log_to_ui(f"✅ AI response generated, length: {len(llm_response)} characters") | |
| except Exception as e: | |
| llm_response = f"Error generating AI response: {str(e)}" | |
| log_to_ui(f"❌ AI response error: {e}") | |
| # Get HuggingFace model recommendations using user's API key | |
| log_to_ui("🤗 Getting HuggingFace model recommendations...") | |
| recommended_models = await get_huggingface_models(request.query, top_cases, x_hf_api_key) | |
| total_models = len(recommended_models.get("fine_tuned", [])) + len(recommended_models.get("general", [])) | |
| log_to_ui(f"🏷️ Found {total_models} recommended models") | |
| log_to_ui("✅ Chat response complete!") | |
| # Return response with logs included | |
| return { | |
| "solution_approach": llm_response, | |
| "company_examples": [ | |
| { | |
| "company": case.company, | |
| "industry": case.industry, | |
| "year": case.year, | |
| "description": case.description, | |
| "summary": case.summary, | |
| "similarity_score": case.similarity_score, | |
| "url": case.url | |
| } | |
| for case in top_cases | |
| ], | |
| "recommended_models": { | |
| "fine_tuned": recommended_models.get("fine_tuned", []), | |
| "general": recommended_models.get("general", []) | |
| }, | |
| "logs": current_logs.copy() # Include all logs in the response | |
| } | |
| async def get_huggingface_models(query: str, relevant_cases: List = None, api_key: str = None) -> Dict[str, List[Dict[str, Any]]]: | |
| """Get relevant ML models from HuggingFace based on query and similar use cases""" | |
| log_to_ui(f"🔍 Analyzing query for ML task mapping: '{query}'") | |
| try: | |
| # Enhanced multi-dimensional classification system | |
| business_domains = { | |
| # Financial Services | |
| "finance": ["fraud detection", "risk assessment", "algorithmic trading", "credit scoring"], | |
| "banking": ["fraud detection", "credit scoring", "customer segmentation", "loan approval"], | |
| "fintech": ["payment processing", "robo-advisor", "fraud detection", "credit scoring"], | |
| "insurance": ["risk assessment", "claim processing", "fraud detection", "pricing optimization"], | |
| # E-commerce & Retail | |
| "ecommerce": ["recommendation systems", "demand forecasting", "price optimization", "customer segmentation"], | |
| "retail": ["inventory management", "demand forecasting", "customer analytics", "supply chain"], | |
| "marketplace": ["search ranking", "recommendation systems", "fraud detection", "seller analytics"], | |
| # Healthcare & Life Sciences | |
| "healthcare": ["medical imaging", "drug discovery", "patient risk prediction", "clinical decision support"], | |
| "medical": ["diagnostic imaging", "treatment optimization", "patient monitoring", "clinical trials"], | |
| "pharma": ["drug discovery", "clinical trials", "adverse event detection", "molecular analysis"], | |
| # Technology & Media | |
| "tech": ["user behavior analysis", "system optimization", "content moderation", "search ranking"], | |
| "media": ["content recommendation", "audience analytics", "content generation", "sentiment analysis"], | |
| "gaming": ["player behavior prediction", "game optimization", "content generation", "cheat detection"], | |
| # Marketing & Advertising | |
| "marketing": ["customer segmentation", "campaign optimization", "lead scoring", "attribution modeling"], | |
| "advertising": ["ad targeting", "bid optimization", "creative optimization", "audience analytics"], | |
| "social": ["sentiment analysis", "trend prediction", "content moderation", "influence analysis"] | |
| } | |
| problem_types = { | |
| # Customer Analytics | |
| "churn": { | |
| "domain": "customer_analytics", | |
| "task_type": "binary_classification", | |
| "data_types": ["tabular", "behavioral"], | |
| "complexity": "intermediate", | |
| "models": ["xgboost", "lightgbm", "catboost", "random_forest"], | |
| "hf_tasks": ["tabular-classification"], | |
| "keywords": ["retention", "attrition", "leave", "cancel", "subscription"] | |
| }, | |
| "segmentation": { | |
| "domain": "customer_analytics", | |
| "task_type": "clustering", | |
| "data_types": ["tabular", "behavioral"], | |
| "complexity": "intermediate", | |
| "models": ["kmeans", "dbscan", "hierarchical", "gaussian_mixture"], | |
| "hf_tasks": ["tabular-classification"], | |
| "keywords": ["segment", "group", "persona", "cluster", "behavior"] | |
| }, | |
| # Risk & Fraud | |
| "fraud": { | |
| "domain": "risk_management", | |
| "task_type": "anomaly_detection", | |
| "data_types": ["tabular", "graph", "time_series"], | |
| "complexity": "advanced", | |
| "models": ["isolation_forest", "autoencoder", "one_class_svm", "gnn"], | |
| "hf_tasks": ["tabular-classification"], | |
| "keywords": ["suspicious", "anomaly", "unusual", "scam", "fake"] | |
| }, | |
| "risk": { | |
| "domain": "risk_management", | |
| "task_type": "regression", | |
| "data_types": ["tabular", "time_series"], | |
| "complexity": "advanced", | |
| "models": ["ensemble", "deep_learning", "survival_analysis"], | |
| "hf_tasks": ["tabular-regression"], | |
| "keywords": ["probability", "likelihood", "exposure", "default", "loss"] | |
| }, | |
| # Demand & Forecasting | |
| "forecast": { | |
| "domain": "demand_planning", | |
| "task_type": "time_series_forecasting", | |
| "data_types": ["time_series", "tabular"], | |
| "complexity": "advanced", | |
| "models": ["prophet", "lstm", "transformer", "arima"], | |
| "hf_tasks": ["time-series-forecasting"], | |
| "keywords": ["predict", "future", "trend", "seasonal", "demand", "sales"] | |
| }, | |
| "demand": { | |
| "domain": "demand_planning", | |
| "task_type": "regression", | |
| "data_types": ["time_series", "tabular"], | |
| "complexity": "intermediate", | |
| "models": ["xgboost", "lstm", "prophet"], | |
| "hf_tasks": ["tabular-regression", "time-series-forecasting"], | |
| "keywords": ["inventory", "supply", "planning", "optimization"] | |
| }, | |
| # Content & NLP | |
| "sentiment": { | |
| "domain": "nlp", | |
| "task_type": "text_classification", | |
| "data_types": ["text"], | |
| "complexity": "beginner", | |
| "models": ["bert", "roberta", "distilbert"], | |
| "hf_tasks": ["text-classification"], | |
| "keywords": ["opinion", "emotion", "feeling", "review", "feedback"] | |
| }, | |
| "recommendation": { | |
| "domain": "personalization", | |
| "task_type": "recommendation", | |
| "data_types": ["tabular", "behavioral", "content"], | |
| "complexity": "advanced", | |
| "models": ["collaborative_filtering", "content_based", "deep_learning"], | |
| "hf_tasks": ["tabular-regression"], | |
| "keywords": ["suggest", "personalize", "similar", "like", "prefer"] | |
| }, | |
| # Pricing & Optimization | |
| "pricing": { | |
| "domain": "revenue_optimization", | |
| "task_type": "regression", | |
| "data_types": ["tabular", "time_series"], | |
| "complexity": "advanced", | |
| "models": ["ensemble", "reinforcement_learning", "optimization"], | |
| "hf_tasks": ["tabular-regression"], | |
| "keywords": ["price", "cost", "revenue", "profit", "optimize"] | |
| } | |
| } | |
| # Advanced query analysis | |
| def analyze_query_intent(query_text, cases=None): | |
| """Analyze query to extract business domain, problem type, and complexity""" | |
| query_lower = query_text.lower() | |
| # Extract business domain | |
| detected_domain = None | |
| domain_confidence = 0 | |
| for domain, use_cases in business_domains.items(): | |
| if domain in query_lower: | |
| detected_domain = domain | |
| domain_confidence = 0.9 | |
| break | |
| # Check use case matches | |
| for use_case in use_cases: | |
| if use_case.lower() in query_lower: | |
| detected_domain = domain | |
| domain_confidence = 0.7 | |
| break | |
| if detected_domain: | |
| break | |
| # Extract problem type with scoring | |
| problem_matches = [] | |
| for problem_name, problem_info in problem_types.items(): | |
| score = 0 | |
| # Direct problem name match | |
| if problem_name in query_lower: | |
| score += 50 | |
| # Keyword matches | |
| for keyword in problem_info["keywords"]: | |
| if keyword in query_lower: | |
| score += 10 | |
| # Context from relevant cases | |
| if cases: | |
| case_text = " ".join([f"{case.description} {case.summary[:300]}" for case in cases]).lower() | |
| if problem_name in case_text: | |
| score += 20 | |
| for keyword in problem_info["keywords"]: | |
| if keyword in case_text: | |
| score += 5 | |
| if score > 0: | |
| problem_matches.append((problem_name, problem_info, score)) | |
| # Sort by score and get best matches | |
| problem_matches.sort(key=lambda x: x[2], reverse=True) | |
| return detected_domain, problem_matches[:3], domain_confidence | |
| # Analyze the query | |
| query_lower = query.lower() | |
| detected_domain, top_problems, domain_confidence = analyze_query_intent(query, relevant_cases) | |
| # Determine primary task and approach | |
| if top_problems: | |
| primary_problem = top_problems[0] | |
| problem_info = primary_problem[1] | |
| primary_task = problem_info["hf_tasks"][0] if problem_info["hf_tasks"] else "tabular-classification" | |
| complexity = problem_info["complexity"] | |
| preferred_models = problem_info["models"] | |
| log_to_ui(f"🎯 Detected problem: '{primary_problem[0]}' (score: {primary_problem[2]})") | |
| log_to_ui(f"📊 Domain: {detected_domain or 'general'} | Complexity: {complexity}") | |
| log_to_ui(f"🔧 Preferred models: {', '.join(preferred_models[:3])}") | |
| else: | |
| # Fallback to simple keyword matching | |
| primary_task = "tabular-classification" | |
| complexity = "intermediate" | |
| preferred_models = ["xgboost", "lightgbm"] | |
| log_to_ui(f"📊 Using fallback classification | Task: {primary_task}") | |
| matched_keywords = [p[0] for p in top_problems] | |
| log_to_ui(f"📊 Primary task: '{primary_task}' | Keywords: {matched_keywords}") | |
| # Search for models with multiple strategies | |
| all_models = [] | |
| # Strategy 1: Search by primary task | |
| models_primary = await search_models_by_task(primary_task, api_key) | |
| all_models.extend(models_primary) | |
| # Strategy 2: Search by specific keywords for better matches | |
| if matched_keywords: | |
| for keyword in matched_keywords[:2]: # Top 2 keywords | |
| keyword_models = await search_models_by_keyword(keyword, api_key) | |
| all_models.extend(keyword_models) | |
| # Strategy 3: Search for domain-specific models | |
| domain_searches = [] | |
| if "churn" in query_lower or "retention" in query_lower: | |
| domain_searches.append("customer-analytics") | |
| if "fraud" in query_lower: | |
| domain_searches.append("anomaly-detection") | |
| if "recommend" in query_lower: | |
| domain_searches.append("recommendation") | |
| for domain in domain_searches: | |
| domain_models = await search_models_by_keyword(domain, api_key) | |
| all_models.extend(domain_models) | |
| # Remove duplicates and rank by relevance | |
| seen_models = set() | |
| unique_models = [] | |
| for model in all_models: | |
| model_id = model.get("id") or model.get("name") | |
| if model_id and model_id not in seen_models: | |
| seen_models.add(model_id) | |
| unique_models.append(model) | |
| # Score models based on enhanced relevance criteria | |
| scored_models = [] | |
| for model in unique_models: | |
| score = calculate_model_relevance( | |
| model, query_lower, matched_keywords, | |
| complexity, preferred_models if 'preferred_models' in locals() else None | |
| ) | |
| scored_models.append((model, score)) | |
| # Separate models into fine-tuned/specific vs general base models | |
| fine_tuned_models = [] | |
| general_models = [] | |
| for model, score in scored_models: | |
| if is_fine_tuned_model(model, matched_keywords): | |
| fine_tuned_models.append((model, score)) | |
| elif is_general_suitable_model(model, primary_task): | |
| general_models.append((model, score)) | |
| # Sort and take top 3 of each type | |
| fine_tuned_models.sort(key=lambda x: x[1], reverse=True) | |
| general_models.sort(key=lambda x: x[1], reverse=True) | |
| top_fine_tuned = [model for model, score in fine_tuned_models[:3]] | |
| top_general = [model for model, score in general_models[:3]] | |
| # Add curated high-quality models for specific use cases | |
| def get_curated_models(problem_type: str, complexity_level: str) -> List[Dict]: | |
| """Get curated high-quality models for specific use cases""" | |
| curated = { | |
| "churn": { | |
| "beginner": [ | |
| {"id": "scikit-learn/RandomForestClassifier", "task": "tabular-classification"}, | |
| {"id": "xgboost/XGBClassifier", "task": "tabular-classification"} | |
| ], | |
| "intermediate": [ | |
| {"id": "microsoft/TabNet", "task": "tabular-classification"}, | |
| {"id": "AutoML/AutoGluon-Tabular", "task": "tabular-classification"} | |
| ], | |
| "advanced": [ | |
| {"id": "microsoft/LightGBM", "task": "tabular-classification"}, | |
| {"id": "dmlc/xgboost", "task": "tabular-classification"} | |
| ] | |
| }, | |
| "sentiment": { | |
| "beginner": [ | |
| {"id": "cardiffnlp/twitter-roberta-base-sentiment-latest", "task": "text-classification"}, | |
| {"id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"} | |
| ], | |
| "intermediate": [ | |
| {"id": "nlptown/bert-base-multilingual-uncased-sentiment", "task": "text-classification"}, | |
| {"id": "microsoft/DialoGPT-medium", "task": "text-classification"} | |
| ], | |
| "advanced": [ | |
| {"id": "roberta-large-mnli", "task": "text-classification"}, | |
| {"id": "microsoft/deberta-v3-large", "task": "text-classification"} | |
| ] | |
| }, | |
| "fraud": { | |
| "intermediate": [ | |
| {"id": "microsoft/TabNet", "task": "tabular-classification"}, | |
| {"id": "IsolationForest/AnomalyDetection", "task": "tabular-classification"} | |
| ], | |
| "advanced": [ | |
| {"id": "pyod/AutoEncoder", "task": "tabular-classification"}, | |
| {"id": "GraphNeuralNetworks/FraudDetection", "task": "tabular-classification"} | |
| ] | |
| }, | |
| "forecast": { | |
| "intermediate": [ | |
| {"id": "facebook/prophet", "task": "time-series-forecasting"}, | |
| {"id": "statsmodels/ARIMA", "task": "time-series-forecasting"} | |
| ], | |
| "advanced": [ | |
| {"id": "microsoft/DeepAR", "task": "time-series-forecasting"}, | |
| {"id": "google/temporal-fusion-transformer", "task": "time-series-forecasting"} | |
| ] | |
| } | |
| } | |
| # Get curated models for the specific problem and complexity | |
| if problem_type in curated and complexity_level in curated[problem_type]: | |
| return curated[problem_type][complexity_level] | |
| elif problem_type in curated: | |
| # Fallback to any complexity level available | |
| for level in ["beginner", "intermediate", "advanced"]: | |
| if level in curated[problem_type]: | |
| return curated[problem_type][level] | |
| return [] | |
| # Add curated models | |
| if top_problems: | |
| primary_problem_name = top_problems[0][0] | |
| curated_models = get_curated_models(primary_problem_name, complexity) | |
| for curated_model in curated_models: | |
| if len(top_general) < 3: | |
| # Format as HuggingFace model dict | |
| formatted_model = { | |
| "id": curated_model["id"], | |
| "pipeline_tag": curated_model["task"], | |
| "downloads": 50000, # Reasonable default | |
| "tags": ["curated", "production-ready"] | |
| } | |
| top_general.append(formatted_model) | |
| # Add general foundation models if we still don't have enough | |
| if len(top_general) < 3: | |
| foundation_models = await get_foundation_models(primary_task, matched_keywords, api_key) | |
| top_general.extend(foundation_models[:3-len(top_general)]) | |
| # Format response with categories | |
| model_response = { | |
| "fine_tuned": [], | |
| "general": [] | |
| } | |
| # Enhanced model descriptions based on detected problem type | |
| def get_enhanced_model_description(model: Dict, model_type: str, problem_context: str = None) -> str: | |
| """Generate context-aware model descriptions""" | |
| model_name = model.get("id", "").lower() | |
| if model_type == "fine-tuned": | |
| if problem_context == "churn": | |
| return "Pre-trained model optimized for customer retention prediction" | |
| elif problem_context == "fraud": | |
| return "Specialized anomaly detection model for fraud identification" | |
| elif problem_context == "sentiment": | |
| return "Fine-tuned sentiment analysis model for text classification" | |
| elif problem_context == "forecast": | |
| return "Time series forecasting model for demand prediction" | |
| else: | |
| return f"Specialized model fine-tuned for {get_model_specialty(model, matched_keywords)}" | |
| else: # general | |
| if "curated" in str(model.get("tags", [])): | |
| return "Production-ready model recommended for business use cases" | |
| elif any(term in model_name for term in ["bert", "roberta", "distilbert"]): | |
| return "Transformer-based foundation model for fine-tuning" | |
| elif any(term in model_name for term in ["xgboost", "lightgbm", "catboost"]): | |
| return "Gradient boosting model excellent for tabular data" | |
| elif "prophet" in model_name: | |
| return "Facebook's time series forecasting framework" | |
| else: | |
| return f"Foundation model suitable for {primary_task.replace('-', ' ')}" | |
| # Format fine-tuned models with enhanced descriptions | |
| primary_problem_name = top_problems[0][0] if top_problems else None | |
| for model in top_fine_tuned: | |
| model_info = { | |
| "name": model.get("id", model.get("name", "Unknown")), | |
| "downloads": model.get("downloads", 0), | |
| "task": model.get("pipeline_tag", primary_task), | |
| "url": f"https://huggingface.co/{model.get('id', '')}", | |
| "type": "fine-tuned", | |
| "description": get_enhanced_model_description(model, "fine-tuned", primary_problem_name) | |
| } | |
| model_response["fine_tuned"].append(model_info) | |
| # Format general models with enhanced descriptions | |
| for model in top_general: | |
| model_info = { | |
| "name": model.get("id", model.get("name", "Unknown")), | |
| "downloads": model.get("downloads", 0), | |
| "task": model.get("pipeline_tag", primary_task), | |
| "url": f"https://huggingface.co/{model.get('id', '')}", | |
| "type": "general", | |
| "description": get_enhanced_model_description(model, "general", primary_problem_name) | |
| } | |
| model_response["general"].append(model_info) | |
| total_models = len(model_response["fine_tuned"]) + len(model_response["general"]) | |
| log_to_ui(f"📦 Found {len(model_response['fine_tuned'])} fine-tuned + {len(model_response['general'])} general models") | |
| # Log details | |
| if model_response["fine_tuned"]: | |
| log_to_ui("🎯 Fine-tuned/Specialized models:") | |
| for i, model in enumerate(model_response["fine_tuned"], 1): | |
| log_to_ui(f" {i}. {model['name']} - {model['downloads']:,} downloads") | |
| if model_response["general"]: | |
| log_to_ui("🔧 General/Foundation models:") | |
| for i, model in enumerate(model_response["general"], 1): | |
| log_to_ui(f" {i}. {model['name']} - {model['downloads']:,} downloads") | |
| return model_response | |
| except Exception as e: | |
| log_to_ui(f"❌ Error fetching HuggingFace models: {e}") | |
| return {"fine_tuned": [], "general": []} | |
| async def search_models_by_task(task: str, api_key: str = None) -> List[Dict]: | |
| """Search models by specific task""" | |
| try: | |
| headers = {} | |
| if api_key: | |
| headers["Authorization"] = f"Bearer {api_key}" | |
| response = requests.get( | |
| f"https://huggingface.co/api/models?pipeline_tag={task}&sort=downloads&limit=10", | |
| headers=headers, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| return response.json() | |
| except: | |
| pass | |
| return [] | |
| async def search_models_by_keyword(keyword: str, api_key: str = None) -> List[Dict]: | |
| """Search models by keyword in name/description""" | |
| try: | |
| headers = {} | |
| if api_key: | |
| headers["Authorization"] = f"Bearer {api_key}" | |
| response = requests.get( | |
| f"https://huggingface.co/api/models?search={keyword}&sort=downloads&limit=10", | |
| headers=headers, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| return response.json() | |
| except: | |
| pass | |
| return [] | |
| def calculate_model_relevance(model: Dict, query: str, keywords: List[str], | |
| complexity: str = "intermediate", preferred_models: List[str] = None) -> float: | |
| """Enhanced multi-criteria model relevance scoring""" | |
| score = 0 | |
| model_name = model.get("id", "").lower() | |
| model_task = model.get("pipeline_tag", "").lower() | |
| downloads = model.get("downloads", 0) | |
| # 1. Base popularity score (0-15 points) | |
| if downloads > 10000000: # 10M+ | |
| score += 15 | |
| elif downloads > 1000000: # 1M+ | |
| score += 12 | |
| elif downloads > 100000: # 100K+ | |
| score += 8 | |
| elif downloads > 10000: # 10K+ | |
| score += 5 | |
| elif downloads > 1000: # 1K+ | |
| score += 2 | |
| # 2. Direct keyword relevance (0-30 points) | |
| for keyword in keywords: | |
| if keyword in model_name: | |
| score += 25 | |
| # Check in model description/tags if available | |
| model_tags = model.get("tags", []) | |
| if any(keyword in str(tag).lower() for tag in model_tags): | |
| score += 15 | |
| # 3. Query term matches (0-20 points) | |
| query_words = [w for w in query.lower().split() if len(w) > 3] | |
| for word in query_words: | |
| if word in model_name: | |
| score += 8 | |
| if word in str(model.get("tags", [])).lower(): | |
| score += 5 | |
| # 4. Preferred model architecture bonus (0-25 points) | |
| if preferred_models: | |
| for preferred in preferred_models: | |
| if preferred.lower() in model_name: | |
| score += 20 | |
| break | |
| # Partial matches | |
| for preferred in preferred_models: | |
| if any(part in model_name for part in preferred.lower().split('_')): | |
| score += 10 | |
| break | |
| # 5. Task alignment (0-20 points) | |
| relevant_tasks = ["tabular-classification", "tabular-regression", "text-classification", | |
| "time-series-forecasting", "question-answering"] | |
| if model_task in relevant_tasks: | |
| score += 15 | |
| # 6. Complexity alignment (0-15 points) | |
| complexity_indicators = { | |
| "beginner": ["base", "simple", "basic", "distil", "small", "mini"], | |
| "intermediate": ["medium", "standard", "v2", "improved"], | |
| "advanced": ["large", "xl", "xxl", "advanced", "complex", "ensemble"] | |
| } | |
| if complexity in complexity_indicators: | |
| for indicator in complexity_indicators[complexity]: | |
| if indicator in model_name: | |
| score += 12 | |
| break | |
| # 7. Production readiness indicators (0-10 points) | |
| production_terms = ["production", "optimized", "efficient", "fast", "deployment"] | |
| for term in production_terms: | |
| if term in model_name: | |
| score += 8 | |
| break | |
| # 8. Penalties for problematic models (negative points) | |
| penalty_terms = ["nsfw", "adult", "sexual", "violence", "toxic", "unsafe", "experimental"] | |
| for term in penalty_terms: | |
| if term in model_name: | |
| score -= 30 | |
| # Generic model penalty | |
| generic_terms = ["general", "random", "test", "example", "demo"] | |
| for term in generic_terms: | |
| if term in model_name: | |
| score -= 10 | |
| # 9. Model quality indicators (0-10 points) | |
| quality_terms = ["sota", "benchmark", "award", "winner", "best", "top"] | |
| for term in quality_terms: | |
| if term in model_name or term in str(model.get("tags", [])).lower(): | |
| score += 8 | |
| break | |
| # 10. Recency bonus (0-5 points) - prefer newer models | |
| # This would require model creation date, approximating with model name patterns | |
| recent_indicators = ["2024", "2023", "v3", "v4", "v5", "latest", "new"] | |
| for indicator in recent_indicators: | |
| if indicator in model_name: | |
| score += 3 | |
| break | |
| return max(score, 0) | |
| def is_fine_tuned_model(model: Dict, keywords: List[str]) -> bool: | |
| """Determine if a model is fine-tuned/specialized for the specific task""" | |
| model_name = model.get("id", "").lower() | |
| # Models with specific task keywords in name are likely fine-tuned | |
| for keyword in keywords: | |
| if keyword in model_name: | |
| return True | |
| # Models with specific fine-tuning indicators | |
| fine_tuned_indicators = [ | |
| "fine-tuned", "ft", "finetuned", "specialized", "custom", | |
| "churn", "fraud", "sentiment", "classification", "detection", | |
| "prediction", "analytics", "recommendation", "recommender" | |
| ] | |
| for indicator in fine_tuned_indicators: | |
| if indicator in model_name: | |
| return True | |
| # Models from specific companies/domains (often specialized) | |
| domain_indicators = ["customer", "business", "financial", "ecommerce", "retail"] | |
| for domain in domain_indicators: | |
| if domain in model_name: | |
| return True | |
| return False | |
| def is_general_suitable_model(model: Dict, primary_task: str) -> bool: | |
| """Determine if a model is a general foundation model suitable for the task""" | |
| model_name = model.get("id", "").lower() | |
| model_task = model.get("pipeline_tag", "").lower() | |
| # Check if model task matches primary task | |
| if model_task == primary_task: | |
| return True | |
| # General foundation models (base models good for fine-tuning) | |
| foundation_indicators = [ | |
| "base", "large", "xlarge", "bert", "roberta", "distilbert", | |
| "electra", "albert", "transformer", "gpt", "t5", "bart", | |
| "deberta", "xlnet", "longformer" | |
| ] | |
| for indicator in foundation_indicators: | |
| if indicator in model_name and not any(x in model_name for x in ["nsfw", "safety", "toxicity"]): | |
| return True | |
| return False | |
| async def get_foundation_models(primary_task: str, keywords: List[str], api_key: str = None) -> List[Dict]: | |
| """Get well-known foundation models suitable for the task""" | |
| foundation_searches = [] | |
| if primary_task in ["text-classification", "token-classification"]: | |
| foundation_searches = [ | |
| "bert-base-uncased", "roberta-base", "distilbert-base-uncased", | |
| "microsoft/deberta-v3-base", "google/electra-base-discriminator" | |
| ] | |
| elif primary_task in ["tabular-classification", "tabular-regression"]: | |
| foundation_searches = [ | |
| "scikit-learn", "xgboost", "lightgbm", "catboost", "pytorch-tabular" | |
| ] | |
| elif primary_task in ["text-generation", "conversational"]: | |
| foundation_searches = [ | |
| "gpt2", "microsoft/DialoGPT-medium", "facebook/blenderbot" | |
| ] | |
| elif primary_task in ["question-answering"]: | |
| foundation_searches = [ | |
| "bert-base-uncased", "distilbert-base-uncased", "roberta-base" | |
| ] | |
| models = [] | |
| for search_term in foundation_searches[:3]: # Top 3 foundation searches | |
| try: | |
| headers = {} | |
| if api_key: | |
| headers["Authorization"] = f"Bearer {api_key}" | |
| response = requests.get( | |
| f"https://huggingface.co/api/models?search={search_term}&sort=downloads&limit=3", | |
| headers=headers, | |
| timeout=10 | |
| ) | |
| if response.status_code == 200: | |
| models.extend(response.json()) | |
| except: | |
| continue | |
| return models[:3] # Return top 3 | |
| def get_model_specialty(model: Dict, keywords: List[str]) -> str: | |
| """Get human-readable specialty description for a model""" | |
| model_name = model.get("id", "").lower() | |
| # Map keywords to descriptions | |
| specialty_map = { | |
| "churn": "customer churn prediction", | |
| "fraud": "fraud detection", | |
| "sentiment": "sentiment analysis", | |
| "recommendation": "recommendation systems", | |
| "classification": "classification tasks", | |
| "detection": "detection tasks", | |
| "prediction": "prediction tasks" | |
| } | |
| for keyword in keywords: | |
| if keyword in specialty_map: | |
| return specialty_map[keyword] | |
| # Fallback: try to infer from model name | |
| if "churn" in model_name: | |
| return "customer churn prediction" | |
| elif "fraud" in model_name: | |
| return "fraud detection" | |
| elif "sentiment" in model_name: | |
| return "sentiment analysis" | |
| elif "recommend" in model_name: | |
| return "recommendation systems" | |
| else: | |
| return "specialized ML tasks" | |
| # Serve static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) # HF Spaces uses port 7860 |