FastAPI / nuse_modules /classifier.py
raghavNCI
classifier type changes
326a8da
raw
history blame
3.36 kB
import os
import re
import requests
from dotenv import load_dotenv
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
HF_ZERO_SHOT_MODEL = "facebook/bart-large-mnli"
# Map readable categories to numeric IDs
QUESTION_TYPES = {
"recent_update": 1,
"explainer": 2,
"timeline": 3,
"person_in_news": 4,
"policy_or_law": 5,
"election_poll": 6,
"business_market": 7,
"sports_event": 8,
"pop_culture": 9,
"science_health": 10,
"fact_check": 11,
"compare_entities": 12,
"small_talk": 13
}
REVERSE_MAP = {v: k for k, v in QUESTION_TYPES.items()}
# ---------- Step 1: Fast Rule-Based Classification ----------
def rule_based_classify(prompt: str) -> int:
p = prompt.lower()
if any(x in p for x in ["latest", "recent", "update", "breaking", "happened today"]):
return QUESTION_TYPES["recent_update"]
if any(x in p for x in ["explain", "why", "what is", "background", "summary"]):
return QUESTION_TYPES["explainer"]
if "timeline" in p or re.search(r"how .* changed", p):
return QUESTION_TYPES["timeline"]
if re.search(r"why .* in the news", p) or "trending" in p:
return QUESTION_TYPES["person_in_news"]
if any(x in p for x in ["bill", "policy", "law", "executive order", "passed", "signed"]):
return QUESTION_TYPES["policy_or_law"]
if any(x in p for x in ["election", "poll", "vote", "candidate", "ballot"]):
return QUESTION_TYPES["election_poll"]
if any(x in p for x in ["stock", "inflation", "economy", "market", "job report"]):
return QUESTION_TYPES["business_market"]
if any(x in p for x in ["score", "match", "tournament", "game", "league"]):
return QUESTION_TYPES["sports_event"]
if any(x in p for x in ["celebrity", "actor", "album", "movie", "music", "show"]):
return QUESTION_TYPES["pop_culture"]
if any(x in p for x in ["health", "covid", "science", "study", "research", "doctor"]):
return QUESTION_TYPES["science_health"]
if any(x in p for x in ["true", "false", "hoax", "real", "claim", "fact check"]):
return QUESTION_TYPES["fact_check"]
if any(x in p for x in ["compare", "vs", "difference between"]):
return QUESTION_TYPES["compare_entities"]
return -1
# ---------- Step 2: HF Zero-Shot Fallback ----------
def zero_shot_classify(prompt: str) -> int:
candidate_labels = list(QUESTION_TYPES.keys())
payload = {
"inputs": prompt,
"parameters": {
"candidate_labels": candidate_labels
}
}
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json"
}
url = f"https://api-inference.huggingface.co/models/{HF_ZERO_SHOT_MODEL}"
try:
res = requests.post(url, headers=headers, json=payload, timeout=20)
res.raise_for_status()
data = res.json()
if isinstance(data, dict) and "labels" in data:
top_label = data["labels"][0]
return QUESTION_TYPES.get(top_label, -1)
except Exception as e:
print("[HF Classifier Error]", str(e))
return -1
# ---------- Public Hybrid Classifier ----------
def classify_question(prompt: str) -> int:
rule_result = rule_based_classify(prompt)
if rule_result != -1:
return rule_result
return zero_shot_classify(prompt)