import gradio as gr
from transformers import pipeline
import re
# Custom sentence tokenizer
def sent_tokenize(text):
sentence_endings = re.compile(r'(?
.analysis-table {
width: 100%;
border-collapse: collapse;
margin: 20px 0;
font-family: Arial, sans-serif;
}
.analysis-table th, .analysis-table td {
padding: 12px;
border: 1px solid #ddd;
text-align: left;
}
.analysis-table th {
background-color: #f5f5f5;
}
.global-prediction {
padding: 15px;
margin: 20px 0;
border-radius: 5px;
font-weight: bold;
}
.confidence {
font-size: 0.9em;
color: #666;
}
"""
# Add global prediction box with confidence
html += f"""
Global Prediction: {global_label}
(Confidence: {global_confidence:.2%})
"""
# Create table
html += """
Sentence |
Prediction |
Confidence |
"""
# Add rows for each sentence
for result in sentence_results:
html += f"""
{result['sentence']} |
{result['prediction']} |
{result['confidence']:.2%} |
"""
html += "
"
return html
def process_input(text_input, labels_or_premise, mode):
if mode == "Zero-Shot Classification":
labels = [label.strip() for label in labels_or_premise.split(',')]
prediction = zero_shot_classifier(text_input, labels)
results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
return results, ''
elif mode == "Natural Language Inference":
pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
results = {pred['label']: pred['score'] for pred in pred}
return results, ''
else: # Long Context NLI
# Global prediction
global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
global_results = {p['label']: p['score'] for p in global_pred}
global_label = max(global_results.items(), key=lambda x: x[1])[0]
global_confidence = max(global_results.values())
# Sentence-level analysis
sentences = sent_tokenize(text_input)
sentence_results = []
for sentence in sentences:
sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
# Get the prediction and confidence for the sentence
pred_scores = [(p['label'], p['score']) for p in sent_pred]
max_pred = max(pred_scores, key=lambda x: x[1])
max_label, confidence = max_pred
sentence_results.append({
'sentence': sentence,
'prediction': max_label,
'confidence': confidence
})
analysis_html = create_analysis_html(sentence_results, global_label, global_confidence)
return global_results, analysis_html
def update_interface(mode):
if mode == "Zero-Shot Classification":
return (
gr.update(
label="🏷️ Categories",
placeholder="Enter comma-separated categories...",
value=zero_shot_examples[0][1]
),
gr.update(value=zero_shot_examples[0][0])
)
elif mode == "Natural Language Inference":
return (
gr.update(
label="🔎 Hypothesis",
placeholder="Enter a hypothesis to compare with the premise...",
value=nli_examples[0][1]
),
gr.update(value=nli_examples[0][0])
)
else: # Long Context NLI
return (
gr.update(
label="🔎 Hypothesis",
placeholder="Enter a hypothesis to test against the full context...",
value=long_context_examples[0][1]
),
gr.update(value=long_context_examples[0][0])
)
def update_visibility(mode):
return (
gr.update(visible=(mode == "Zero-Shot Classification")),
gr.update(visible=(mode == "Natural Language Inference")),
gr.update(visible=(mode == "Long Context NLI"))
)
# Now define the Blocks interface
with gr.Blocks() as demo:
gr.Markdown("""
# tasksource/ModernBERT-nli demonstration
This space uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli),
fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
on tasksource classification tasks.
This NLI model achieves high accuracy on categorization, logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL (long-context NLI) and FOLIO (logical reasoning).
""")
mode = gr.Radio(
["Zero-Shot Classification", "Natural Language Inference", "Long Context NLI"],
label="Select Mode",
value="Zero-Shot Classification"
)
with gr.Column():
text_input = gr.Textbox(
label="✍️ Input Text",
placeholder="Enter your text...",
lines=3,
value=zero_shot_examples[0][0]
)
labels_or_premise = gr.Textbox(
label="🏷️ Categories",
placeholder="Enter comma-separated categories...",
lines=2,
value=zero_shot_examples[0][1]
)
submit_btn = gr.Button("Submit")
outputs = [
gr.Label(label="📊 Results"),
gr.HTML(label="📈 Sentence Analysis")
]
with gr.Column(variant="panel") as zero_shot_examples_panel:
gr.Examples(
examples=zero_shot_examples,
inputs=[text_input, labels_or_premise],
label="Zero-Shot Classification Examples",
)
with gr.Column(variant="panel") as nli_examples_panel:
gr.Examples(
examples=nli_examples,
inputs=[text_input, labels_or_premise],
label="Natural Language Inference Examples",
)
with gr.Column(variant="panel") as long_context_examples_panel:
gr.Examples(
examples=long_context_examples,
inputs=[text_input, labels_or_premise],
label="Long Context NLI Examples",
)
mode.change(
fn=update_interface,
inputs=[mode],
outputs=[labels_or_premise, text_input]
)
mode.change(
fn=update_visibility,
inputs=[mode],
outputs=[zero_shot_examples_panel, nli_examples_panel, long_context_examples_panel]
)
submit_btn.click(
fn=process_input,
inputs=[text_input, labels_or_premise, mode],
outputs=outputs
)
if __name__ == "__main__":
demo.launch()