import os import re from dotenv import load_dotenv from models_initialization.bart_large_registry import run_zero_shot_classification load_dotenv() # Map readable categories to numeric IDs QUESTION_TYPES = { "recent_update": 1, "explainer": 2, "timeline": 3, "person_in_news": 4, "policy_or_law": 5, "election_poll": 6, "business_market": 7, "sports_event": 8, "pop_culture": 9, "science_health": 10, "fact_check": 11, "compare_entities": 12, "small_talk": 13 } REVERSE_MAP = {v: k for k, v in QUESTION_TYPES.items()} # ---------- Step 1: Fast Rule-Based Classification ---------- def rule_based_classify(prompt: str) -> int: p = prompt.lower() if any(x in p for x in ["latest", "recent", "update", "breaking", "happened today"]): return QUESTION_TYPES["recent_update"] if any(x in p for x in ["explain", "why", "what is", "background", "summary"]): return QUESTION_TYPES["explainer"] if "timeline" in p or re.search(r"how .* changed", p): return QUESTION_TYPES["timeline"] if re.search(r"why .* in the news", p) or "trending" in p: return QUESTION_TYPES["person_in_news"] if any(x in p for x in ["bill", "policy", "law", "executive order", "passed", "signed"]): return QUESTION_TYPES["policy_or_law"] if any(x in p for x in ["election", "poll", "vote", "candidate", "ballot"]): return QUESTION_TYPES["election_poll"] if any(x in p for x in ["stock", "inflation", "economy", "market", "job report"]): return QUESTION_TYPES["business_market"] if any(x in p for x in ["score", "match", "tournament", "game", "league"]): return QUESTION_TYPES["sports_event"] if any(x in p for x in ["celebrity", "actor", "album", "movie", "music", "show"]): return QUESTION_TYPES["pop_culture"] if any(x in p for x in ["health", "covid", "science", "study", "research", "doctor"]): return QUESTION_TYPES["science_health"] if any(x in p for x in ["true", "false", "hoax", "real", "claim", "fact check"]): return QUESTION_TYPES["fact_check"] if any(x in p for x in ["compare", "vs", "difference between"]): return QUESTION_TYPES["compare_entities"] return -1 # ---------- Public Hybrid Classifier ---------- def classify_question(prompt: str) -> int: rule_result = rule_based_classify(prompt) if rule_result != -1: return rule_result return run_zero_shot_classification(prompt, candidate_labels=list(QUESTION_TYPES.keys()))