File size: 5,072 Bytes
6c67d38
 
b90e13b
 
cd09eef
6c67d38
a389429
6c67d38
42f0920
6c67d38
42f0920
6c67d38
 
 
a389429
6c67d38
 
 
 
 
a389429
dc2875b
 
6c67d38
 
b355243
6b838f7
b355243
 
 
 
6b838f7
6c67d38
a389429
6c67d38
 
 
a389429
 
 
 
6c67d38
 
a389429
6c67d38
 
a389429
6c67d38
a389429
cd09eef
 
6c67d38
b355243
a389429
 
8f108e4
 
cd09eef
 
8f108e4
cd09eef
 
 
8f108e4
6c67d38
a389429
cd09eef
 
 
6c67d38
cd09eef
 
 
 
 
b355243
a389429
 
cd09eef
 
 
6c67d38
 
 
 
 
 
a389429
6c67d38
 
cd09eef
 
 
 
 
6c67d38
a389429
6c67d38
a389429
 
 
 
 
 
 
6c67d38
 
a389429
6c67d38
 
a389429
6c67d38
 
 
8f108e4
a389429
6c67d38
cd09eef
 
 
 
a389429
8f108e4
a389429
8f108e4
a389429
 
 
 
6c67d38
a389429
6c67d38
a389429
 
 
 
6c67d38
 
 
a389429
6c67d38
a389429
 
 
 
6c67d38
a389429
 
 
 
6c67d38
 
a36206a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import json, time, csv, os
import gradio as gr
from transformers import pipeline


# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Load taxonomies
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with open("coarse_labels.json") as f:
    coarse_labels = json.load(f)
with open("fine_labels.json") as f:
    fine_map = json.load(f)

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Model choices (5 only)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
MODEL_CHOICES = [
    "facebook/bart-large-mnli",
    "roberta-large-mnli",
    "joeddav/xlm-roberta-large-xnli",
    "valhalla/distilbart-mnli-12-4",
    "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" ,
    "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"# placeholder β€” replace with real phantom model
]

PIPELINES = {}

def get_pipeline(name):
    if name not in PIPELINES:
        PIPELINES[name] = pipeline("zero-shot-classification", model=name)
    return PIPELINES[name]

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Ensure log files exist
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
LOG_FILE      = "logs.csv"
FEEDBACK_FILE = "feedback.csv"
for fn, hdr in [
    (LOG_FILE,      ["timestamp","model","question","chosen_subject","top3_topics","duration"]),
    (FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"])
]:
    if not os.path.exists(fn):
        with open(fn, "w", newline="") as f:
            csv.writer(f).writerow(hdr)

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Inference functions
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def run_stage1(question, model_name):
    if not question or not question.strip():
        return {}, gr.update(choices=[]), ""
    start = time.time()
    clf = get_pipeline(model_name)
    out = clf(question, candidate_labels=coarse_labels)
    labels, scores = out["labels"][:3], out["scores"][:3]
    duration = round(time.time() - start, 3)

    # Prepare outputs
    subject_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
    radio_update = gr.update(choices=labels, value=labels[0])
    time_str    = f"⏱ {duration}s"

    return subject_dict, radio_update, time_str


def run_stage2(question, model_name, subject):
    # 1) Validate inputs
    if not question or not question.strip():
        return {}, "No question provided", ""
    fine_labels = fine_map.get(subject, [])
    if not fine_labels:
        return {}, f"No topics found for '{subject}'", ""

    # 2) Inference (fast, using preloaded pipeline)
    start = time.time()
    clf = get_pipeline(model_name)
    out = clf(question, candidate_labels=fine_labels)
    labels, scores = out["labels"][:3], out["scores"][:3]
    duration = round(time.time() - start, 3)

    # 3) Logging
    with open(LOG_FILE, "a", newline="") as f:
        csv.writer(f).writerow([
            time.strftime("%Y-%m-%d %H:%M:%S"),
            model_name,
            question.replace("\n"," "),
            subject,
            ";".join(labels),
            duration
        ])

    # 4) Return topics + time
    topic_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
    return topic_dict, f"⏱ {duration}s"


def submit_feedback(question, subject_fb, topic_fb):
    with open(FEEDBACK_FILE, "a", newline="") as f:
        csv.writer(f).writerow([
            time.strftime("%Y-%m-%d %H:%M:%S"),
            question.replace("\n"," "),
            subject_fb,
            topic_fb
        ])
    return "βœ… Feedback recorded!"

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Build Gradio UI
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with gr.Blocks() as demo:
    gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback")

    with gr.Row():
        question_input = gr.Textbox(lines=3, label="Enter your question")
        model_input    = gr.Dropdown(choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model")
        go_button      = gr.Button("Run Stage 1")

    subject_out = gr.Label(num_top_classes=3, label="Top-3 Subjects")
    subj_radio  = gr.Radio(choices=[], label="Select Subject for Stage 2")
    stage1_time = gr.Textbox(label="Stage 1 Time")
    
    go_button.click(
        fn=run_stage1,
        inputs=[question_input, model_input],
        outputs=[subject_out, subj_radio, stage1_time]
    )

    # Stage 2 UI
    go2_button  = gr.Button("Run Stage 2")
    topics_out  = gr.Label(label="Top-3 Topics")
    stage2_time = gr.Textbox(label="Stage 2 Time")

    go2_button.click(
        fn=run_stage2,
        inputs=[question_input, model_input, subj_radio],
        outputs=[topics_out, stage2_time]
    )

    gr.Markdown("---")
    gr.Markdown("### Feedback / Correction")

    subject_fb = gr.Textbox(label="Correct Subject")
    topic_fb   = gr.Textbox(label="Correct Topic(s)")
    fb_button  = gr.Button("Submit Feedback")
    fb_status  = gr.Textbox(label="")

    fb_button.click(
        fn=submit_feedback,
        inputs=[question_input, subject_fb, topic_fb],
        outputs=[fb_status]
    )

demo.launch(share=True)