File size: 2,537 Bytes
eafca75
 
 
 
f00f379
eafca75
f00f379
eafca75
 
 
 
 
 
 
 
 
 
 
 
 
 
326a8da
 
eafca75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00f379
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import re
from dotenv import load_dotenv

from models_initialization.bart_large_registry import run_zero_shot_classification

load_dotenv()

# Map readable categories to numeric IDs
QUESTION_TYPES = {
    "recent_update": 1,
    "explainer": 2,
    "timeline": 3,
    "person_in_news": 4,
    "policy_or_law": 5,
    "election_poll": 6,
    "business_market": 7,
    "sports_event": 8,
    "pop_culture": 9,
    "science_health": 10,
    "fact_check": 11,
    "compare_entities": 12,
    "small_talk": 13
}

REVERSE_MAP = {v: k for k, v in QUESTION_TYPES.items()}


# ---------- Step 1: Fast Rule-Based Classification ----------
def rule_based_classify(prompt: str) -> int:
    p = prompt.lower()

    if any(x in p for x in ["latest", "recent", "update", "breaking", "happened today"]):
        return QUESTION_TYPES["recent_update"]
    if any(x in p for x in ["explain", "why", "what is", "background", "summary"]):
        return QUESTION_TYPES["explainer"]
    if "timeline" in p or re.search(r"how .* changed", p):
        return QUESTION_TYPES["timeline"]
    if re.search(r"why .* in the news", p) or "trending" in p:
        return QUESTION_TYPES["person_in_news"]
    if any(x in p for x in ["bill", "policy", "law", "executive order", "passed", "signed"]):
        return QUESTION_TYPES["policy_or_law"]
    if any(x in p for x in ["election", "poll", "vote", "candidate", "ballot"]):
        return QUESTION_TYPES["election_poll"]
    if any(x in p for x in ["stock", "inflation", "economy", "market", "job report"]):
        return QUESTION_TYPES["business_market"]
    if any(x in p for x in ["score", "match", "tournament", "game", "league"]):
        return QUESTION_TYPES["sports_event"]
    if any(x in p for x in ["celebrity", "actor", "album", "movie", "music", "show"]):
        return QUESTION_TYPES["pop_culture"]
    if any(x in p for x in ["health", "covid", "science", "study", "research", "doctor"]):
        return QUESTION_TYPES["science_health"]
    if any(x in p for x in ["true", "false", "hoax", "real", "claim", "fact check"]):
        return QUESTION_TYPES["fact_check"]
    if any(x in p for x in ["compare", "vs", "difference between"]):
        return QUESTION_TYPES["compare_entities"]

    return -1


# ---------- Public Hybrid Classifier ----------
def classify_question(prompt: str) -> int:
    rule_result = rule_based_classify(prompt)
    if rule_result != -1:
        return rule_result
    return run_zero_shot_classification(prompt, candidate_labels=list(QUESTION_TYPES.keys()))