import torch from transformers import BertTokenizer, BertModel from huggingface_hub import PyTorchModelHubMixin import numpy as np import gradio as gr import nltk nltk.download('stopwords') from nltk.corpus import stopwords import re device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device class BERTClass(torch.nn.Module, PyTorchModelHubMixin): def __init__(self): super(BERTClass, self).__init__() self.bert_model = BertModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2', return_dict=True) self.dropout = torch.nn.Dropout(0.3) self.linear = torch.nn.Linear(1024, 11) def forward(self, input_ids, attn_mask, token_type_ids): output = self.bert_model( input_ids, attention_mask=attn_mask, token_type_ids=token_type_ids ) output_dropout = self.dropout(output.pooler_output) output = self.linear(output_dropout) return output model = BERTClass() model = model.from_pretrained("Asutosh2003/ct-bert-v2-vaccine-concern") model.to(device) tokenizer = BertTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2') MAX_LEN = 256 def rmTrash(raw_string, remuser, remstop, remurls): final_string = "" raw_string_2 = "" if remuser == True: for i in raw_string.split(): if '@' not in i: raw_string_2 += ' ' + i else: raw_string_2 = raw_string raw_string_2 = re.sub(r'[^\w\s]', '', raw_string_2.lower()) if remurls == True: raw_string_2 = re.sub(r'http\S+', '', raw_string_2.lower()) if remstop == True: raw_string_tokens = raw_string_2.split() for token in raw_string_tokens: if (not(token in stopwords.words('english'))): final_string = final_string + ' ' + token else: final_string = raw_string_2 return final_string def return_vec(text): text = rmTrash(text,True,True,True) encodings = tokenizer.encode_plus( text, None, add_special_tokens=True, max_length=MAX_LEN, padding='max_length', return_token_type_ids=True, truncation=True, return_attention_mask=True, return_tensors='pt' ) model.eval() with torch.no_grad(): input_ids = encodings['input_ids'].to(device, dtype=torch.long) attention_mask = encodings['attention_mask'].to(device, dtype=torch.long) token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long) output = model(input_ids, attention_mask, token_type_ids) final_output = torch.sigmoid(output).cpu().detach().numpy().tolist() return list(final_output[0]) def filter_threshold_lst(vector, threshold_list): optimized_vector = [] optimized_vector = [1 if val >= threshold else 0 for val, threshold in zip(vector, threshold_list)] optimized_vector.append(optimized_vector) return optimized_vector def predict(text, threshold_lst): pred_lbl_lst = [] labels = ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious') prob_lst = return_vec(text) vec = filter_threshold_lst(prob_lst, threshold_lst) if vec[:11] == [0] * 11: pred_lbl_lst = ['none'] vec = [0] * 11 vec.append(1) return pred_lbl_lst, prob_lst for i in range(len(vec)): if vec[i] == 1: pred_lbl_lst.append(labels[i]) return pred_lbl_lst, prob_lst def gr_predict(text): thres = [0.616, 0.212, 0.051, 0.131, 0.212, 0.111, 0.071, 0.566, 0.061, 0.02, 0.081] out_lst, _ = predict(text,thres) out_str = '' for lbl in out_lst: out_str += lbl + ',' out_str = out_str[:-1] return out_str descr = """ This app uses [Covid-twitter-BERT-v2](https://huggingface.co/digitalepidemiologylab/covid-twitter-bert-v2) fine tuned on a custom subset of [Caves dataset](https://arxiv.org/abs/2204.13746) sent by [FIRE 2023](http://fire.irsi.res.in/fire/2023/home) conference to do multi-label classification of tweets expressing concerns towards vaccines. The different concerns/classes are ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious'). Each tweet can be expressing multiple of these concerns. If a tweet is not expressing any concern falling into any of these categories it will be classified as 'None'.\n [Source files](https://github.com/Ranjit246/AISoME_FIRE_2023)\n Try it out with some ridiculous statements about vaccines. You can use the examples below as a start. """ # Gradio Interface iface = gr.Interface( fn=gr_predict, inputs=gr.Textbox(), outputs=gr.Label(), # Use Label widget for output examples=["This vaccine gave me mumps", "Chinese vaccine will infect our brain", "Trump is gonna use these vaccines to control us and become the president"], title="Vaccine Concerns ML", description=descr ) # Launch the Gradio app iface.launch(debug=True)