Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch.nn as nn | |
import torch | |
from transformers import BertTokenizerFast as BertTokenizer, BertModel | |
import pytorch_lightning as pl | |
BERT_MODEL_NAME = 'bert-base-cased' | |
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) | |
LABEL_COLUMNS = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
class ToxicCommentTagger(pl.LightningModule): | |
def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None): | |
super().__init__() | |
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True) | |
self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes) | |
self.n_training_steps = n_training_steps | |
self.n_warmup_steps = n_warmup_steps | |
self.criterion = nn.BCELoss() | |
def predict(model, tokenizer, sentence): | |
encoding = tokenizer.encode_plus( | |
sentence, | |
add_special_tokens=False, | |
max_length=510, | |
return_token_type_ids=False, | |
padding="max_length", | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
# define target chunksize | |
chunksize = 512 | |
# split into chunks of 510 tokens, we also convert to list (default is tuple which is immutable) | |
input_id_chunks = list(encoding['input_ids'][0].split(chunksize - 2)) | |
mask_chunks = list(encoding['attention_mask'][0].split(chunksize - 2)) | |
# loop through each chunk | |
for i in range(len(input_id_chunks)): | |
# add CLS and SEP tokens to input IDs | |
input_id_chunks[i] = torch.cat([ | |
torch.tensor([101]), input_id_chunks[i], torch.tensor([102]) | |
]) | |
# add attention tokens to attention mask | |
mask_chunks[i] = torch.cat([ | |
torch.tensor([1]), mask_chunks[i], torch.tensor([1]) | |
]) | |
# get required padding length | |
pad_len = chunksize - input_id_chunks[i].shape[0] | |
# check if tensor length satisfies required chunk size | |
if pad_len > 0: | |
# if padding length is more than 0, we must add padding | |
input_id_chunks[i] = torch.cat([ | |
input_id_chunks[i], torch.Tensor([0] * pad_len) | |
]) | |
mask_chunks[i] = torch.cat([ | |
mask_chunks[i], torch.Tensor([0] * pad_len) | |
]) | |
input_ids = torch.stack(input_id_chunks) | |
attention_mask = torch.stack(mask_chunks) | |
input_dict = { | |
'input_ids': input_ids.long(), | |
'attention_mask': attention_mask.int() | |
} | |
_, test_prediction = model(**input_dict) | |
test_prediction = test_prediction.numpy() | |
output = {} | |
for chunk in test_prediction: | |
for label, prediction in zip(LABEL_COLUMNS, chunk): | |
if label in output: | |
output[label] = max(prediction, output[label]) | |
else: | |
output[label] = prediction | |
return output | |
model = ToxicCommentTagger.load_from_checkpoint( | |
'/content/drive/MyDrive/checkpoints/best-checkpoint.ckpt', | |
n_classes=len(LABEL_COLUMNS) | |
) | |
model.eval() | |
model.freeze() | |
all_categories = {'all_categories': [ | |
'toxicity', | |
'obscene', | |
'threat', | |
'insult', | |
'identity_attack', | |
'sexual_explicit' | |
] | |
} | |
examples = [ | |
['Yes indeed. She sort of reminds me of the elder lady that played the part in the movie "Titanic" who was telling her story!!! And I wouldn\'t have wanted to cover who I really am!! I would be proud!!!! WE should be proud of our race no matter what it is!!', | |
50], | |
['The trans women reading this tweet right now is beautiful', 50], | |
["Question: These 4 broads who criticize America, what country did they flee to get here? And now they want to make OUR America like THEIR former HELL HOLE. I don't think so!!!!!!!!!! Let them explain their GRATITUDE for letting them in OUR country.", | |
50], | |
['Well being subtle you could not have been born because white people don\'t matter', | |
50], | |
["For starters bend over the one in pink and kick that ass and pussy to get a taste until she's begging for a dick inside her.", | |
50] | |
] | |
def toxicity(sentence, threshold): | |
predicts = predict(model, tokenizer, sentence) | |
return [ x for x in predicts if predicts[x] > threshold/100 ], all_categories | |
gr.Interface(fn=toxicity, | |
inputs=[ | |
gr.Textbox(placeholder="Enter sentence here..."), | |
gr.Slider(0, 100) | |
], | |
outputs=[ | |
'text', | |
gr.JSON(all_categories) | |
], | |
examples=examples).launch() | |