|
|
|
|
|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline( |
|
"zero-shot-classification", |
|
model="facebook/bart-large-mnli", |
|
device=-1 |
|
) |
|
|
|
def zero_shot(text: str, labels: str, multi_label: bool): |
|
if not text.strip() or not labels.strip(): |
|
return [] |
|
|
|
candidate_list = [lbl.strip() for lbl in labels.split(",") if lbl.strip()] |
|
res = classifier(text, candidate_list, multi_label=multi_label) |
|
|
|
table = [ |
|
[label, round(score, 3)] |
|
for label, score in zip(res["labels"], res["scores"]) |
|
] |
|
return table |
|
|
|
with gr.Blocks(title="🏷️ Zero-Shot Classifier") as demo: |
|
gr.Markdown( |
|
"# 🏷️ Zero-Shot Text Classification\n" |
|
"Paste any text, list your candidate labels (comma-separated),\n" |
|
"choose single- or multi-label mode, and see scores instantly." |
|
) |
|
|
|
with gr.Row(): |
|
text_in = gr.Textbox( |
|
label="Input Text", |
|
lines=4, |
|
placeholder="e.g. The new conditioner left my hair incredibly soft!" |
|
) |
|
labels_in = gr.Textbox( |
|
label="Candidate Labels", |
|
lines=2, |
|
placeholder="e.g. Positive, Negative, Question, Feedback" |
|
) |
|
|
|
multi_in = gr.Checkbox( |
|
label="Multi-label classification", |
|
info="Assign multiple labels if checked; otherwise picks the top label." |
|
) |
|
|
|
run_btn = gr.Button("Classify 🏷️", variant="primary") |
|
|
|
result_df = gr.Dataframe( |
|
headers=["Label", "Score"], |
|
datatype=["str", "number"], |
|
interactive=False, |
|
wrap=True, |
|
label="Prediction Scores" |
|
) |
|
|
|
run_btn.click( |
|
zero_shot, |
|
inputs=[text_in, labels_in, multi_in], |
|
outputs=result_df |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0") |
|
|