Spaces:
Runtime error
Runtime error
File size: 4,122 Bytes
b5473c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import torch
from torch import nn
from transformers import BertTokenizer, BertModel
# Define the BertClassifier class
class BertClassifier(nn.Module):
def __init__(self, bert: BertModel, num_classes: int):
super().__init__()
self.bert = bert
self.classifier = nn.Linear(bert.config.hidden_size, num_classes)
self.criterion = nn.BCELoss()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask
)
cls_output = outputs.pooler_output
cls_output = self.classifier(cls_output)
cls_output = torch.sigmoid(cls_output)
loss = 0
if labels is not None:
loss = self.criterion(cls_output, labels)
return loss, cls_output
# Load the tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
model = BertClassifier(bert_model, num_classes=7)
# Load the model weights from the .pkl file
model.load_state_dict(torch.load('bert_classifier_mltc.pkl', map_location=torch.device('cpu')))
model.eval()
# Define prediction function
def predict(text):
tokens = tokenizer.encode(text, add_special_tokens=True, max_length=512, truncation=True)
input_ids = torch.tensor([tokens])
mask = (input_ids != tokenizer.pad_token_id).float()
with torch.no_grad():
_, outputs = model(input_ids, attention_mask=mask)
# Assuming outputs[0] contains the probability scores for each class
confidence_scores = outputs[0].tolist()
# Convert to a dictionary mapping labels to confidence scores
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', 'clean']
prediction = {label: score for label, score in zip(labels, confidence_scores)}
return prediction
# Add example texts
examples = [
["This is a positive and friendly comment."],
["I'm not happy with your service."],
["You are a great person!"],
["How can one defame someone who thinks the Fort Hood Shooting is justified by God's law?"],
["A Bisexual, like a homosexual or a heterosexual, is not defined by sexual activity. (Much like a 15 year old boy who is attracted to a girl sexually but has never had sex is still straight). A person who is actually sexually attracted/aroused by the same sex as well as the opposite sex is bisexual."],
["I'LL SAY NIGGER IF I WANT TO BRO"],
["Do you need a verifiable source for that statement? Check your own user page you fat, miserable balding fuck."],
["Fuck you bitch! I will go to your house and kill you with my gun as well! Prepare to die!"],
["you're hot, i will rape you ;)"],
["WOULDN'T BE THE FIRST TIME BITCH. FUCK YOU I'LL FIND OUT WHERE YOU LIVE, SODOMIZE YOUR WIFE AND THEN BURN YOUR HOUSE DOWN. FUCK YOU YOU FUCKING QUEER."],
["What a motherfucking piece of crap those fuckheads for blocking us!"],
["Get a life you animal fucker. Shut up you fucking nigger. Fuck off and shit your pants full of all the shit you can fill them with. 144.131.176.126"],
["HOPE YOUR HEAD GETS CUT OFF AND SOMEONE WIPS THERE ASS WITH IT AND THEN STABS YOU IN YOUR HEART"],
["you people are pretty overzealous with this whole free thing. get a fucking life, you fucking niggers !!!23 16!!!"],
["Stupid peace of shit stop deleting my stuff asshole go die and fall in a hole go to hell!"],
["Bye! Don't look, come or think of comming back! Tosser."]
]
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=10, placeholder="Enter a comment here..."),
outputs=gr.Label(num_top_classes=7),
examples=examples,
title="Toxic Comment Classification",
description="Classify comments into toxic and non-toxic categories using BERT and GNN model.",
)
iface.launch()
|