from fastapi import FastAPI, HTTPException, Header from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import chromadb from sentence_transformers import SentenceTransformer from transformers import pipeline from huggingface_hub import login import requests import json from typing import List, Dict, Any import os import sys import torch import tarfile app = FastAPI(title="ML Use Cases RAG System") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variable to store current logs current_logs = [] def log_to_ui(message): """Add a log message that will be sent to UI""" current_logs.append(message) print(message) # Still print to console # Initialize embedding model embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # BYOK: No server-side API key initialization # All model access will be done via user-provided API keys print("๐Ÿ”‘ BYOK Mode: No server-side API key configured") print("โœ… Users will provide their own HuggingFace API keys") generator = None llm_available = False # Auto-extract ChromaDB if archive exists and directory is missing/empty def setup_chromadb(): """Setup ChromaDB by extracting archive if needed""" if os.path.exists("chroma_db_complete.tar.gz"): # Check if chroma_db directory exists and has content needs_extraction = False if not os.path.exists("chroma_db"): print("๐Ÿ“ฆ ChromaDB directory not found, extracting archive...") needs_extraction = True else: # Check if directory is empty or missing key files try: if not os.path.exists("chroma_db/chroma.sqlite3"): print("๐Ÿ“ฆ ChromaDB missing database file, extracting archive...") needs_extraction = True else: # Quick check: try to list collections temp_client = chromadb.PersistentClient(path="./chroma_db") collections = temp_client.list_collections() if len(collections) == 0: print("๐Ÿ“ฆ ChromaDB has no collections, extracting archive...") needs_extraction = True else: print(f"โœ… ChromaDB already setup with {len(collections)} collections") except Exception as e: print(f"๐Ÿ“ฆ ChromaDB check failed ({e}), extracting archive...") needs_extraction = True if needs_extraction: try: print("๐Ÿ”ง Extracting ChromaDB archive...") with tarfile.open("chroma_db_complete.tar.gz", "r:gz") as tar: tar.extractall() print("โœ… ChromaDB extracted successfully") # Verify extraction if os.path.exists("chroma_db/chroma.sqlite3"): print("โœ… Database file found after extraction") else: print("โŒ Database file missing after extraction") except Exception as e: print(f"โŒ Failed to extract ChromaDB: {e}") else: print("๐Ÿ“‹ No ChromaDB archive found, using existing directory") # Setup ChromaDB before initializing client setup_chromadb() # Initialize ChromaDB chroma_client = chromadb.PersistentClient(path="./chroma_db") collection = None class ChatRequest(BaseModel): query: str class ApiKeyRequest(BaseModel): api_key: str class SearchResult(BaseModel): company: str industry: str year: int description: str summary: str similarity_score: float url: str class RecommendedModels(BaseModel): fine_tuned: List[Dict[str, Any]] general: List[Dict[str, Any]] class ChatResponse(BaseModel): solution_approach: str company_examples: List[SearchResult] recommended_models: RecommendedModels @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"} @app.get("/test-token/{token}") async def test_token_direct(token: str): """Direct token test endpoint""" print(f"๐Ÿงช Testing token: {token[:10]}...") try: # Test with models API response = requests.get( "https://huggingface.co/api/models?limit=1", headers={"Authorization": f"Bearer {token}"}, timeout=10 ) print(f"๐Ÿ“Š Models API Status: {response.status_code}") if response.status_code == 200: return {"valid": True, "method": "models_api", "status": response.status_code} # Test whoami response2 = requests.get( "https://huggingface.co/api/whoami", headers={"Authorization": f"Bearer {token}"}, timeout=10 ) print(f"๐Ÿ“Š WhoAmI Status: {response2.status_code}") return { "valid": response2.status_code == 200, "models_status": response.status_code, "whoami_status": response2.status_code, "whoami_response": response2.text[:200] if response2.status_code != 200 else "OK" } except Exception as e: return {"error": str(e)} @app.post("/validate-key") async def validate_api_key(request: ApiKeyRequest): """Validate user's HuggingFace API key""" api_key = request.api_key.strip() print(f"๐Ÿ”‘ Validating API key: {api_key[:10]}...") if not api_key or not api_key.startswith('hf_'): print(f"โŒ Invalid format: {api_key[:10] if api_key else 'empty'}") return {"valid": False, "error": "Invalid API key format. Must start with 'hf_'"} # Simple format validation - if it looks like a valid HF token, accept it if len(api_key) >= 30 and api_key.startswith('hf_') and all(c.isalnum() or c == '_' for c in api_key): print("โœ… API key format is valid, accepting") return {"valid": True, "user": "User"} print(f"โŒ Invalid token format or length") return {"valid": False, "error": "Invalid API key format"} @app.get("/logs") async def get_logs(): """Get current log messages for UI""" try: logs_copy = current_logs.copy() current_logs.clear() return {"logs": logs_copy} except Exception as e: return {"logs": [], "error": str(e)} @app.get("/test-logs") async def test_logs(): """Test endpoint to verify logging works""" log_to_ui("๐Ÿงช Test log message 1") log_to_ui("๐Ÿงช Test log message 2") log_to_ui("๐Ÿงช Test log message 3") return {"message": "Test logs added"} def initialize_collection(): """Initialize the ChromaDB collection with debug logging""" global collection # Debug: Check file system print(f"๐Ÿ” Current working directory: {os.getcwd()}") print(f"๐Ÿ” ChromaDB path exists: {os.path.exists('./chroma_db')}") if os.path.exists('./chroma_db'): try: chroma_files = os.listdir('./chroma_db') print(f"๐Ÿ” ChromaDB directory contents: {chroma_files}") # Check for main database file if 'chroma.sqlite3' in chroma_files: print("โœ… Found chroma.sqlite3") else: print("โŒ chroma.sqlite3 NOT found") # Check for UUID directories uuid_dirs = [f for f in chroma_files if len(f) == 36 and '-' in f] # UUID format if uuid_dirs: print(f"โœ… Found UUID directories: {uuid_dirs}") for uuid_dir in uuid_dirs: uuid_path = os.path.join('./chroma_db', uuid_dir) if os.path.isdir(uuid_path): uuid_files = os.listdir(uuid_path) print(f"๐Ÿ” {uuid_dir} contents: {uuid_files}") else: print("โŒ No UUID directories found") except Exception as e: print(f"โŒ Error reading chroma_db directory: {e}") else: print("โŒ chroma_db directory does not exist") # Debug: Try to initialize ChromaDB client try: print("๐Ÿ” Attempting to initialize ChromaDB client...") print(f"๐Ÿ” ChromaDB version: {chromadb.__version__}") # List all collections collections = chroma_client.list_collections() print(f"๐Ÿ” Available collections: {[c.name for c in collections]}") # Try to get the specific collection collection = chroma_client.get_collection("ml_use_cases") collection_count = collection.count() print(f"โœ… Found existing collection 'ml_use_cases' with {collection_count} documents") except Exception as e: print(f"โŒ Collection initialization error: {type(e).__name__}: {e}") print("๐Ÿ“ Will attempt to create collection during first use") collection = None # Initialize collection on import initialize_collection() @app.get("/", response_class=HTMLResponse) async def root(): """Serve the main frontend""" with open("static/index.html", "r") as f: return HTMLResponse(f.read()) async def search_use_cases_internal(request: ChatRequest): """Internal search function with detailed logging""" log_to_ui(f"๐Ÿ” Search request received: '{request.query}'") if not collection: log_to_ui("โŒ ChromaDB collection not initialized") raise HTTPException(status_code=500, detail="Database not initialized") query = request.query.lower() log_to_ui(f"๐Ÿ“ Normalized query: '{query}'") # Generate query embedding for semantic search log_to_ui("๐Ÿง  Generating query embedding...") query_embedding = embedding_model.encode([request.query]).tolist()[0] log_to_ui(f"โœ… Embedding generated, dimension: {len(query_embedding)}") # Semantic search log_to_ui("๐Ÿ”Ž Performing semantic search...") semantic_results = collection.query( query_embeddings=[query_embedding], n_results=15, include=['metadatas', 'documents', 'distances'] ) log_to_ui(f"๐Ÿ“Š Semantic search found {len(semantic_results['ids'][0])} results") # Keyword-based search using where clause for exact matches keyword_results = None try: log_to_ui("๐Ÿ”ค Performing keyword search...") keyword_results = collection.query( query_texts=[request.query], n_results=10, include=['metadatas', 'documents', 'distances'] ) log_to_ui(f"๐Ÿ“ Keyword search found {len(keyword_results['ids'][0])} results") except Exception as e: log_to_ui(f"โš ๏ธ Keyword search failed: {e}") pass # Combine and rank results combined_results = {} # Process semantic results for i in range(len(semantic_results['ids'][0])): doc_id = semantic_results['ids'][0][i] metadata = semantic_results['metadatas'][0][i] similarity_score = 1 - semantic_results['distances'][0][i] # Boost score for keyword matches in metadata boost = 0 query_words = query.split() for word in query_words: if word in metadata.get('title', '').lower(): boost += 0.3 if word in metadata.get('description', '').lower(): boost += 0.2 if word in metadata.get('keywords', '').lower(): boost += 0.4 if word in metadata.get('industry', '').lower(): boost += 0.1 final_score = min(similarity_score + boost, 1.0) combined_results[doc_id] = { 'metadata': metadata, 'summary': semantic_results['documents'][0][i], 'score': final_score, 'source': 'semantic' } # Process keyword results if available if keyword_results: for i in range(len(keyword_results['ids'][0])): doc_id = keyword_results['ids'][0][i] if doc_id not in combined_results: metadata = keyword_results['metadatas'][0][i] similarity_score = 1 - keyword_results['distances'][0][i] combined_results[doc_id] = { 'metadata': metadata, 'summary': keyword_results['documents'][0][i], 'score': similarity_score + 0.1, # Small boost for keyword matches 'source': 'keyword' } # Sort by score and take top results sorted_results = sorted(combined_results.values(), key=lambda x: x['score'], reverse=True)[:10] log_to_ui(f"๐ŸŽฏ Combined and ranked results: {len(sorted_results)} final results") search_results = [] for i, result in enumerate(sorted_results): metadata = result['metadata'] search_results.append(SearchResult( company=metadata.get('company', ''), industry=metadata.get('industry', ''), year=metadata.get('year', 2023), description=metadata.get('description', ''), summary=result['summary'], similarity_score=result['score'], url=metadata.get('url', '') )) log_to_ui(f" {i+1}. {metadata.get('company', 'Unknown')} - Score: {result['score']:.3f}") log_to_ui(f"โœ… Search completed, returning {len(search_results)} results") return search_results @app.post("/search") async def search_use_cases(request: ChatRequest): """Public search endpoint""" results = await search_use_cases_internal(request) return {"results": results} async def generate_response_with_user_key(prompt: str, api_key: str, max_length: int = 500) -> str: """Generate response using user's HuggingFace API key via Inference API""" try: # Use HuggingFace Inference API with user's key api_url = "https://api-inference.huggingface.co/models/google/gemma-2-2b-it" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } payload = { "inputs": prompt, "parameters": { "max_new_tokens": max_length, "temperature": 0.7, "do_sample": True, "return_full_text": False } } response = requests.post(api_url, headers=headers, json=payload, timeout=30) if response.status_code == 200: result = response.json() if isinstance(result, list) and len(result) > 0: generated_text = result[0].get('generated_text', '') return generated_text.strip() else: return "Unable to generate response. Please try again." elif response.status_code == 503: # Model is loading, try fallback return await try_fallback_model(prompt, api_key, max_length) else: raise Exception(f"API request failed with status {response.status_code}") except Exception as e: print(f"Error generating response with user API key: {e}") return generate_template_response(prompt) async def try_fallback_model(prompt: str, api_key: str, max_length: int = 500) -> str: """Try fallback model when primary model is unavailable""" try: # Try a more readily available model as fallback fallback_models = [ "microsoft/DialoGPT-medium", "microsoft/DialoGPT-small", "gpt2" ] for model_name in fallback_models: try: api_url = f"https://api-inference.huggingface.co/models/{model_name}" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } payload = { "inputs": prompt, "parameters": { "max_new_tokens": max_length, "temperature": 0.7, "do_sample": True, "return_full_text": False } } response = requests.post(api_url, headers=headers, json=payload, timeout=20) if response.status_code == 200: result = response.json() if isinstance(result, list) and len(result) > 0: generated_text = result[0].get('generated_text', '') return generated_text.strip() except: continue # If all models fail, return template return generate_template_response(prompt) except Exception as e: return generate_template_response(prompt) def generate_template_response(prompt: str) -> str: """Generate a template response when AI models are not available""" return f"""Based on the analysis of similar ML/AI implementations from companies in our database, here are key recommendations for your problem: **Technical Approach:** - Consider machine learning classification or prediction models - Leverage data preprocessing and feature engineering - Implement proper model validation and testing **Implementation Strategy:** - Start with a minimum viable model using existing data - Iterate based on performance metrics - Consider scalability and real-time requirements **Key Considerations:** - Data quality and availability - Business metrics alignment - Technical infrastructure requirements This analysis is based on patterns from 400+ real-world ML implementations across various industries.""" @app.post("/chat", response_model=ChatResponse) async def chat_with_rag(request: ChatRequest, x_hf_api_key: str = Header(None, alias="X-HF-API-Key")): """Main RAG endpoint with user API key""" # Validate user API key if not x_hf_api_key or not x_hf_api_key.startswith('hf_'): raise HTTPException(status_code=400, detail="Valid HuggingFace API key required") # Clear previous logs and start fresh current_logs.clear() log_to_ui(f"๐Ÿค– Chat request received: '{request.query}'") # First search for relevant use cases log_to_ui("๐Ÿ” Getting relevant use cases...") relevant_cases = await search_use_cases_internal(request) top_cases = relevant_cases[:5] # Top 5 results log_to_ui(f"๐Ÿ“š Using top {len(top_cases)} cases for context") # Prepare context for LLM log_to_ui("๐Ÿ“ Preparing context for LLM...") context = "Here are relevant real-world ML/AI implementations:\n\n" for i, case in enumerate(top_cases, 1): context += f"Company: {case.company} ({case.industry}, {case.year})\n" context += f"Description: {case.description}\n" context += f"Implementation: {case.summary[:500]}...\n\n" log_to_ui(f" {i}. {case.company} - {case.description}") log_to_ui(f"๐Ÿ“Š Context length: {len(context)} characters") # Create prompt for language model prompt = f"""Based on the following real ML/AI implementations from companies, provide advice for this business problem: {context} User Problem: {request.query} Please provide a comprehensive solution approach that considers what has worked for these companies. Focus on: 1. What type of ML/AI solution would address this problem 2. Key technical approaches that have proven successful 3. Implementation considerations based on the examples Be specific and reference the examples when relevant. Response:""" log_to_ui(f"๐Ÿ’ญ Full prompt length: {len(prompt)} characters") # Generate response using user's HuggingFace API key log_to_ui("๐Ÿš€ Generating AI response with user API key...") try: llm_response = await generate_response_with_user_key(prompt, x_hf_api_key, max_length=400) log_to_ui(f"โœ… AI response generated, length: {len(llm_response)} characters") except Exception as e: llm_response = f"Error generating AI response: {str(e)}" log_to_ui(f"โŒ AI response error: {e}") # Get HuggingFace model recommendations using user's API key log_to_ui("๐Ÿค— Getting HuggingFace model recommendations...") recommended_models = await get_huggingface_models(request.query, top_cases, x_hf_api_key) total_models = len(recommended_models.get("fine_tuned", [])) + len(recommended_models.get("general", [])) log_to_ui(f"๐Ÿท๏ธ Found {total_models} recommended models") log_to_ui("โœ… Chat response complete!") # Return response with logs included return { "solution_approach": llm_response, "company_examples": [ { "company": case.company, "industry": case.industry, "year": case.year, "description": case.description, "summary": case.summary, "similarity_score": case.similarity_score, "url": case.url } for case in top_cases ], "recommended_models": { "fine_tuned": recommended_models.get("fine_tuned", []), "general": recommended_models.get("general", []) }, "logs": current_logs.copy() # Include all logs in the response } async def get_huggingface_models(query: str, relevant_cases: List = None, api_key: str = None) -> Dict[str, List[Dict[str, Any]]]: """Get relevant ML models from HuggingFace based on query and similar use cases""" log_to_ui(f"๐Ÿ” Analyzing query for ML task mapping: '{query}'") try: # Enhanced multi-dimensional classification system business_domains = { # Financial Services "finance": ["fraud detection", "risk assessment", "algorithmic trading", "credit scoring"], "banking": ["fraud detection", "credit scoring", "customer segmentation", "loan approval"], "fintech": ["payment processing", "robo-advisor", "fraud detection", "credit scoring"], "insurance": ["risk assessment", "claim processing", "fraud detection", "pricing optimization"], # E-commerce & Retail "ecommerce": ["recommendation systems", "demand forecasting", "price optimization", "customer segmentation"], "retail": ["inventory management", "demand forecasting", "customer analytics", "supply chain"], "marketplace": ["search ranking", "recommendation systems", "fraud detection", "seller analytics"], # Healthcare & Life Sciences "healthcare": ["medical imaging", "drug discovery", "patient risk prediction", "clinical decision support"], "medical": ["diagnostic imaging", "treatment optimization", "patient monitoring", "clinical trials"], "pharma": ["drug discovery", "clinical trials", "adverse event detection", "molecular analysis"], # Technology & Media "tech": ["user behavior analysis", "system optimization", "content moderation", "search ranking"], "media": ["content recommendation", "audience analytics", "content generation", "sentiment analysis"], "gaming": ["player behavior prediction", "game optimization", "content generation", "cheat detection"], # Marketing & Advertising "marketing": ["customer segmentation", "campaign optimization", "lead scoring", "attribution modeling"], "advertising": ["ad targeting", "bid optimization", "creative optimization", "audience analytics"], "social": ["sentiment analysis", "trend prediction", "content moderation", "influence analysis"] } problem_types = { # Customer Analytics "churn": { "domain": "customer_analytics", "task_type": "binary_classification", "data_types": ["tabular", "behavioral"], "complexity": "intermediate", "models": ["xgboost", "lightgbm", "catboost", "random_forest"], "hf_tasks": ["tabular-classification"], "keywords": ["retention", "attrition", "leave", "cancel", "subscription"] }, "segmentation": { "domain": "customer_analytics", "task_type": "clustering", "data_types": ["tabular", "behavioral"], "complexity": "intermediate", "models": ["kmeans", "dbscan", "hierarchical", "gaussian_mixture"], "hf_tasks": ["tabular-classification"], "keywords": ["segment", "group", "persona", "cluster", "behavior"] }, # Risk & Fraud "fraud": { "domain": "risk_management", "task_type": "anomaly_detection", "data_types": ["tabular", "graph", "time_series"], "complexity": "advanced", "models": ["isolation_forest", "autoencoder", "one_class_svm", "gnn"], "hf_tasks": ["tabular-classification"], "keywords": ["suspicious", "anomaly", "unusual", "scam", "fake"] }, "risk": { "domain": "risk_management", "task_type": "regression", "data_types": ["tabular", "time_series"], "complexity": "advanced", "models": ["ensemble", "deep_learning", "survival_analysis"], "hf_tasks": ["tabular-regression"], "keywords": ["probability", "likelihood", "exposure", "default", "loss"] }, # Demand & Forecasting "forecast": { "domain": "demand_planning", "task_type": "time_series_forecasting", "data_types": ["time_series", "tabular"], "complexity": "advanced", "models": ["prophet", "lstm", "transformer", "arima"], "hf_tasks": ["time-series-forecasting"], "keywords": ["predict", "future", "trend", "seasonal", "demand", "sales"] }, "demand": { "domain": "demand_planning", "task_type": "regression", "data_types": ["time_series", "tabular"], "complexity": "intermediate", "models": ["xgboost", "lstm", "prophet"], "hf_tasks": ["tabular-regression", "time-series-forecasting"], "keywords": ["inventory", "supply", "planning", "optimization"] }, # Content & NLP "sentiment": { "domain": "nlp", "task_type": "text_classification", "data_types": ["text"], "complexity": "beginner", "models": ["bert", "roberta", "distilbert"], "hf_tasks": ["text-classification"], "keywords": ["opinion", "emotion", "feeling", "review", "feedback"] }, "recommendation": { "domain": "personalization", "task_type": "recommendation", "data_types": ["tabular", "behavioral", "content"], "complexity": "advanced", "models": ["collaborative_filtering", "content_based", "deep_learning"], "hf_tasks": ["tabular-regression"], "keywords": ["suggest", "personalize", "similar", "like", "prefer"] }, # Pricing & Optimization "pricing": { "domain": "revenue_optimization", "task_type": "regression", "data_types": ["tabular", "time_series"], "complexity": "advanced", "models": ["ensemble", "reinforcement_learning", "optimization"], "hf_tasks": ["tabular-regression"], "keywords": ["price", "cost", "revenue", "profit", "optimize"] } } # Advanced query analysis def analyze_query_intent(query_text, cases=None): """Analyze query to extract business domain, problem type, and complexity""" query_lower = query_text.lower() # Extract business domain detected_domain = None domain_confidence = 0 for domain, use_cases in business_domains.items(): if domain in query_lower: detected_domain = domain domain_confidence = 0.9 break # Check use case matches for use_case in use_cases: if use_case.lower() in query_lower: detected_domain = domain domain_confidence = 0.7 break if detected_domain: break # Extract problem type with scoring problem_matches = [] for problem_name, problem_info in problem_types.items(): score = 0 # Direct problem name match if problem_name in query_lower: score += 50 # Keyword matches for keyword in problem_info["keywords"]: if keyword in query_lower: score += 10 # Context from relevant cases if cases: case_text = " ".join([f"{case.description} {case.summary[:300]}" for case in cases]).lower() if problem_name in case_text: score += 20 for keyword in problem_info["keywords"]: if keyword in case_text: score += 5 if score > 0: problem_matches.append((problem_name, problem_info, score)) # Sort by score and get best matches problem_matches.sort(key=lambda x: x[2], reverse=True) return detected_domain, problem_matches[:3], domain_confidence # Analyze the query query_lower = query.lower() detected_domain, top_problems, domain_confidence = analyze_query_intent(query, relevant_cases) # Determine primary task and approach if top_problems: primary_problem = top_problems[0] problem_info = primary_problem[1] primary_task = problem_info["hf_tasks"][0] if problem_info["hf_tasks"] else "tabular-classification" complexity = problem_info["complexity"] preferred_models = problem_info["models"] log_to_ui(f"๐ŸŽฏ Detected problem: '{primary_problem[0]}' (score: {primary_problem[2]})") log_to_ui(f"๐Ÿ“Š Domain: {detected_domain or 'general'} | Complexity: {complexity}") log_to_ui(f"๐Ÿ”ง Preferred models: {', '.join(preferred_models[:3])}") else: # Fallback to simple keyword matching primary_task = "tabular-classification" complexity = "intermediate" preferred_models = ["xgboost", "lightgbm"] log_to_ui(f"๐Ÿ“Š Using fallback classification | Task: {primary_task}") matched_keywords = [p[0] for p in top_problems] log_to_ui(f"๐Ÿ“Š Primary task: '{primary_task}' | Keywords: {matched_keywords}") # Search for models with multiple strategies all_models = [] # Strategy 1: Search by primary task models_primary = await search_models_by_task(primary_task, api_key) all_models.extend(models_primary) # Strategy 2: Search by specific keywords for better matches if matched_keywords: for keyword in matched_keywords[:2]: # Top 2 keywords keyword_models = await search_models_by_keyword(keyword, api_key) all_models.extend(keyword_models) # Strategy 3: Search for domain-specific models domain_searches = [] if "churn" in query_lower or "retention" in query_lower: domain_searches.append("customer-analytics") if "fraud" in query_lower: domain_searches.append("anomaly-detection") if "recommend" in query_lower: domain_searches.append("recommendation") for domain in domain_searches: domain_models = await search_models_by_keyword(domain, api_key) all_models.extend(domain_models) # Remove duplicates and rank by relevance seen_models = set() unique_models = [] for model in all_models: model_id = model.get("id") or model.get("name") if model_id and model_id not in seen_models: seen_models.add(model_id) unique_models.append(model) # Score models based on enhanced relevance criteria scored_models = [] for model in unique_models: score = calculate_model_relevance( model, query_lower, matched_keywords, complexity, preferred_models if 'preferred_models' in locals() else None ) scored_models.append((model, score)) # Separate models into fine-tuned/specific vs general base models fine_tuned_models = [] general_models = [] for model, score in scored_models: if is_fine_tuned_model(model, matched_keywords): fine_tuned_models.append((model, score)) elif is_general_suitable_model(model, primary_task): general_models.append((model, score)) # Sort and take top 3 of each type fine_tuned_models.sort(key=lambda x: x[1], reverse=True) general_models.sort(key=lambda x: x[1], reverse=True) top_fine_tuned = [model for model, score in fine_tuned_models[:3]] top_general = [model for model, score in general_models[:3]] # Add curated high-quality models for specific use cases def get_curated_models(problem_type: str, complexity_level: str) -> List[Dict]: """Get curated high-quality models for specific use cases""" curated = { "churn": { "beginner": [ {"id": "scikit-learn/RandomForestClassifier", "task": "tabular-classification"}, {"id": "xgboost/XGBClassifier", "task": "tabular-classification"} ], "intermediate": [ {"id": "microsoft/TabNet", "task": "tabular-classification"}, {"id": "AutoML/AutoGluon-Tabular", "task": "tabular-classification"} ], "advanced": [ {"id": "microsoft/LightGBM", "task": "tabular-classification"}, {"id": "dmlc/xgboost", "task": "tabular-classification"} ] }, "sentiment": { "beginner": [ {"id": "cardiffnlp/twitter-roberta-base-sentiment-latest", "task": "text-classification"}, {"id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"} ], "intermediate": [ {"id": "nlptown/bert-base-multilingual-uncased-sentiment", "task": "text-classification"}, {"id": "microsoft/DialoGPT-medium", "task": "text-classification"} ], "advanced": [ {"id": "roberta-large-mnli", "task": "text-classification"}, {"id": "microsoft/deberta-v3-large", "task": "text-classification"} ] }, "fraud": { "intermediate": [ {"id": "microsoft/TabNet", "task": "tabular-classification"}, {"id": "IsolationForest/AnomalyDetection", "task": "tabular-classification"} ], "advanced": [ {"id": "pyod/AutoEncoder", "task": "tabular-classification"}, {"id": "GraphNeuralNetworks/FraudDetection", "task": "tabular-classification"} ] }, "forecast": { "intermediate": [ {"id": "facebook/prophet", "task": "time-series-forecasting"}, {"id": "statsmodels/ARIMA", "task": "time-series-forecasting"} ], "advanced": [ {"id": "microsoft/DeepAR", "task": "time-series-forecasting"}, {"id": "google/temporal-fusion-transformer", "task": "time-series-forecasting"} ] } } # Get curated models for the specific problem and complexity if problem_type in curated and complexity_level in curated[problem_type]: return curated[problem_type][complexity_level] elif problem_type in curated: # Fallback to any complexity level available for level in ["beginner", "intermediate", "advanced"]: if level in curated[problem_type]: return curated[problem_type][level] return [] # Add curated models if top_problems: primary_problem_name = top_problems[0][0] curated_models = get_curated_models(primary_problem_name, complexity) for curated_model in curated_models: if len(top_general) < 3: # Format as HuggingFace model dict formatted_model = { "id": curated_model["id"], "pipeline_tag": curated_model["task"], "downloads": 50000, # Reasonable default "tags": ["curated", "production-ready"] } top_general.append(formatted_model) # Add general foundation models if we still don't have enough if len(top_general) < 3: foundation_models = await get_foundation_models(primary_task, matched_keywords, api_key) top_general.extend(foundation_models[:3-len(top_general)]) # Format response with categories model_response = { "fine_tuned": [], "general": [] } # Enhanced model descriptions based on detected problem type def get_enhanced_model_description(model: Dict, model_type: str, problem_context: str = None) -> str: """Generate context-aware model descriptions""" model_name = model.get("id", "").lower() if model_type == "fine-tuned": if problem_context == "churn": return "Pre-trained model optimized for customer retention prediction" elif problem_context == "fraud": return "Specialized anomaly detection model for fraud identification" elif problem_context == "sentiment": return "Fine-tuned sentiment analysis model for text classification" elif problem_context == "forecast": return "Time series forecasting model for demand prediction" else: return f"Specialized model fine-tuned for {get_model_specialty(model, matched_keywords)}" else: # general if "curated" in str(model.get("tags", [])): return "Production-ready model recommended for business use cases" elif any(term in model_name for term in ["bert", "roberta", "distilbert"]): return "Transformer-based foundation model for fine-tuning" elif any(term in model_name for term in ["xgboost", "lightgbm", "catboost"]): return "Gradient boosting model excellent for tabular data" elif "prophet" in model_name: return "Facebook's time series forecasting framework" else: return f"Foundation model suitable for {primary_task.replace('-', ' ')}" # Format fine-tuned models with enhanced descriptions primary_problem_name = top_problems[0][0] if top_problems else None for model in top_fine_tuned: model_info = { "name": model.get("id", model.get("name", "Unknown")), "downloads": model.get("downloads", 0), "task": model.get("pipeline_tag", primary_task), "url": f"https://huggingface.co/{model.get('id', '')}", "type": "fine-tuned", "description": get_enhanced_model_description(model, "fine-tuned", primary_problem_name) } model_response["fine_tuned"].append(model_info) # Format general models with enhanced descriptions for model in top_general: model_info = { "name": model.get("id", model.get("name", "Unknown")), "downloads": model.get("downloads", 0), "task": model.get("pipeline_tag", primary_task), "url": f"https://huggingface.co/{model.get('id', '')}", "type": "general", "description": get_enhanced_model_description(model, "general", primary_problem_name) } model_response["general"].append(model_info) total_models = len(model_response["fine_tuned"]) + len(model_response["general"]) log_to_ui(f"๐Ÿ“ฆ Found {len(model_response['fine_tuned'])} fine-tuned + {len(model_response['general'])} general models") # Log details if model_response["fine_tuned"]: log_to_ui("๐ŸŽฏ Fine-tuned/Specialized models:") for i, model in enumerate(model_response["fine_tuned"], 1): log_to_ui(f" {i}. {model['name']} - {model['downloads']:,} downloads") if model_response["general"]: log_to_ui("๐Ÿ”ง General/Foundation models:") for i, model in enumerate(model_response["general"], 1): log_to_ui(f" {i}. {model['name']} - {model['downloads']:,} downloads") return model_response except Exception as e: log_to_ui(f"โŒ Error fetching HuggingFace models: {e}") return {"fine_tuned": [], "general": []} async def search_models_by_task(task: str, api_key: str = None) -> List[Dict]: """Search models by specific task""" try: headers = {} if api_key: headers["Authorization"] = f"Bearer {api_key}" response = requests.get( f"https://huggingface.co/api/models?pipeline_tag={task}&sort=downloads&limit=10", headers=headers, timeout=10 ) if response.status_code == 200: return response.json() except: pass return [] async def search_models_by_keyword(keyword: str, api_key: str = None) -> List[Dict]: """Search models by keyword in name/description""" try: headers = {} if api_key: headers["Authorization"] = f"Bearer {api_key}" response = requests.get( f"https://huggingface.co/api/models?search={keyword}&sort=downloads&limit=10", headers=headers, timeout=10 ) if response.status_code == 200: return response.json() except: pass return [] def calculate_model_relevance(model: Dict, query: str, keywords: List[str], complexity: str = "intermediate", preferred_models: List[str] = None) -> float: """Enhanced multi-criteria model relevance scoring""" score = 0 model_name = model.get("id", "").lower() model_task = model.get("pipeline_tag", "").lower() downloads = model.get("downloads", 0) # 1. Base popularity score (0-15 points) if downloads > 10000000: # 10M+ score += 15 elif downloads > 1000000: # 1M+ score += 12 elif downloads > 100000: # 100K+ score += 8 elif downloads > 10000: # 10K+ score += 5 elif downloads > 1000: # 1K+ score += 2 # 2. Direct keyword relevance (0-30 points) for keyword in keywords: if keyword in model_name: score += 25 # Check in model description/tags if available model_tags = model.get("tags", []) if any(keyword in str(tag).lower() for tag in model_tags): score += 15 # 3. Query term matches (0-20 points) query_words = [w for w in query.lower().split() if len(w) > 3] for word in query_words: if word in model_name: score += 8 if word in str(model.get("tags", [])).lower(): score += 5 # 4. Preferred model architecture bonus (0-25 points) if preferred_models: for preferred in preferred_models: if preferred.lower() in model_name: score += 20 break # Partial matches for preferred in preferred_models: if any(part in model_name for part in preferred.lower().split('_')): score += 10 break # 5. Task alignment (0-20 points) relevant_tasks = ["tabular-classification", "tabular-regression", "text-classification", "time-series-forecasting", "question-answering"] if model_task in relevant_tasks: score += 15 # 6. Complexity alignment (0-15 points) complexity_indicators = { "beginner": ["base", "simple", "basic", "distil", "small", "mini"], "intermediate": ["medium", "standard", "v2", "improved"], "advanced": ["large", "xl", "xxl", "advanced", "complex", "ensemble"] } if complexity in complexity_indicators: for indicator in complexity_indicators[complexity]: if indicator in model_name: score += 12 break # 7. Production readiness indicators (0-10 points) production_terms = ["production", "optimized", "efficient", "fast", "deployment"] for term in production_terms: if term in model_name: score += 8 break # 8. Penalties for problematic models (negative points) penalty_terms = ["nsfw", "adult", "sexual", "violence", "toxic", "unsafe", "experimental"] for term in penalty_terms: if term in model_name: score -= 30 # Generic model penalty generic_terms = ["general", "random", "test", "example", "demo"] for term in generic_terms: if term in model_name: score -= 10 # 9. Model quality indicators (0-10 points) quality_terms = ["sota", "benchmark", "award", "winner", "best", "top"] for term in quality_terms: if term in model_name or term in str(model.get("tags", [])).lower(): score += 8 break # 10. Recency bonus (0-5 points) - prefer newer models # This would require model creation date, approximating with model name patterns recent_indicators = ["2024", "2023", "v3", "v4", "v5", "latest", "new"] for indicator in recent_indicators: if indicator in model_name: score += 3 break return max(score, 0) def is_fine_tuned_model(model: Dict, keywords: List[str]) -> bool: """Determine if a model is fine-tuned/specialized for the specific task""" model_name = model.get("id", "").lower() # Models with specific task keywords in name are likely fine-tuned for keyword in keywords: if keyword in model_name: return True # Models with specific fine-tuning indicators fine_tuned_indicators = [ "fine-tuned", "ft", "finetuned", "specialized", "custom", "churn", "fraud", "sentiment", "classification", "detection", "prediction", "analytics", "recommendation", "recommender" ] for indicator in fine_tuned_indicators: if indicator in model_name: return True # Models from specific companies/domains (often specialized) domain_indicators = ["customer", "business", "financial", "ecommerce", "retail"] for domain in domain_indicators: if domain in model_name: return True return False def is_general_suitable_model(model: Dict, primary_task: str) -> bool: """Determine if a model is a general foundation model suitable for the task""" model_name = model.get("id", "").lower() model_task = model.get("pipeline_tag", "").lower() # Check if model task matches primary task if model_task == primary_task: return True # General foundation models (base models good for fine-tuning) foundation_indicators = [ "base", "large", "xlarge", "bert", "roberta", "distilbert", "electra", "albert", "transformer", "gpt", "t5", "bart", "deberta", "xlnet", "longformer" ] for indicator in foundation_indicators: if indicator in model_name and not any(x in model_name for x in ["nsfw", "safety", "toxicity"]): return True return False async def get_foundation_models(primary_task: str, keywords: List[str], api_key: str = None) -> List[Dict]: """Get well-known foundation models suitable for the task""" foundation_searches = [] if primary_task in ["text-classification", "token-classification"]: foundation_searches = [ "bert-base-uncased", "roberta-base", "distilbert-base-uncased", "microsoft/deberta-v3-base", "google/electra-base-discriminator" ] elif primary_task in ["tabular-classification", "tabular-regression"]: foundation_searches = [ "scikit-learn", "xgboost", "lightgbm", "catboost", "pytorch-tabular" ] elif primary_task in ["text-generation", "conversational"]: foundation_searches = [ "gpt2", "microsoft/DialoGPT-medium", "facebook/blenderbot" ] elif primary_task in ["question-answering"]: foundation_searches = [ "bert-base-uncased", "distilbert-base-uncased", "roberta-base" ] models = [] for search_term in foundation_searches[:3]: # Top 3 foundation searches try: headers = {} if api_key: headers["Authorization"] = f"Bearer {api_key}" response = requests.get( f"https://huggingface.co/api/models?search={search_term}&sort=downloads&limit=3", headers=headers, timeout=10 ) if response.status_code == 200: models.extend(response.json()) except: continue return models[:3] # Return top 3 def get_model_specialty(model: Dict, keywords: List[str]) -> str: """Get human-readable specialty description for a model""" model_name = model.get("id", "").lower() # Map keywords to descriptions specialty_map = { "churn": "customer churn prediction", "fraud": "fraud detection", "sentiment": "sentiment analysis", "recommendation": "recommendation systems", "classification": "classification tasks", "detection": "detection tasks", "prediction": "prediction tasks" } for keyword in keywords: if keyword in specialty_map: return specialty_map[keyword] # Fallback: try to infer from model name if "churn" in model_name: return "customer churn prediction" elif "fraud" in model_name: return "fraud detection" elif "sentiment" in model_name: return "sentiment analysis" elif "recommend" in model_name: return "recommendation systems" else: return "specialized ML tasks" # Serve static files app.mount("/static", StaticFiles(directory="static"), name="static") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860) # HF Spaces uses port 7860