hate / app.py
Desalegnn's picture
Update app.py
03960ab verified
raw
history blame contribute delete
3.52 kB
import numpy as np
import os
import pandas as pd
import torch
import matplotlib.pyplot as plt
from transformers import XLMRobertaModel, XLMRobertaTokenizer
import torch.nn as nn
import gradio as gr
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
# Load BERT model and tokenizer via HuggingFace Transformers
bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large')
tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large')
# Define the model architecture for three classes
class BERT_Arch(nn.Module):
def __init__(self, bert):
super(BERT_Arch, self).__init__()
self.bert = bert
self.dropout = nn.Dropout(0.1) # Dropout layer
self.relu = nn.ReLU() # ReLU activation function
self.fc1 = nn.Linear(768, 512) # Dense layer 1
self.fc2 = nn.Linear(512, 3) # Dense layer 2 (Output layer for 3 classes)
self.softmax = nn.LogSoftmax(dim=1) # Softmax activation function
def forward(self, sent_id, mask): # Define the forward pass
cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output']
x = self.fc1(cls_hs)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x) # Output layer
x = self.softmax(x) # Apply softmax activation
return x
# Load the model and set it to evaluation mode
model = BERT_Arch(bert)
fake_news_model_path = "Hate_Speech_model.pt"
fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu'))
fake_news_model.eval()
# Mapping labels to classes
LABELS = {0: "Free", 1: "Hate", 2: "Offensive"}
# Function to detect hate speech
def detect_fake_news(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask'])
label = torch.argmax(outputs, dim=1).item()
return LABELS[label]
# Function to handle post logic
def post_text(text, fake_news_result):
if fake_news_result in ["Hate", "Offensive"]:
return f"Your message contains {fake_news_result} Speech and cannot be posted.", ""
else:
return "The text is safe to post.", text
# Gradio Interface
interface = gr.Blocks()
with interface:
gr.Markdown("## Hate Speech Detection")
with gr.Row():
text_input = gr.Textbox(label="Enter Text", lines=5)
with gr.Row():
detect_fake_button = gr.Button("Detect Hate Speech")
with gr.Row():
fake_news_result_box = gr.Textbox(label="Hate Speech Detection Result", interactive=False)
with gr.Row():
post_button = gr.Button("Post Text")
with gr.Row():
post_result_box = gr.Textbox(label="Posting Status", interactive=False)
posted_text_box = gr.Textbox(label="Posted Text", interactive=False)
detect_fake_button.click(
fn=detect_fake_news,
inputs=text_input,
outputs=fake_news_result_box,
)
post_button.click(
fn=post_text,
inputs=[text_input, fake_news_result_box],
outputs=[post_result_box, posted_text_box],
)
# Launch the app
if __name__ == "__main__":
interface.launch()