FastAPI / nuse_modules /classifier.py
raghavNCI
too many changes, hope this works
f00f379
raw
history blame
2.54 kB
import os
import re
from dotenv import load_dotenv
from models_initialization.bart_large_registry import run_zero_shot_classification
load_dotenv()
# Map readable categories to numeric IDs
QUESTION_TYPES = {
"recent_update": 1,
"explainer": 2,
"timeline": 3,
"person_in_news": 4,
"policy_or_law": 5,
"election_poll": 6,
"business_market": 7,
"sports_event": 8,
"pop_culture": 9,
"science_health": 10,
"fact_check": 11,
"compare_entities": 12,
"small_talk": 13
}
REVERSE_MAP = {v: k for k, v in QUESTION_TYPES.items()}
# ---------- Step 1: Fast Rule-Based Classification ----------
def rule_based_classify(prompt: str) -> int:
p = prompt.lower()
if any(x in p for x in ["latest", "recent", "update", "breaking", "happened today"]):
return QUESTION_TYPES["recent_update"]
if any(x in p for x in ["explain", "why", "what is", "background", "summary"]):
return QUESTION_TYPES["explainer"]
if "timeline" in p or re.search(r"how .* changed", p):
return QUESTION_TYPES["timeline"]
if re.search(r"why .* in the news", p) or "trending" in p:
return QUESTION_TYPES["person_in_news"]
if any(x in p for x in ["bill", "policy", "law", "executive order", "passed", "signed"]):
return QUESTION_TYPES["policy_or_law"]
if any(x in p for x in ["election", "poll", "vote", "candidate", "ballot"]):
return QUESTION_TYPES["election_poll"]
if any(x in p for x in ["stock", "inflation", "economy", "market", "job report"]):
return QUESTION_TYPES["business_market"]
if any(x in p for x in ["score", "match", "tournament", "game", "league"]):
return QUESTION_TYPES["sports_event"]
if any(x in p for x in ["celebrity", "actor", "album", "movie", "music", "show"]):
return QUESTION_TYPES["pop_culture"]
if any(x in p for x in ["health", "covid", "science", "study", "research", "doctor"]):
return QUESTION_TYPES["science_health"]
if any(x in p for x in ["true", "false", "hoax", "real", "claim", "fact check"]):
return QUESTION_TYPES["fact_check"]
if any(x in p for x in ["compare", "vs", "difference between"]):
return QUESTION_TYPES["compare_entities"]
return -1
# ---------- Public Hybrid Classifier ----------
def classify_question(prompt: str) -> int:
rule_result = rule_based_classify(prompt)
if rule_result != -1:
return rule_result
return run_zero_shot_classification(prompt, candidate_labels=list(QUESTION_TYPES.keys()))