File size: 2,010 Bytes
6b7450a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# 🏷️ Zero-Shot Text Classification | CPU-only HF Space

import gradio as gr
from transformers import pipeline

# Load the zero-shot pipeline once at startup
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",
    device=-1        # CPU only
)

def zero_shot(text: str, labels: str, multi_label: bool):
    if not text.strip() or not labels.strip():
        return []
    # parse comma-separated labels
    candidate_list = [lbl.strip() for lbl in labels.split(",") if lbl.strip()]
    res = classifier(text, candidate_list, multi_label=multi_label)
    # build table of [label, score]
    table = [
        [label, round(score, 3)]
        for label, score in zip(res["labels"], res["scores"])
    ]
    return table

with gr.Blocks(title="🏷️ Zero-Shot Classifier") as demo:
    gr.Markdown(
        "# 🏷️ Zero-Shot Text Classification\n"
        "Paste any text, list your candidate labels (comma-separated),\n"
        "choose single- or multi-label mode, and see scores instantly."
    )

    with gr.Row():
        text_in = gr.Textbox(
            label="Input Text",
            lines=4,
            placeholder="e.g. The new conditioner left my hair incredibly soft!"
        )
        labels_in = gr.Textbox(
            label="Candidate Labels",
            lines=2,
            placeholder="e.g. Positive, Negative, Question, Feedback"
        )

    multi_in = gr.Checkbox(
        label="Multi-label classification",
        info="Assign multiple labels if checked; otherwise picks the top label."
    )

    run_btn = gr.Button("Classify 🏷️", variant="primary")

    result_df = gr.Dataframe(
        headers=["Label", "Score"],
        datatype=["str", "number"],
        interactive=False,
        wrap=True,
        label="Prediction Scores"
    )

    run_btn.click(
        zero_shot,
        inputs=[text_in, labels_in, multi_in],
        outputs=result_df
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")