hk-bt-rnd
Init space
ad5ee12
import gradio as gr
import numpy as np
from PIL import Image
from matplotlib import cm
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from model import Classifier
import torch.nn as nn
import torch.nn.functional as F
# Load model directly
MODEL_NAME = "cahya/roberta-base-indonesian-522M"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class_names = ['Action', 'Adventure', 'Comedy', 'Drama', 'Fantasy', 'Romance', 'Sci-Fi']
config = AutoConfig.from_pretrained(MODEL_NAME)
transformer = AutoModel.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
cp = torch.load(r"weight.pt", map_location="cpu")
transformer.load_state_dict(cp['w_t'])
classifier = Classifier(input_size = config.hidden_size, output_sizes = [1, 1, 1, 3, 5])
classifier.load_state_dict(cp['w_c'])
transformer.to(device)
classifier.to(device)
target_names = ["Individual", 'Group']
strength_names = ["Weak", 'Moderate', 'Strong']
type_names = ['Religion','Race','Physical','Gender','Other']
act_sig = nn.Sigmoid()
act_soft = nn.Softmax()
def predict(sentence):
# Tokenize the input sentence
inputs = tokenizer(sentence,
add_special_tokens = True, \
max_length = 256, \
padding = "max_length", \
truncation = True,
return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
att_masks = inputs['attention_mask'].to(device)
# Get model predictions
with torch.no_grad():
out = transformer(input_ids, attention_mask=att_masks)
logits = out.pooler_output
out = classifier(logits)
hs_out, abusive_out, target_out, strength_out, type_out = out[0], out[1], out[2], out[3], out[4]
hs_act, abusive_act, target_act, strength_act, type_act = act_sig(hs_out).squeeze(), \
act_sig(abusive_out).squeeze(), act_sig(target_out).squeeze(0), act_soft(strength_out), act_sig(type_out).squeeze(0)
# Interpret the predictions
is_hate_speech = bool(hs_act >= 0.5)
is_abusive = bool(abusive_act >= 0.5)
hate_speech_target = int(target_act >= 0.5)
hate_speech_strength = strength_act.argmax().item()
if is_hate_speech:
hate_speech_target_label = target_names[hate_speech_target]
hate_speech_strength_label = strength_names[hate_speech_strength]
hate_speech_type_label = []
print('target', target_act)
print('strength', strength_act)
for idx, prob in enumerate(type_act):
if prob >= 0.5:
hate_speech_type_label.append(type_names[idx])
if len(hate_speech_type_label) == 0:
hate_speech_type_label.append("Other")
else:
hate_speech_target_label = "Non-HS"
hate_speech_strength_label = "Non-HS"
hate_speech_type_label = "Non-HS"
return is_hate_speech, is_abusive, hate_speech_target_label, hate_speech_strength_label, {"hs_type":hate_speech_type_label}
# Create the Gradio interface
iface = gr.Interface(fn=predict, inputs=gr.Textbox(label="Enter a sentence"), outputs=[
gr.Label(label="Is Hate Speech"),
gr.Label(label="Is Abusive"),
gr.Label(label="Hate Speech Target"),
gr.Label(label="Hate Speech Strength"),
gr.JSON(label="Hate Speech Type")
], title="Hate Speech Detection")
iface.launch() # Launches the mini app!