Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import transformers | |
import json | |
from flask import Flask, jsonify, request | |
import torch.nn.functional as F | |
import boto3 | |
import pandas as pd | |
#bucket = 'data-ai-dev2' | |
from transformers import BertTokenizer, BertModel | |
from torch import cuda | |
import gradio as gr | |
device = 'cuda' if cuda.is_available() else 'cpu' | |
class RobertaClass(torch.nn.Module): | |
def __init__(self): | |
super(RobertaClass, self).__init__() | |
self.l1 = BertModel.from_pretrained("bert-base-multilingual-cased") | |
self.pre_classifier = torch.nn.Linear(768, 768) | |
self.dropout = torch.nn.Dropout(0.3) | |
self.classifier = torch.nn.Linear(768, 8) | |
def forward(self, input_ids, attention_mask, token_type_ids): | |
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) | |
hidden_state = output_1[0] | |
pooler = hidden_state[:, 0] | |
pooler = self.pre_classifier(pooler) | |
pooler = torch.nn.ReLU()(pooler) | |
pooler = self.dropout(pooler) | |
output = self.classifier(pooler) | |
return output | |
model = RobertaClass() | |
model.to(device) | |
model = torch.load('./tweet_model_v1.bin', map_location=torch.device('cpu')) | |
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', truncation=True, do_lower_case=True) | |
def id2class_fun(lst, map_cl): | |
s = pd.Series(lst) | |
return s.map(map_cl).tolist() | |
id2class = {0: 'InappropriateUndesirable', 1 : 'GreenContent', 2 : 'IllegalActivities', | |
3 : 'DiscriminatoryHate', 4 :'ViolentGraphic', 5:'PotentialAddiction', | |
6 : 'ExtremismTerrorism', 7 : 'SexualExplicit'} | |
def process(text): | |
try: | |
inputs = ( | |
tokenizer.encode_plus( | |
text, None, add_special_tokens=True, max_length = 512, | |
return_token_type_ids=True, padding=True, | |
truncation=True, return_tensors='pt')) | |
ids = inputs['input_ids'] | |
mask = inputs['attention_mask'] | |
token_type_ids = inputs["token_type_ids"] | |
outputs = model(ids, mask, token_type_ids) | |
top_values, top_indices = torch.topk(outputs.data, k=2, dim=1) | |
probs_values = F.softmax(top_values, dim=0) | |
prd_cls = top_indices.cpu().detach().numpy().tolist() | |
prd_cls = [item for sublist in prd_cls for item in sublist] | |
prd_cls_1 = id2class_fun(prd_cls, id2class) | |
prd_score = top_values.cpu().detach().numpy().tolist() | |
prd_score = [item for sublist in prd_score for item in sublist] | |
otp = dict(zip(prd_cls_1, prd_score)) | |
return {'output':otp} | |
except: | |
return {'output':'something went wrong'} | |
inputs = [gr.inputs.Textbox(lines=2, label="Enter the tweet")] | |
outputs = gr.outputs.Textbox(label="result") | |
gr.Interface(fn=process, inputs=inputs, outputs=outputs, title="twitter_classifier", | |
theme="compact").launch() | |