Spaces:
Running
Running
import torch | |
from transformers import BertTokenizer, BertModel | |
from huggingface_hub import PyTorchModelHubMixin | |
import numpy as np | |
import gradio as gr | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
device | |
class BERTClass(torch.nn.Module, PyTorchModelHubMixin): | |
def __init__(self): | |
super(BERTClass, self).__init__() | |
self.bert_model = BertModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2', return_dict=True) | |
self.dropout = torch.nn.Dropout(0.3) | |
self.linear = torch.nn.Linear(1024, 11) | |
def forward(self, input_ids, attn_mask, token_type_ids): | |
output = self.bert_model( | |
input_ids, | |
attention_mask=attn_mask, | |
token_type_ids=token_type_ids | |
) | |
output_dropout = self.dropout(output.pooler_output) | |
output = self.linear(output_dropout) | |
return output | |
model = BERTClass() | |
model = model.from_pretrained("Asutosh2003/ct-bert-v2-vaccine-concern") | |
model.to(device) | |
tokenizer = BertTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2') | |
MAX_LEN = 256 | |
def return_vec(text): | |
encodings = tokenizer.encode_plus( | |
text, | |
None, | |
add_special_tokens=True, | |
max_length=MAX_LEN, | |
padding='max_length', | |
return_token_type_ids=True, | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
model.eval() | |
with torch.no_grad(): | |
input_ids = encodings['input_ids'].to(device, dtype=torch.long) | |
attention_mask = encodings['attention_mask'].to(device, dtype=torch.long) | |
token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long) | |
output = model(input_ids, attention_mask, token_type_ids) | |
final_output = torch.sigmoid(output).cpu().detach().numpy().tolist() | |
return list(final_output[0]) | |
def filter_threshold_lst(vector, threshold_list): | |
optimized_vector = [] | |
optimized_vector = [1 if val >= threshold else 0 for val, threshold in zip(vector, threshold_list)] | |
optimized_vector.append(optimized_vector) | |
return optimized_vector | |
def predict(text, threshold_lst): | |
pred_lbl_lst = [] | |
labels = ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious') | |
prob_lst = return_vec(text) | |
vec = filter_threshold_lst(prob_lst, threshold_lst) | |
if vec[:11] == [0] * 11: | |
pred_lbl_lst = ['none'] | |
vec = [0] * 11 | |
vec.append(1) | |
return pred_lbl_lst, prob_lst | |
for i in range(len(vec)): | |
if vec[i] == 1: | |
pred_lbl_lst.append(labels[i]) | |
return pred_lbl_lst, prob_lst | |
def gr_predict(text): | |
thres = [0.616, 0.212, 0.051, 0.131, 0.212, 0.111, 0.071, 0.566, 0.061, 0.02, 0.081] | |
out_lst, _ = predict(text,thres) | |
out_str = '' | |
for lbl in out_lst: | |
out_str += lbl + ',' | |
out_str = out_str[:-1] | |
return out_str | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=gr_predict, | |
inputs=gr.Textbox(), | |
outputs=gr.Label() # Use Label widget for output | |
) | |
# Launch the Gradio app | |
iface.launch(debug=True) |