Asutosh2003 commited on
Commit
b805738
1 Parent(s): 38697ec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer, BertModel
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+ import numpy as np
5
+ import gradio as gr
6
+
7
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
8
+ device
9
+
10
+ class BERTClass(torch.nn.Module, PyTorchModelHubMixin):
11
+ def __init__(self):
12
+ super(BERTClass, self).__init__()
13
+ self.bert_model = BertModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2', return_dict=True)
14
+ self.dropout = torch.nn.Dropout(0.3)
15
+ self.linear = torch.nn.Linear(1024, 11)
16
+
17
+ def forward(self, input_ids, attn_mask, token_type_ids):
18
+ output = self.bert_model(
19
+ input_ids,
20
+ attention_mask=attn_mask,
21
+ token_type_ids=token_type_ids
22
+ )
23
+ output_dropout = self.dropout(output.pooler_output)
24
+ output = self.linear(output_dropout)
25
+ return output
26
+
27
+ model = BERTClass()
28
+
29
+ model = model.from_pretrained("Asutosh2003/ct-bert-v2-vaccine-concern")
30
+ model.to(device)
31
+
32
+ tokenizer = BertTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2')
33
+ MAX_LEN = 256
34
+
35
+
36
+ def return_vec(text):
37
+ encodings = tokenizer.encode_plus(
38
+ text,
39
+ None,
40
+ add_special_tokens=True,
41
+ max_length=MAX_LEN,
42
+ padding='max_length',
43
+ return_token_type_ids=True,
44
+ truncation=True,
45
+ return_attention_mask=True,
46
+ return_tensors='pt'
47
+ )
48
+ model.eval()
49
+ with torch.no_grad():
50
+ input_ids = encodings['input_ids'].to(device, dtype=torch.long)
51
+ attention_mask = encodings['attention_mask'].to(device, dtype=torch.long)
52
+ token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long)
53
+ output = model(input_ids, attention_mask, token_type_ids)
54
+ final_output = torch.sigmoid(output).cpu().detach().numpy().tolist()
55
+ return list(final_output[0])
56
+
57
+
58
+ def filter_threshold_lst(vector, threshold_list):
59
+ optimized_vector = []
60
+ optimized_vector = [1 if val >= threshold else 0 for val, threshold in zip(vector, threshold_list)]
61
+ optimized_vector.append(optimized_vector)
62
+
63
+ return optimized_vector
64
+
65
+
66
+ def predict(text, threshold_lst):
67
+ pred_lbl_lst = []
68
+ labels = ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious')
69
+ prob_lst = return_vec(text)
70
+ vec = filter_threshold_lst(prob_lst, threshold_lst)
71
+ if vec[:11] == [0] * 11:
72
+ pred_lbl_lst = ['none']
73
+ vec = [0] * 11
74
+ vec.append(1)
75
+ return pred_lbl_lst, prob_lst
76
+ for i in range(len(vec)):
77
+ if vec[i] == 1:
78
+ pred_lbl_lst.append(labels[i])
79
+ return pred_lbl_lst, prob_lst
80
+
81
+ def gr_predict(text):
82
+ thres = [0.616, 0.212, 0.051, 0.131, 0.212, 0.111, 0.071, 0.566, 0.061, 0.02, 0.081]
83
+ out_lst, _ = predict(text,thres)
84
+ out_str = ''
85
+ for lbl in out_lst:
86
+ out_str += lbl + ','
87
+ out_str = out_str[:-1]
88
+
89
+ return out_str
90
+
91
+ # Gradio Interface
92
+ iface = gr.Interface(
93
+ fn=gr_predict,
94
+ inputs=gr.Textbox(),
95
+ outputs=gr.Label() # Use Label widget for output
96
+ )
97
+
98
+ # Launch the Gradio app
99
+ iface.launch(debug=True)