Spaces:
Running
Running
File size: 5,072 Bytes
6c67d38 b90e13b cd09eef 6c67d38 a389429 6c67d38 42f0920 6c67d38 42f0920 6c67d38 a389429 6c67d38 a389429 dc2875b 6c67d38 b355243 6b838f7 b355243 6b838f7 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 cd09eef 6c67d38 b355243 a389429 8f108e4 cd09eef 8f108e4 cd09eef 8f108e4 6c67d38 a389429 cd09eef 6c67d38 cd09eef b355243 a389429 cd09eef 6c67d38 a389429 6c67d38 cd09eef 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 8f108e4 a389429 6c67d38 cd09eef a389429 8f108e4 a389429 8f108e4 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a36206a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import json, time, csv, os
import gradio as gr
from transformers import pipeline
# ββββββββββββββββ
# Load taxonomies
# ββββββββββββββββ
with open("coarse_labels.json") as f:
coarse_labels = json.load(f)
with open("fine_labels.json") as f:
fine_map = json.load(f)
# ββββββββββββββββ
# Model choices (5 only)
# ββββββββββββββββ
MODEL_CHOICES = [
"facebook/bart-large-mnli",
"roberta-large-mnli",
"joeddav/xlm-roberta-large-xnli",
"valhalla/distilbart-mnli-12-4",
"MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" ,
"MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"# placeholder β replace with real phantom model
]
PIPELINES = {}
def get_pipeline(name):
if name not in PIPELINES:
PIPELINES[name] = pipeline("zero-shot-classification", model=name)
return PIPELINES[name]
# ββββββββββββββββ
# Ensure log files exist
# ββββββββββββββββ
LOG_FILE = "logs.csv"
FEEDBACK_FILE = "feedback.csv"
for fn, hdr in [
(LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]),
(FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"])
]:
if not os.path.exists(fn):
with open(fn, "w", newline="") as f:
csv.writer(f).writerow(hdr)
# ββββββββββββββββ
# Inference functions
# ββββββββββββββββ
def run_stage1(question, model_name):
if not question or not question.strip():
return {}, gr.update(choices=[]), ""
start = time.time()
clf = get_pipeline(model_name)
out = clf(question, candidate_labels=coarse_labels)
labels, scores = out["labels"][:3], out["scores"][:3]
duration = round(time.time() - start, 3)
# Prepare outputs
subject_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
radio_update = gr.update(choices=labels, value=labels[0])
time_str = f"β± {duration}s"
return subject_dict, radio_update, time_str
def run_stage2(question, model_name, subject):
# 1) Validate inputs
if not question or not question.strip():
return {}, "No question provided", ""
fine_labels = fine_map.get(subject, [])
if not fine_labels:
return {}, f"No topics found for '{subject}'", ""
# 2) Inference (fast, using preloaded pipeline)
start = time.time()
clf = get_pipeline(model_name)
out = clf(question, candidate_labels=fine_labels)
labels, scores = out["labels"][:3], out["scores"][:3]
duration = round(time.time() - start, 3)
# 3) Logging
with open(LOG_FILE, "a", newline="") as f:
csv.writer(f).writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
model_name,
question.replace("\n"," "),
subject,
";".join(labels),
duration
])
# 4) Return topics + time
topic_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
return topic_dict, f"β± {duration}s"
def submit_feedback(question, subject_fb, topic_fb):
with open(FEEDBACK_FILE, "a", newline="") as f:
csv.writer(f).writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
question.replace("\n"," "),
subject_fb,
topic_fb
])
return "β
Feedback recorded!"
# ββββββββββββββββ
# Build Gradio UI
# ββββββββββββββββ
with gr.Blocks() as demo:
gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback")
with gr.Row():
question_input = gr.Textbox(lines=3, label="Enter your question")
model_input = gr.Dropdown(choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model")
go_button = gr.Button("Run Stage 1")
subject_out = gr.Label(num_top_classes=3, label="Top-3 Subjects")
subj_radio = gr.Radio(choices=[], label="Select Subject for Stage 2")
stage1_time = gr.Textbox(label="Stage 1 Time")
go_button.click(
fn=run_stage1,
inputs=[question_input, model_input],
outputs=[subject_out, subj_radio, stage1_time]
)
# Stage 2 UI
go2_button = gr.Button("Run Stage 2")
topics_out = gr.Label(label="Top-3 Topics")
stage2_time = gr.Textbox(label="Stage 2 Time")
go2_button.click(
fn=run_stage2,
inputs=[question_input, model_input, subj_radio],
outputs=[topics_out, stage2_time]
)
gr.Markdown("---")
gr.Markdown("### Feedback / Correction")
subject_fb = gr.Textbox(label="Correct Subject")
topic_fb = gr.Textbox(label="Correct Topic(s)")
fb_button = gr.Button("Submit Feedback")
fb_status = gr.Textbox(label="")
fb_button.click(
fn=submit_feedback,
inputs=[question_input, subject_fb, topic_fb],
outputs=[fb_status]
)
demo.launch(share=True)
|