|
import numpy as np |
|
import os |
|
import pandas as pd |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from transformers import XLMRobertaModel, XLMRobertaTokenizer |
|
import torch.nn as nn |
|
import gradio as gr |
|
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import classification_report |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification |
|
|
|
|
|
bert = XLMRobertaModel.from_pretrained('castorini/afriberta_large') |
|
tokenizer = XLMRobertaTokenizer.from_pretrained('castorini/afriberta_large') |
|
|
|
|
|
class BERT_Arch(nn.Module): |
|
def __init__(self, bert): |
|
super(BERT_Arch, self).__init__() |
|
self.bert = bert |
|
self.dropout = nn.Dropout(0.1) |
|
self.relu = nn.ReLU() |
|
self.fc1 = nn.Linear(768, 512) |
|
self.fc2 = nn.Linear(512, 3) |
|
self.softmax = nn.LogSoftmax(dim=1) |
|
|
|
def forward(self, sent_id, mask): |
|
cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output'] |
|
x = self.fc1(cls_hs) |
|
x = self.relu(x) |
|
x = self.dropout(x) |
|
x = self.fc2(x) |
|
x = self.softmax(x) |
|
return x |
|
|
|
|
|
model = BERT_Arch(bert) |
|
fake_news_model_path = "Hate_Speech_model.pt" |
|
fake_news_model = torch.load(fake_news_model_path, map_location=torch.device('cpu')) |
|
fake_news_model.eval() |
|
|
|
|
|
LABELS = {0: "Free", 1: "Hate", 2: "Offensive"} |
|
|
|
|
|
def detect_fake_news(text): |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
with torch.no_grad(): |
|
outputs = fake_news_model(inputs['input_ids'], inputs['attention_mask']) |
|
label = torch.argmax(outputs, dim=1).item() |
|
return LABELS[label] |
|
|
|
|
|
def post_text(text, fake_news_result): |
|
if fake_news_result in ["Hate", "Offensive"]: |
|
return f"Your message contains {fake_news_result} Speech and cannot be posted.", "" |
|
else: |
|
return "The text is safe to post.", text |
|
|
|
|
|
interface = gr.Blocks() |
|
with interface: |
|
gr.Markdown("## Hate Speech Detection") |
|
with gr.Row(): |
|
text_input = gr.Textbox(label="Enter Text", lines=5) |
|
with gr.Row(): |
|
detect_fake_button = gr.Button("Detect Hate Speech") |
|
with gr.Row(): |
|
fake_news_result_box = gr.Textbox(label="Hate Speech Detection Result", interactive=False) |
|
with gr.Row(): |
|
post_button = gr.Button("Post Text") |
|
with gr.Row(): |
|
post_result_box = gr.Textbox(label="Posting Status", interactive=False) |
|
posted_text_box = gr.Textbox(label="Posted Text", interactive=False) |
|
|
|
detect_fake_button.click( |
|
fn=detect_fake_news, |
|
inputs=text_input, |
|
outputs=fake_news_result_box, |
|
) |
|
|
|
post_button.click( |
|
fn=post_text, |
|
inputs=[text_input, fake_news_result_box], |
|
outputs=[post_result_box, posted_text_box], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|