import gradio as gr from transformers import pipeline import re # Custom sentence tokenizer def sent_tokenize(text): sentence_endings = re.compile(r'(? .analysis-table { width: 100%; border-collapse: collapse; margin: 20px 0; font-family: Arial, sans-serif; } .analysis-table th, .analysis-table td { padding: 12px; border: 1px solid #ddd; text-align: left; } .analysis-table th { background-color: #f5f5f5; } .global-prediction { padding: 15px; margin: 20px 0; border-radius: 5px; font-weight: bold; } .confidence { font-size: 0.9em; color: #666; } """ # Add global prediction box with confidence html += f"""
Global Prediction: {global_label} (Confidence: {global_confidence:.2%})
""" # Create table html += """ """ # Add rows for each sentence for result in sentence_results: html += f""" """ html += "
Sentence Prediction Confidence
{result['sentence']} {result['prediction']} {result['confidence']:.2%}
" return html def process_input(text_input, labels_or_premise, mode): if mode == "Zero-Shot Classification": labels = [label.strip() for label in labels_or_premise.split(',')] prediction = zero_shot_classifier(text_input, labels) results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])} return results, '' elif mode == "Natural Language Inference": pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0] results = {pred['label']: pred['score'] for pred in pred} return results, '' else: # Long Context NLI # Global prediction global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0] global_results = {p['label']: p['score'] for p in global_pred} global_label = max(global_results.items(), key=lambda x: x[1])[0] global_confidence = max(global_results.values()) # Sentence-level analysis sentences = sent_tokenize(text_input) sentence_results = [] for sentence in sentences: sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0] # Get the prediction and confidence for the sentence pred_scores = [(p['label'], p['score']) for p in sent_pred] max_pred = max(pred_scores, key=lambda x: x[1]) max_label, confidence = max_pred sentence_results.append({ 'sentence': sentence, 'prediction': max_label, 'confidence': confidence }) analysis_html = create_analysis_html(sentence_results, global_label, global_confidence) return global_results, analysis_html def update_interface(mode): if mode == "Zero-Shot Classification": return ( gr.update( label="🏷️ Categories", placeholder="Enter comma-separated categories...", value=zero_shot_examples[0][1] ), gr.update(value=zero_shot_examples[0][0]) ) elif mode == "Natural Language Inference": return ( gr.update( label="🔎 Hypothesis", placeholder="Enter a hypothesis to compare with the premise...", value=nli_examples[0][1] ), gr.update(value=nli_examples[0][0]) ) else: # Long Context NLI return ( gr.update( label="🔎 Hypothesis", placeholder="Enter a hypothesis to test against the full context...", value=long_context_examples[0][1] ), gr.update(value=long_context_examples[0][0]) ) def update_visibility(mode): return ( gr.update(visible=(mode == "Zero-Shot Classification")), gr.update(visible=(mode == "Natural Language Inference")), gr.update(visible=(mode == "Long Context NLI")) ) # Now define the Blocks interface with gr.Blocks() as demo: gr.Markdown(""" # tasksource/ModernBERT-nli demonstration This space uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli), fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on tasksource classification tasks. This NLI model achieves high accuracy on categorization, logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL (long-context NLI) and FOLIO (logical reasoning). """) mode = gr.Radio( ["Zero-Shot Classification", "Natural Language Inference", "Long Context NLI"], label="Select Mode", value="Zero-Shot Classification" ) with gr.Column(): text_input = gr.Textbox( label="✍️ Input Text", placeholder="Enter your text...", lines=3, value=zero_shot_examples[0][0] ) labels_or_premise = gr.Textbox( label="🏷️ Categories", placeholder="Enter comma-separated categories...", lines=2, value=zero_shot_examples[0][1] ) submit_btn = gr.Button("Submit") outputs = [ gr.Label(label="📊 Results"), gr.HTML(label="📈 Sentence Analysis") ] with gr.Column(variant="panel") as zero_shot_examples_panel: gr.Examples( examples=zero_shot_examples, inputs=[text_input, labels_or_premise], label="Zero-Shot Classification Examples", ) with gr.Column(variant="panel") as nli_examples_panel: gr.Examples( examples=nli_examples, inputs=[text_input, labels_or_premise], label="Natural Language Inference Examples", ) with gr.Column(variant="panel") as long_context_examples_panel: gr.Examples( examples=long_context_examples, inputs=[text_input, labels_or_premise], label="Long Context NLI Examples", ) mode.change( fn=update_interface, inputs=[mode], outputs=[labels_or_premise, text_input] ) mode.change( fn=update_visibility, inputs=[mode], outputs=[zero_shot_examples_panel, nli_examples_panel, long_context_examples_panel] ) submit_btn.click( fn=process_input, inputs=[text_input, labels_or_premise, mode], outputs=outputs ) if __name__ == "__main__": demo.launch()