Asutosh2003's picture
Update app.py
bd209d0 verified
raw history blame
No virus
3.94 kB
import torch
from transformers import BertTokenizer, BertModel
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
import gradio as gr
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
import re
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device
class BERTClass(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self):
super(BERTClass, self).__init__()
self.bert_model = BertModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2', return_dict=True)
self.dropout = torch.nn.Dropout(0.3)
self.linear = torch.nn.Linear(1024, 11)
def forward(self, input_ids, attn_mask, token_type_ids):
output = self.bert_model(
input_ids,
attention_mask=attn_mask,
token_type_ids=token_type_ids
)
output_dropout = self.dropout(output.pooler_output)
output = self.linear(output_dropout)
return output
model = BERTClass()
model = model.from_pretrained("Asutosh2003/ct-bert-v2-vaccine-concern")
model.to(device)
tokenizer = BertTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
MAX_LEN = 256
def rmTrash(raw_string, remuser, remstop, remurls):
final_string = ""
raw_string_2 = ""
if remuser == True:
for i in raw_string.split():
if '@' not in i:
raw_string_2 += ' ' + i
else:
raw_string_2 = raw_string
raw_string_2 = re.sub(r'[^\w\s]', '', raw_string_2.lower())
if remurls == True:
raw_string_2 = re.sub(r'http\S+', '', raw_string_2.lower())
if remstop == True:
raw_string_tokens = raw_string_2.split()
for token in raw_string_tokens:
if (not(token in stopwords.words('english'))):
final_string = final_string + ' ' + token
else:
final_string = raw_string_2
return final_string
def return_vec(text):
text = rmTrash(text,True,True,True)
encodings = tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=MAX_LEN,
padding='max_length',
return_token_type_ids=True,
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
model.eval()
with torch.no_grad():
input_ids = encodings['input_ids'].to(device, dtype=torch.long)
attention_mask = encodings['attention_mask'].to(device, dtype=torch.long)
token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long)
output = model(input_ids, attention_mask, token_type_ids)
final_output = torch.sigmoid(output).cpu().detach().numpy().tolist()
return list(final_output[0])
def filter_threshold_lst(vector, threshold_list):
optimized_vector = []
optimized_vector = [1 if val >= threshold else 0 for val, threshold in zip(vector, threshold_list)]
optimized_vector.append(optimized_vector)
return optimized_vector
def predict(text, threshold_lst):
pred_lbl_lst = []
labels = ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious')
prob_lst = return_vec(text)
vec = filter_threshold_lst(prob_lst, threshold_lst)
if vec[:11] == [0] * 11:
pred_lbl_lst = ['none']
vec = [0] * 11
vec.append(1)
return pred_lbl_lst, prob_lst
for i in range(len(vec)):
if vec[i] == 1:
pred_lbl_lst.append(labels[i])
return pred_lbl_lst, prob_lst
def gr_predict(text):
thres = [0.616, 0.212, 0.051, 0.131, 0.212, 0.111, 0.071, 0.566, 0.061, 0.02, 0.081]
out_lst, _ = predict(text,thres)
out_str = ''
for lbl in out_lst:
out_str += lbl + ','
out_str = out_str[:-1]
return out_str
# Gradio Interface
iface = gr.Interface(
fn=gr_predict,
inputs=gr.Textbox(),
outputs=gr.Label() # Use Label widget for output
)
# Launch the Gradio app
iface.launch(debug=True)