import numpy as np import os import pandas as pd import torch import matplotlib.pyplot as plt from transformers import XLMRobertaModel, XLMRobertaTokenizer import torch.nn as nn import gradio as gr from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification # Load BERT model and tokenizer via HuggingFace Transformers bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large') tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large') # Define the model architecture for three classes class BERT_Arch(nn.Module): def __init__(self, bert): super(BERT_Arch, self).__init__() self.bert = bert self.dropout = nn.Dropout(0.1) # Dropout layer self.relu = nn.ReLU() # ReLU activation function self.fc1 = nn.Linear(768, 512) # Dense layer 1 self.fc2 = nn.Linear(512, 3) # Dense layer 2 (Output layer for 3 classes) self.softmax = nn.LogSoftmax(dim=1) # Softmax activation function def forward(self, sent_id, mask): # Define the forward pass cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output'] x = self.fc1(cls_hs) x = self.relu(x) x = self.dropout(x) x = self.fc2(x) # Output layer x = self.softmax(x) # Apply softmax activation return x # Load the model and set it to evaluation mode model = BERT_Arch(bert) fake_news_model_path = "Hate_Speech_model.pt" fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu')) fake_news_model.eval() # Mapping labels to classes LABELS = {0: "Free", 1: "Hate", 2: "Offensive"} # Function to detect hate speech def detect_fake_news(text): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) with torch.no_grad(): outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask']) label = torch.argmax(outputs, dim=1).item() return LABELS[label] # Function to handle post logic def post_text(text, fake_news_result): if fake_news_result in ["Hate", "Offensive"]: return f"Your message contains {fake_news_result} Speech and cannot be posted.", "" else: return "The text is safe to post.", text # Gradio Interface interface = gr.Blocks() with interface: gr.Markdown("## Hate Speech Detection") with gr.Row(): text_input = gr.Textbox(label="Enter Text", lines=5) with gr.Row(): detect_fake_button = gr.Button("Detect Hate Speech") with gr.Row(): fake_news_result_box = gr.Textbox(label="Hate Speech Detection Result", interactive=False) with gr.Row(): post_button = gr.Button("Post Text") with gr.Row(): post_result_box = gr.Textbox(label="Posting Status", interactive=False) posted_text_box = gr.Textbox(label="Posted Text", interactive=False) detect_fake_button.click( fn=detect_fake_news, inputs=text_input, outputs=fake_news_result_box, ) post_button.click( fn=post_text, inputs=[text_input, fake_news_result_box], outputs=[post_result_box, posted_text_box], ) # Launch the app if __name__ == "__main__": interface.launch()