Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,38 +1,263 @@
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
-
import
|
4 |
-
from
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
option = st.selectbox(
|
13 |
'Discourse Type',
|
14 |
('Position', 'Concluding Statement', 'Claim', 'Counterclaim' , 'Evidence', 'Lead', 'Position', 'Rebuttal'))
|
15 |
text = st.text_area('Input Here!')
|
16 |
|
17 |
if text:
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
# inputs2 = tokenizer2(text, padding=True, truncation=True, return_tensors="pt")
|
22 |
-
# inputs3 = tokenizer3(text, padding=True, truncation=True, return_tensors="pt")
|
23 |
-
|
24 |
-
outputs1 = model1(**inputs1)
|
25 |
-
# outputs2 = model2(**inputs2)
|
26 |
-
# outputs3 = model3(**inputs3)
|
27 |
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
out = 'Adequate'
|
33 |
-
elif prediction == 1:
|
34 |
-
out = 'Effective'
|
35 |
-
elif prediction == 2:
|
36 |
-
out = 'Ineffective'
|
37 |
-
|
38 |
-
st.text(out)
|
1 |
+
from flask import Flask, request
|
2 |
import streamlit as st
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
# import numpy as np
|
6 |
+
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup, AutoModel
|
7 |
+
from datasets import Dataset
|
8 |
+
import math
|
9 |
+
|
10 |
+
from sklearn.preprocessing import LabelEncoder
|
11 |
+
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
from pytorch_lightning import seed_everything
|
14 |
+
|
15 |
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch import nn
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
|
21 |
+
data = {'discourse_type':[''],'discourse_text':['']}
|
22 |
+
data_path = pd.DataFrame(data)
|
23 |
+
test_path = pd.DataFrame(data)
|
24 |
+
|
25 |
+
|
26 |
+
attributes = ["Adequate" ,"Effective","Ineffective"]
|
27 |
+
distilbert_config={'name': 'distilbert',
|
28 |
+
'model_name':'distilbert-base-uncased',
|
29 |
+
'newly_tuned_model_path' : './20220820-043647.pth',
|
30 |
+
'wandb':False,
|
31 |
+
'param':{
|
32 |
+
'n_labels': 3,
|
33 |
+
'batch_size': 64,
|
34 |
+
'lr': 8e-4,#6e-5,
|
35 |
+
'warmup': 0,
|
36 |
+
'weight_decay': 0.01,#Default is 0.01
|
37 |
+
'n_epochs': 5,#4,
|
38 |
+
'n_freeze' : 5,
|
39 |
+
'p_dropout':0,#0.2,#0.6,
|
40 |
+
'scheduler':False,
|
41 |
+
'precision':16, #Default is 32
|
42 |
+
'sample_mode':True,
|
43 |
+
'sample_size': 100,
|
44 |
+
'swa':False,
|
45 |
+
'swa_lrs':1e-2
|
46 |
+
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
seed_everything(91, workers=True)
|
51 |
+
|
52 |
+
|
53 |
+
# Freeze the hidden layer within the pretrained model
|
54 |
+
def freeze(module):
|
55 |
+
for parameter in module.parameters():
|
56 |
+
parameter.requires_grad = False
|
57 |
+
|
58 |
+
def get_freezed_parameters(module):
|
59 |
+
freezed_parameters = []
|
60 |
+
for name, parameter in module.named_parameters():
|
61 |
+
if not parameter.requires_grad:
|
62 |
+
freezed_parameters.append(name)
|
63 |
+
return freezed_parameters
|
64 |
+
|
65 |
+
|
66 |
+
class _Dataset(Dataset):
|
67 |
+
def __init__(self,data_path,test_path, tokenizer,label_encoder,attributes,config, max_token_len: int = 512, is_train=True,is_test=False):
|
68 |
+
self.data_path = data_path
|
69 |
+
self.test_path = test_path
|
70 |
+
self.tokenizer = tokenizer
|
71 |
+
self.attributes = attributes
|
72 |
+
self.max_token_len = max_token_len
|
73 |
+
self.is_train = is_train
|
74 |
+
self.is_test = is_test
|
75 |
+
self.label_encoder = label_encoder
|
76 |
+
self.config = config
|
77 |
+
self._prepare_data()
|
78 |
+
|
79 |
+
def _prepare_data(self):
|
80 |
+
SEP = self.tokenizer.sep_token # different model uses different to text as seperator (e.g. [SEP], </s>)
|
81 |
+
df = self.test_path
|
82 |
+
df['text'] = df['discourse_type'] + SEP + df['discourse_text']
|
83 |
+
df = df.loc[:,['text']]
|
84 |
+
self.df = df
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.df)
|
88 |
+
|
89 |
+
def __getitem__(self,index):
|
90 |
+
item = self.df.iloc[index]
|
91 |
+
text = str(item.text)
|
92 |
+
tokens = self.tokenizer.encode_plus(text,
|
93 |
+
add_special_tokens= True,
|
94 |
+
return_tensors='pt',
|
95 |
+
truncation=True,
|
96 |
+
max_length=self.max_token_len,
|
97 |
+
return_attention_mask = True)
|
98 |
+
if self.is_test:
|
99 |
+
return {'input_ids':tokens.input_ids.flatten(),'attention_mask': tokens.attention_mask.flatten()}
|
100 |
+
else:
|
101 |
+
attributes = item['labels'].split()
|
102 |
+
self.label_encoder.fit(self.attributes)
|
103 |
+
attributes = self.label_encoder.transform(attributes)
|
104 |
+
attributes = torch.as_tensor(attributes)
|
105 |
+
return {'input_ids':tokens.input_ids.flatten(),'attention_mask': tokens.attention_mask.flatten(), 'labels':attributes}
|
106 |
+
|
107 |
+
|
108 |
+
class Collate:
|
109 |
+
def __init__(self, tokenizer, isTrain=True):
|
110 |
+
self.tokenizer = tokenizer
|
111 |
+
self.isTrain = isTrain
|
112 |
+
|
113 |
+
def __call__(self, batch):
|
114 |
+
output = dict()
|
115 |
+
output["input_ids"] = [sample["input_ids"] for sample in batch]
|
116 |
+
output["attention_mask"] = [sample["attention_mask"] for sample in batch]
|
117 |
+
if self.isTrain:
|
118 |
+
output["labels"] = [sample["labels"] for sample in batch]
|
119 |
+
|
120 |
+
# calculate max token length of this batch
|
121 |
+
batch_max = max([len(ids) for ids in output["input_ids"]])
|
122 |
+
|
123 |
+
# add padding
|
124 |
+
if self.tokenizer.padding_side == "right":
|
125 |
+
output["input_ids"] = [s.tolist() + (batch_max - len(s)) * [self.tokenizer.pad_token_id] for s in output["input_ids"]]
|
126 |
+
output["attention_mask"] = [s.tolist() + (batch_max - len(s)) * [0] for s in output["attention_mask"]]
|
127 |
+
|
128 |
+
else:
|
129 |
+
output["input_ids"] = [torch.FloatTensor((batch_max - len(s)) * [self.tokenizer.pad_token_id].tolist()) + s.tolist() for s in output["input_ids"]]
|
130 |
+
output["attention_mask"] = [torch.FloatTensor((batch_max - len(s)) * [0]) + s.tolist() for s in output["attention_mask"]]
|
131 |
+
|
132 |
+
# convert to tensors
|
133 |
+
output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
|
134 |
+
output["attention_mask"] = torch.tensor(output["attention_mask"], dtype=torch.long)
|
135 |
+
if self.isTrain:
|
136 |
+
output["labels"] = torch.tensor(output["labels"], dtype=torch.long)
|
137 |
+
return output
|
138 |
+
|
139 |
+
class _Data_Module(pl.LightningDataModule):
|
140 |
+
|
141 |
+
def __init__(self, data_path, test_path,attributes,label_encoder,tokenizer,config, batch_size: int = 8, max_token_length: int = 512):
|
142 |
+
super().__init__()
|
143 |
+
self.data_path = data_path
|
144 |
+
self.test_path = test_path
|
145 |
+
self.attributes = attributes
|
146 |
+
self.batch_size = batch_size
|
147 |
+
self.max_token_length = max_token_length
|
148 |
+
self.tokenizer = tokenizer
|
149 |
+
self.label_encoder = label_encoder
|
150 |
+
self.config = config
|
151 |
+
|
152 |
+
def setup(self, stage = None):
|
153 |
+
if stage == 'predict':
|
154 |
+
self.test_dataset = _Dataset(self.data_path, self.test_path, label_encoder = self.label_encoder, attributes=self.attributes, is_train=False,is_test=True, tokenizer=self.tokenizer,config = self.config)
|
155 |
+
|
156 |
+
def predict_dataloader(self):
|
157 |
+
collate_fn = Collate(self.tokenizer,
|
158 |
+
isTrain=False)
|
159 |
+
|
160 |
+
return DataLoader(self.test_dataset,
|
161 |
+
batch_size = self.batch_size,
|
162 |
+
num_workers=2,
|
163 |
+
shuffle=False,
|
164 |
+
collate_fn = collate_fn)
|
165 |
+
|
166 |
+
|
167 |
+
class DistilBert_Text_Classifier(pl.LightningModule):
|
168 |
+
|
169 |
+
def __init__(self, config: dict,data_module):
|
170 |
+
super().__init__()
|
171 |
+
self.config = config
|
172 |
+
self.data_module=data_module
|
173 |
+
self.pretrained_model = AutoModel.from_pretrained(config['model_name'], return_dict = True)
|
174 |
+
freeze((self.pretrained_model).embeddings)
|
175 |
+
freeze((self.pretrained_model).transformer.layer[:config['param']['n_freeze']])
|
176 |
+
self.classifier = torch.nn.Linear(self.pretrained_model.config.hidden_size, self.config['param']['n_labels'])
|
177 |
+
self.loss_func = nn.CrossEntropyLoss() # do not put SoftMax, just use CrossEntropyLoss
|
178 |
+
|
179 |
+
self.dropout = nn.Dropout(config['param']['p_dropout'])
|
180 |
+
|
181 |
+
# For inference
|
182 |
+
def forward(self, input_ids, attention_mask, labels = None):
|
183 |
+
output = self.pretrained_model(input_ids = input_ids, attention_mask = attention_mask)
|
184 |
+
pooled_output = torch.mean(output.last_hidden_state, 1) # mean of sequence length
|
185 |
+
pooled_output = F.relu(pooled_output)
|
186 |
+
pooled_output = self.dropout(pooled_output)
|
187 |
+
logits = self.classifier(pooled_output)
|
188 |
+
|
189 |
+
loss = 0
|
190 |
+
if labels is not None:
|
191 |
+
loss = self.loss_func(logits,labels)
|
192 |
+
return loss, logits
|
193 |
+
|
194 |
+
def predict_step(self, batch, batch_index):
|
195 |
+
loss, logits = self(**batch)
|
196 |
+
return logits
|
197 |
+
|
198 |
+
def configure_optimizers(self):
|
199 |
+
train_size = len(self.data_module.train_dataloader())
|
200 |
+
|
201 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.config['param']['lr'], weight_decay=self.config['param']['weight_decay'])
|
202 |
+
if self.config['param']['scheduler']:
|
203 |
+
total_steps = train_size/self.config['param']['batch_size']
|
204 |
+
warmup_steps = math.floor(total_steps * self.config['param']['warmup'])
|
205 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
|
206 |
+
return[optimizer],[scheduler]
|
207 |
+
else:
|
208 |
+
return optimizer
|
209 |
+
|
210 |
+
def predict(_Text_Classifier,config,test_path):
|
211 |
+
attributes = ["Adequate" ,"Effective","Ineffective"]
|
212 |
+
tokenizer = AutoTokenizer.from_pretrained(config['model_name'], use_fast=True)
|
213 |
+
le = LabelEncoder()
|
214 |
+
|
215 |
+
# Initialize data module
|
216 |
+
test_data_module = _Data_Module(data_path,
|
217 |
+
test_path,
|
218 |
+
attributes,
|
219 |
+
le,
|
220 |
+
tokenizer,
|
221 |
+
batch_size=config['param']['batch_size'],
|
222 |
+
config=config
|
223 |
+
)
|
224 |
+
test_data_module.setup()
|
225 |
+
|
226 |
+
# Initialize Model
|
227 |
+
model = _Text_Classifier(config,test_data_module)
|
228 |
+
model.load_state_dict(torch.load(config['newly_tuned_model_path']))
|
229 |
+
|
230 |
+
# Initialize Trainer
|
231 |
+
trainer = pl.Trainer(accelerator='auto')
|
232 |
+
|
233 |
+
output = trainer.predict(model, datamodule=test_data_module)
|
234 |
+
predictions = output[0].argmax(dim=-1).item()
|
235 |
+
return predictions
|
236 |
+
|
237 |
option = st.selectbox(
|
238 |
'Discourse Type',
|
239 |
('Position', 'Concluding Statement', 'Claim', 'Counterclaim' , 'Evidence', 'Lead', 'Position', 'Rebuttal'))
|
240 |
text = st.text_area('Input Here!')
|
241 |
|
242 |
if text:
|
243 |
+
discourse_type = option
|
244 |
+
discourse_text = text
|
245 |
+
test_path = pd.DataFrame({'discourse_type':[discourse_type],'discourse_text':[discourse_text]})
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
+
# prediction = predict(DistilBert_Text_Classifier,distilbert_config,test_path)
|
248 |
+
prediction = int(discourse_text)
|
249 |
+
if prediction == 0:
|
250 |
+
out = 'Adequate'
|
251 |
+
elif prediction == 1:
|
252 |
+
out = 'Effective'
|
253 |
+
elif prediction == 2:
|
254 |
+
out = 'Ineffective'
|
255 |
+
st.text(out)
|
256 |
+
#return {'response':out}
|
257 |
+
|
258 |
+
|
259 |
+
#if __name__ == '__main__':
|
260 |
+
# app.run(host='0.0.0.0', debug=True, port=int(os.environ.get("PORT", 8080)))
|
261 |
+
|
262 |
|
263 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|