Spaces:
Sleeping
Sleeping
File size: 5,049 Bytes
b805738 f1866af b805738 f1866af b805738 f1866af bd209d0 b805738 bd209d0 b805738 bd209d0 b805738 3cba291 b805738 3cba291 b805738 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
from transformers import BertTokenizer, BertModel
from huggingface_hub import PyTorchModelHubMixin
import numpy as np
import gradio as gr
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
import re
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device
class BERTClass(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self):
super(BERTClass, self).__init__()
self.bert_model = BertModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2', return_dict=True)
self.dropout = torch.nn.Dropout(0.3)
self.linear = torch.nn.Linear(1024, 11)
def forward(self, input_ids, attn_mask, token_type_ids):
output = self.bert_model(
input_ids,
attention_mask=attn_mask,
token_type_ids=token_type_ids
)
output_dropout = self.dropout(output.pooler_output)
output = self.linear(output_dropout)
return output
model = BERTClass()
model = model.from_pretrained("Asutosh2003/ct-bert-v2-vaccine-concern")
model.to(device)
tokenizer = BertTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
MAX_LEN = 256
def rmTrash(raw_string, remuser, remstop, remurls):
final_string = ""
raw_string_2 = ""
if remuser == True:
for i in raw_string.split():
if '@' not in i:
raw_string_2 += ' ' + i
else:
raw_string_2 = raw_string
raw_string_2 = re.sub(r'[^\w\s]', '', raw_string_2.lower())
if remurls == True:
raw_string_2 = re.sub(r'http\S+', '', raw_string_2.lower())
if remstop == True:
raw_string_tokens = raw_string_2.split()
for token in raw_string_tokens:
if (not(token in stopwords.words('english'))):
final_string = final_string + ' ' + token
else:
final_string = raw_string_2
return final_string
def return_vec(text):
text = rmTrash(text,True,True,True)
encodings = tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=MAX_LEN,
padding='max_length',
return_token_type_ids=True,
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
model.eval()
with torch.no_grad():
input_ids = encodings['input_ids'].to(device, dtype=torch.long)
attention_mask = encodings['attention_mask'].to(device, dtype=torch.long)
token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long)
output = model(input_ids, attention_mask, token_type_ids)
final_output = torch.sigmoid(output).cpu().detach().numpy().tolist()
return list(final_output[0])
def filter_threshold_lst(vector, threshold_list):
optimized_vector = []
optimized_vector = [1 if val >= threshold else 0 for val, threshold in zip(vector, threshold_list)]
optimized_vector.append(optimized_vector)
return optimized_vector
def predict(text, threshold_lst):
pred_lbl_lst = []
labels = ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious')
prob_lst = return_vec(text)
vec = filter_threshold_lst(prob_lst, threshold_lst)
if vec[:11] == [0] * 11:
pred_lbl_lst = ['none']
vec = [0] * 11
vec.append(1)
return pred_lbl_lst, prob_lst
for i in range(len(vec)):
if vec[i] == 1:
pred_lbl_lst.append(labels[i])
return pred_lbl_lst, prob_lst
def gr_predict(text):
thres = [0.616, 0.212, 0.051, 0.131, 0.212, 0.111, 0.071, 0.566, 0.061, 0.02, 0.081]
out_lst, _ = predict(text,thres)
out_str = ''
for lbl in out_lst:
out_str += lbl + ','
out_str = out_str[:-1]
return out_str
descr = """
This app uses [Covid-twitter-BERT-v2](https://huggingface.co/digitalepidemiologylab/covid-twitter-bert-v2)
fine tuned on a custom subset of [Caves dataset](https://arxiv.org/abs/2204.13746) sent by [FIRE 2023](http://fire.irsi.res.in/fire/2023/home)
conference to do multi-label classification of tweets expressing concerns towards vaccines. The different concerns/classes are
('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious').
Each tweet can be expressing multiple of these concerns. If a tweet is not expressing any concern falling into any of these categories
it will be classified as 'None'.\n
[Source files](https://github.com/Ranjit246/AISoME_FIRE_2023)\n
Try it out with some ridiculous statements about vaccines. You can use the examples below as a start.
"""
# Gradio Interface
iface = gr.Interface(
fn=gr_predict,
inputs=gr.Textbox(),
outputs=gr.Label(), # Use Label widget for output
examples=["This vaccine gave me mumps", "Chinese vaccine will infect our brain",
"Trump is gonna use these vaccines to control us and become the president"],
title="Vaccine Concerns ML",
description=descr
)
# Launch the Gradio app
iface.launch(debug=True) |