import torch import tensorflow as tf from tf_keras import models, layers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering import gradio as gr import re # Check if GPU is available and use it if possible device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the model and tokenizer mme_model_name = 'sperkins2116/ConfliBERT-BC-MMEs' mme_model = AutoModelForSequenceClassification.from_pretrained(mme_model_name).to(device) mme_tokenizer = AutoTokenizer.from_pretrained(mme_model_name) # Define the class names for text classification class_names = ['Negative', 'Positive'] def handle_error_message(e, default_limit=512): error_message = str(e) pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)") match = pattern.search(error_message) if match: number_1, number_2 = match.groups() return f"Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}" return f"Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}" def mme_classification(text): try: inputs = mme_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device) with torch.no_grad(): outputs = mme_model(**inputs) logits = outputs.logits.squeeze().tolist() predicted_class = torch.argmax(outputs.logits, dim=1).item() confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100 if predicted_class == 1: # Positive class result = f"Positive: The text contains evidence of a multinational military exercise. (Confidence: {confidence:.2f}%)" else: # Negative class result = f"Negative: The text does not contain evidence of a multinational military exercise. (Confidence: {confidence:.2f}%)" return result except Exception as e: return handle_error_message(e) # Define the Gradio interface def chatbot(text): return mme_classification(text) css = """ body { background-color: #f0f8ff; font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; color: black; /* Ensure text is visible in dark mode */ } h1 { color: #2e8b57; text-align: center; font-size: 2em; } h2 { color: #ff8c00; text-align: center; font-size: 1.5em; } .gradio-container { max-width: 100%; margin: 10px auto; padding: 10px; background-color: #ffffff; border-radius: 10px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); } .gr-input, .gr-output { background-color: #ffffff; border: 1px solid #ddd; border-radius: 5px; padding: 10px; font-size: 1em; color: black; /* Ensure text is visible in dark mode */ } .gr-title { font-size: 1.5em; font-weight: bold; color: #2e8b57; margin-bottom: 10px; text-align: center; } .gr-description { font-size: 1.2em; color: #ff8c00; margin-bottom: 10px; text-align: center; } .header { display: flex; justify-content: center; align-items: center; padding: 10px; flex-wrap: wrap; } .header-title-center a { font-size: 4em; /* Increased font size */ font-weight: bold; /* Made text bold */ color: darkorange; /* Darker orange color */ text-align: center; display: block; } .gr-button { background-color: #ff8c00; color: white; border: none; padding: 10px 20px; font-size: 1em; border-radius: 5px; cursor: pointer; } .gr-button:hover { background-color: #ff4500; } .footer { text-align: center; margin-top: 10px; font-size: 0.9em; /* Updated font size */ color: black; /* Ensure text is visible in dark mode */ width: 100%; } .footer a { color: #2e8b57; font-weight: bold; text-decoration: none; } .footer a:hover { text-decoration: underline; } .footer .inline { display: inline; color: black; /* Ensure text is visible in dark mode */ } """ with gr.Blocks(css=css) as demo: with gr.Row(elem_id="header"): gr.Markdown("
ConfliBERT-MME
", elem_id="header-title-center") gr.Markdown("Provide the text for MME Classification.") text_input = gr.Textbox(lines=5, placeholder="Enter the text here...", label="Text") output = gr.HTML(label="Output") submit_button = gr.Button("Submit", elem_id="gr-button") submit_button.click(fn=chatbot, inputs=text_input, outputs=output) gr.Markdown("") gr.Markdown("") demo.launch(share=True)