|
import os |
|
import re |
|
from dotenv import load_dotenv |
|
|
|
from models_initialization.bart_large_registry import run_zero_shot_classification |
|
|
|
load_dotenv() |
|
|
|
|
|
QUESTION_TYPES = { |
|
"recent_update": 1, |
|
"explainer": 2, |
|
"timeline": 3, |
|
"person_in_news": 4, |
|
"policy_or_law": 5, |
|
"election_poll": 6, |
|
"business_market": 7, |
|
"sports_event": 8, |
|
"pop_culture": 9, |
|
"science_health": 10, |
|
"fact_check": 11, |
|
"compare_entities": 12, |
|
"small_talk": 13 |
|
} |
|
|
|
REVERSE_MAP = {v: k for k, v in QUESTION_TYPES.items()} |
|
|
|
|
|
|
|
def rule_based_classify(prompt: str) -> int: |
|
p = prompt.lower() |
|
|
|
if any(x in p for x in ["latest", "recent", "update", "breaking", "happened today"]): |
|
return QUESTION_TYPES["recent_update"] |
|
if any(x in p for x in ["explain", "why", "what is", "background", "summary"]): |
|
return QUESTION_TYPES["explainer"] |
|
if "timeline" in p or re.search(r"how .* changed", p): |
|
return QUESTION_TYPES["timeline"] |
|
if re.search(r"why .* in the news", p) or "trending" in p: |
|
return QUESTION_TYPES["person_in_news"] |
|
if any(x in p for x in ["bill", "policy", "law", "executive order", "passed", "signed"]): |
|
return QUESTION_TYPES["policy_or_law"] |
|
if any(x in p for x in ["election", "poll", "vote", "candidate", "ballot"]): |
|
return QUESTION_TYPES["election_poll"] |
|
if any(x in p for x in ["stock", "inflation", "economy", "market", "job report"]): |
|
return QUESTION_TYPES["business_market"] |
|
if any(x in p for x in ["score", "match", "tournament", "game", "league"]): |
|
return QUESTION_TYPES["sports_event"] |
|
if any(x in p for x in ["celebrity", "actor", "album", "movie", "music", "show"]): |
|
return QUESTION_TYPES["pop_culture"] |
|
if any(x in p for x in ["health", "covid", "science", "study", "research", "doctor"]): |
|
return QUESTION_TYPES["science_health"] |
|
if any(x in p for x in ["true", "false", "hoax", "real", "claim", "fact check"]): |
|
return QUESTION_TYPES["fact_check"] |
|
if any(x in p for x in ["compare", "vs", "difference between"]): |
|
return QUESTION_TYPES["compare_entities"] |
|
|
|
return -1 |
|
|
|
|
|
|
|
def classify_question(prompt: str) -> int: |
|
rule_result = rule_based_classify(prompt) |
|
if rule_result != -1: |
|
return rule_result |
|
return run_zero_shot_classification(prompt, candidate_labels=list(QUESTION_TYPES.keys())) |
|
|