Spaces:
Build error
Build error
BecomeAllan
commited on
Commit
•
6b71499
1
Parent(s):
4a0a4f7
update
Browse files- .vscode/settings.json +7 -0
- ML-SLRC/Info.json +1 -0
- ML-SLRC/ML_SLRC.py +574 -0
- ML-SLRC/Util_funs.py +740 -0
- ML-SLRC/__init__.py +4 -0
- ML-SLRC/model.pt +3 -0
- app.py +0 -1
.vscode/settings.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"workbench.colorCustomizations": {
|
3 |
+
"activityBar.background": "#590F35",
|
4 |
+
"titleBar.activeBackground": "#7C154B",
|
5 |
+
"titleBar.activeForeground": "#FEFCFD"
|
6 |
+
}
|
7 |
+
}
|
ML-SLRC/Info.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"inner_print": 2, "bert_layers": 4, "max_seq_length": 512, "meta_epoch": 20, "k_spt": 8, "k_qry": 8, "outer_batch_size": 5, "inner_batch_size": 4, "outer_update_lr": 5e-05, "inner_update_lr": 5e-05, "inner_update_step": 4, "inner_update_step_eval": 4, "num_task_train": 20, "pos_weight": 1.5, "tresh": 0.9, "model": "allenai/scibert_scivocab_uncased"}
|
ML-SLRC/ML_SLRC.py
ADDED
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from copy import deepcopy
|
5 |
+
import re
|
6 |
+
import unicodedata
|
7 |
+
from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
from torch.optim import Adam
|
10 |
+
from copy import deepcopy
|
11 |
+
import gc
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from torchmetrics import functional as fn
|
15 |
+
import random
|
16 |
+
|
17 |
+
|
18 |
+
# Pre-trained model
|
19 |
+
class Encoder(nn.Module):
|
20 |
+
def __init__(self, layers, freeze_bert, model):
|
21 |
+
super(Encoder, self).__init__()
|
22 |
+
|
23 |
+
# Dummy Parameter
|
24 |
+
self.dummy_param = nn.Parameter(torch.empty(0))
|
25 |
+
|
26 |
+
# Pre-trained model
|
27 |
+
self.model = deepcopy(model)
|
28 |
+
|
29 |
+
# Freezing bert parameters
|
30 |
+
if freeze_bert:
|
31 |
+
for param in self.model.parameters():
|
32 |
+
param.requires_grad = freeze_bert
|
33 |
+
|
34 |
+
# Selecting hidden layers of the pre-trained model
|
35 |
+
old_model_encoder = self.model.encoder.layer
|
36 |
+
new_model_encoder = nn.ModuleList()
|
37 |
+
|
38 |
+
for i in layers:
|
39 |
+
new_model_encoder.append(old_model_encoder[i])
|
40 |
+
|
41 |
+
self.model.encoder.layer = new_model_encoder
|
42 |
+
|
43 |
+
# Feed forward
|
44 |
+
def forward(self, **x):
|
45 |
+
return self.model(**x)['pooler_output']
|
46 |
+
|
47 |
+
# Complete model
|
48 |
+
class SLR_Classifier(nn.Module):
|
49 |
+
def __init__(self, **data):
|
50 |
+
super(SLR_Classifier, self).__init__()
|
51 |
+
|
52 |
+
# Dummy Parameter
|
53 |
+
self.dummy_param = nn.Parameter(torch.empty(0))
|
54 |
+
|
55 |
+
# Loss function
|
56 |
+
# Binary Cross Entropy with logits reduced to mean
|
57 |
+
self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
|
58 |
+
pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
|
59 |
+
|
60 |
+
# Pre-trained model
|
61 |
+
self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
|
62 |
+
freeze_bert = data.get("freeze_bert", False),
|
63 |
+
model = data.get("model"),
|
64 |
+
)
|
65 |
+
|
66 |
+
# Feature Map Layer
|
67 |
+
self.feature_map = nn.Sequential(
|
68 |
+
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
|
69 |
+
nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
|
70 |
+
# nn.Dropout(data.get("drop", 0.5)),
|
71 |
+
nn.Linear(self.Encoder.model.config.hidden_size, 200),
|
72 |
+
nn.Dropout(data.get("drop", 0.5)),
|
73 |
+
)
|
74 |
+
|
75 |
+
# Classifier Layer
|
76 |
+
self.classifier = nn.Sequential(
|
77 |
+
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
|
78 |
+
# nn.Dropout(data.get("drop", 0.5)),
|
79 |
+
# nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
|
80 |
+
# nn.Dropout(data.get("drop", 0.5)),
|
81 |
+
nn.Tanh(),
|
82 |
+
nn.Linear(200, 1)
|
83 |
+
)
|
84 |
+
|
85 |
+
# Initializing layer parameters
|
86 |
+
nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001)
|
87 |
+
nn.init.zeros_(self.feature_map[1].bias)
|
88 |
+
|
89 |
+
# Feed forward
|
90 |
+
def forward(self, input_ids, attention_mask, token_type_ids, labels):
|
91 |
+
|
92 |
+
predict = self.Encoder(**{"input_ids":input_ids,
|
93 |
+
"attention_mask":attention_mask,
|
94 |
+
"token_type_ids":token_type_ids})
|
95 |
+
feature = self.feature_map(predict)
|
96 |
+
logit = self.classifier(feature)
|
97 |
+
|
98 |
+
predict = torch.sigmoid(logit)
|
99 |
+
|
100 |
+
# Loss function
|
101 |
+
loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1))
|
102 |
+
|
103 |
+
return [loss, [feature, logit], predict]
|
104 |
+
|
105 |
+
# Undesirable patterns within texts
|
106 |
+
patterns = {
|
107 |
+
'CONCLUSIONS AND IMPLICATIONS':'',
|
108 |
+
'BACKGROUND AND PURPOSE':'',
|
109 |
+
'EXPERIMENTAL APPROACH':'',
|
110 |
+
'KEY RESULTS AEA':'',
|
111 |
+
'©':'',
|
112 |
+
'®':'',
|
113 |
+
'μ':'',
|
114 |
+
'(C)':'',
|
115 |
+
'OBJECTIVE:':'',
|
116 |
+
'MATERIALS AND METHODS:':'',
|
117 |
+
'SIGNIFICANCE:':'',
|
118 |
+
'BACKGROUND:':'',
|
119 |
+
'RESULTS:':'',
|
120 |
+
'METHODS:':'',
|
121 |
+
'CONCLUSIONS:':'',
|
122 |
+
'AIM:':'',
|
123 |
+
'STUDY DESIGN:':'',
|
124 |
+
'CLINICAL RELEVANCE:':'',
|
125 |
+
'CONCLUSION:':'',
|
126 |
+
'HYPOTHESIS:':'',
|
127 |
+
'CLINICAL RELEVANCE:':'',
|
128 |
+
'Questions/Purposes:':'',
|
129 |
+
'Introduction:':'',
|
130 |
+
'PURPOSE:':'',
|
131 |
+
'PATIENTS AND METHODS:':'',
|
132 |
+
'FINDINGS:':'',
|
133 |
+
'INTERPRETATIONS:':'',
|
134 |
+
'FUNDING:':'',
|
135 |
+
'PROGRESS:':'',
|
136 |
+
'CONTEXT:':'',
|
137 |
+
'MEASURES:':'',
|
138 |
+
'DESIGN:':'',
|
139 |
+
'BACKGROUND AND OBJECTIVES:':'',
|
140 |
+
'<p>':'',
|
141 |
+
'</p>':'',
|
142 |
+
'<<ETX>>':'',
|
143 |
+
'+/-':'',
|
144 |
+
'\(.+\)':'',
|
145 |
+
'\[.+\]':'',
|
146 |
+
' \d ':'',
|
147 |
+
'<':'',
|
148 |
+
'>':'',
|
149 |
+
'- ':'',
|
150 |
+
' +':' ',
|
151 |
+
', ,':',',
|
152 |
+
',,':',',
|
153 |
+
'%':' percent',
|
154 |
+
'per cent':' percent'
|
155 |
+
}
|
156 |
+
|
157 |
+
patterns = {x.lower():y for x,y in patterns.items()}
|
158 |
+
|
159 |
+
|
160 |
+
LABEL_MAP = {'negative': 0,
|
161 |
+
'not included':0,
|
162 |
+
'0':0,
|
163 |
+
0:0,
|
164 |
+
'excluded':0,
|
165 |
+
'positive': 1,
|
166 |
+
'included':1,
|
167 |
+
'1':1,
|
168 |
+
1:1,
|
169 |
+
}
|
170 |
+
|
171 |
+
class SLR_DataSet(Dataset):
|
172 |
+
def __init__(self,treat_text =None, **args):
|
173 |
+
self.tokenizer = args.get('tokenizer')
|
174 |
+
self.data = args.get('data')
|
175 |
+
self.max_seq_length = args.get("max_seq_length", 512)
|
176 |
+
self.INPUT_NAME = args.get("input", 'x')
|
177 |
+
self.LABEL_NAME = args.get("output", 'y')
|
178 |
+
self.treat_text = treat_text
|
179 |
+
|
180 |
+
# Tokenizing and processing text
|
181 |
+
def encode_text(self, example):
|
182 |
+
comment_text = example[self.INPUT_NAME]
|
183 |
+
if self.treat_text:
|
184 |
+
comment_text = self.treat_text(comment_text)
|
185 |
+
|
186 |
+
try:
|
187 |
+
labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
|
188 |
+
except:
|
189 |
+
labels = -1
|
190 |
+
|
191 |
+
encoding = self.tokenizer.encode_plus(
|
192 |
+
(comment_text, "It is great text"),
|
193 |
+
add_special_tokens=True,
|
194 |
+
max_length=self.max_seq_length,
|
195 |
+
return_token_type_ids=True,
|
196 |
+
padding="max_length",
|
197 |
+
truncation=True,
|
198 |
+
return_attention_mask=True,
|
199 |
+
return_tensors='pt',
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
return tuple((
|
204 |
+
encoding["input_ids"].flatten(),
|
205 |
+
encoding["attention_mask"].flatten(),
|
206 |
+
encoding["token_type_ids"].flatten(),
|
207 |
+
torch.tensor([torch.tensor(labels).to(int)])
|
208 |
+
))
|
209 |
+
|
210 |
+
|
211 |
+
def __len__(self):
|
212 |
+
return len(self.data)
|
213 |
+
|
214 |
+
# Returning data
|
215 |
+
def __getitem__(self, index: int):
|
216 |
+
# print(index)
|
217 |
+
data_row = self.data.reset_index().iloc[index]
|
218 |
+
temp_data = self.encode_text(data_row)
|
219 |
+
return temp_data
|
220 |
+
|
221 |
+
|
222 |
+
class Learner(nn.Module):
|
223 |
+
|
224 |
+
def __init__(self, **args):
|
225 |
+
"""
|
226 |
+
:param args:
|
227 |
+
"""
|
228 |
+
super(Learner, self).__init__()
|
229 |
+
|
230 |
+
self.inner_print = args.get('inner_print')
|
231 |
+
self.inner_batch_size = args.get('inner_batch_size')
|
232 |
+
self.outer_update_lr = args.get('outer_update_lr')
|
233 |
+
self.inner_update_lr = args.get('inner_update_lr')
|
234 |
+
self.inner_update_step = args.get('inner_update_step')
|
235 |
+
self.inner_update_step_eval = args.get('inner_update_step_eval')
|
236 |
+
self.model = args.get('model')
|
237 |
+
self.device = args.get('device')
|
238 |
+
|
239 |
+
# Outer optimizer
|
240 |
+
self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
|
241 |
+
self.model.train()
|
242 |
+
|
243 |
+
def forward(self, batch_tasks, training = True, valid_train = True):
|
244 |
+
"""
|
245 |
+
batch = [(support TensorDataset, query TensorDataset),
|
246 |
+
(support TensorDataset, query TensorDataset),
|
247 |
+
(support TensorDataset, query TensorDataset),
|
248 |
+
(support TensorDataset, query TensorDataset)]
|
249 |
+
|
250 |
+
# support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
|
251 |
+
"""
|
252 |
+
task_accs = []
|
253 |
+
task_f1 = []
|
254 |
+
task_recall = []
|
255 |
+
sum_gradients = []
|
256 |
+
num_task = len(batch_tasks)
|
257 |
+
num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval
|
258 |
+
|
259 |
+
# Outer loop tasks
|
260 |
+
for task_id, task in enumerate(batch_tasks):
|
261 |
+
support = task[0]
|
262 |
+
query = task[1]
|
263 |
+
name = task[2]
|
264 |
+
|
265 |
+
# Copying model
|
266 |
+
fast_model = deepcopy(self.model)
|
267 |
+
fast_model.to(self.device)
|
268 |
+
|
269 |
+
# Inner trainer optimizer
|
270 |
+
inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
|
271 |
+
|
272 |
+
# Creating training data loaders
|
273 |
+
if len(support) % self.inner_batch_size == 1 :
|
274 |
+
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
|
275 |
+
batch_size=self.inner_batch_size,
|
276 |
+
drop_last=True)
|
277 |
+
else:
|
278 |
+
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
|
279 |
+
batch_size=self.inner_batch_size,
|
280 |
+
drop_last=False)
|
281 |
+
|
282 |
+
# steps_per_epoch=len(support) // self.inner_batch_size
|
283 |
+
# total_training_steps = steps_per_epoch * 5
|
284 |
+
# warmup_steps = total_training_steps // 3
|
285 |
+
#
|
286 |
+
|
287 |
+
# scheduler = get_linear_schedule_with_warmup(
|
288 |
+
# inner_optimizer,
|
289 |
+
# num_warmup_steps=warmup_steps,
|
290 |
+
# num_training_steps=total_training_steps
|
291 |
+
# )
|
292 |
+
|
293 |
+
fast_model.train()
|
294 |
+
|
295 |
+
# Inner loop training epoch (support set)
|
296 |
+
if valid_train:
|
297 |
+
print('----Task',task_id,":", name, '----')
|
298 |
+
|
299 |
+
for i in range(0, num_inner_update_step):
|
300 |
+
all_loss = []
|
301 |
+
|
302 |
+
# Inner loop training batch (support set)
|
303 |
+
for inner_step, batch in enumerate(support_dataloader):
|
304 |
+
batch = tuple(t.to(self.device) for t in batch)
|
305 |
+
input_ids, attention_mask, token_type_ids, label_id = batch
|
306 |
+
|
307 |
+
# Feed Foward
|
308 |
+
loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id)
|
309 |
+
|
310 |
+
# Computing gradients
|
311 |
+
loss.backward()
|
312 |
+
# torch.nn.utils.clip_grad_norm_(fast_model.parameters(), max_norm=1)
|
313 |
+
|
314 |
+
# Updating inner training parameters
|
315 |
+
inner_optimizer.step()
|
316 |
+
inner_optimizer.zero_grad()
|
317 |
+
|
318 |
+
# Appending losses
|
319 |
+
all_loss.append(loss.item())
|
320 |
+
|
321 |
+
del batch, input_ids, attention_mask, label_id
|
322 |
+
torch.cuda.empty_cache()
|
323 |
+
|
324 |
+
if valid_train:
|
325 |
+
if (i+1) % self.inner_print == 0:
|
326 |
+
print("Inner Loss: ", np.mean(all_loss))
|
327 |
+
|
328 |
+
fast_model.to(torch.device('cpu'))
|
329 |
+
|
330 |
+
# Inner training phase weights
|
331 |
+
if training:
|
332 |
+
meta_weights = list(self.model.parameters())
|
333 |
+
fast_weights = list(fast_model.parameters())
|
334 |
+
|
335 |
+
# Appending gradients
|
336 |
+
gradients = []
|
337 |
+
for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)):
|
338 |
+
gradient = meta_params - fast_params
|
339 |
+
if task_id == 0:
|
340 |
+
sum_gradients.append(gradient)
|
341 |
+
else:
|
342 |
+
sum_gradients[i] += gradient
|
343 |
+
|
344 |
+
|
345 |
+
# Inner test (query set)
|
346 |
+
fast_model.to(self.device)
|
347 |
+
fast_model.eval()
|
348 |
+
|
349 |
+
if valid_train:
|
350 |
+
# Inner test (query set)
|
351 |
+
fast_model.to(self.device)
|
352 |
+
fast_model.eval()
|
353 |
+
|
354 |
+
with torch.no_grad():
|
355 |
+
# Data loader
|
356 |
+
query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
|
357 |
+
query_batch = iter(query_dataloader).next()
|
358 |
+
query_batch = tuple(t.to(self.device) for t in query_batch)
|
359 |
+
q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch
|
360 |
+
|
361 |
+
# Feedfoward
|
362 |
+
_, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id)
|
363 |
+
|
364 |
+
# Predictions
|
365 |
+
pre_label_id = pre_label_id.detach().cpu().squeeze()
|
366 |
+
# Labels
|
367 |
+
q_label_id = q_label_id.detach().cpu()
|
368 |
+
|
369 |
+
# Calculating metrics
|
370 |
+
acc = fn.accuracy(pre_label_id, q_label_id).item()
|
371 |
+
recall = fn.recall(pre_label_id, q_label_id).item(),
|
372 |
+
f1 = fn.f1_score(pre_label_id, q_label_id).item()
|
373 |
+
|
374 |
+
# appending metrics
|
375 |
+
task_accs.append(acc)
|
376 |
+
task_f1.append(f1)
|
377 |
+
task_recall.append(recall)
|
378 |
+
|
379 |
+
fast_model.to(torch.device('cpu'))
|
380 |
+
|
381 |
+
del fast_model, inner_optimizer
|
382 |
+
torch.cuda.empty_cache()
|
383 |
+
|
384 |
+
print("\n")
|
385 |
+
print("f1:",np.mean(task_f1))
|
386 |
+
print("recall:",np.mean(task_recall))
|
387 |
+
|
388 |
+
# Updating outer training parameters
|
389 |
+
if training:
|
390 |
+
# Mean of gradients
|
391 |
+
for i in range(0,len(sum_gradients)):
|
392 |
+
sum_gradients[i] = sum_gradients[i] / float(num_task)
|
393 |
+
|
394 |
+
# Indexing parameters to model
|
395 |
+
for i, params in enumerate(self.model.parameters()):
|
396 |
+
params.grad = sum_gradients[i]
|
397 |
+
|
398 |
+
# Updating parameters
|
399 |
+
self.outer_optimizer.step()
|
400 |
+
self.outer_optimizer.zero_grad()
|
401 |
+
|
402 |
+
del sum_gradients
|
403 |
+
gc.collect()
|
404 |
+
torch.cuda.empty_cache()
|
405 |
+
|
406 |
+
if valid_train:
|
407 |
+
return np.mean(task_accs)
|
408 |
+
else:
|
409 |
+
return np.array(0)
|
410 |
+
|
411 |
+
|
412 |
+
|
413 |
+
# Creating Meta Tasks
|
414 |
+
class MetaTask(Dataset):
|
415 |
+
def __init__(self, examples, num_task, k_support, k_query,
|
416 |
+
tokenizer, training=True, max_seq_length=512,
|
417 |
+
treat_text =None, **args):
|
418 |
+
"""
|
419 |
+
:param samples: list of samples
|
420 |
+
:param num_task: number of training tasks.
|
421 |
+
:param k_support: number of classes support samples per task
|
422 |
+
:param k_query: number of classes query sample per task
|
423 |
+
"""
|
424 |
+
self.examples = examples
|
425 |
+
|
426 |
+
self.num_task = num_task
|
427 |
+
self.k_support = k_support
|
428 |
+
self.k_query = k_query
|
429 |
+
self.tokenizer = tokenizer
|
430 |
+
self.max_seq_length = max_seq_length
|
431 |
+
self.treat_text = treat_text
|
432 |
+
|
433 |
+
# Randomly generating tasks
|
434 |
+
self.create_batch(self.num_task, training)
|
435 |
+
|
436 |
+
# Creating batch
|
437 |
+
def create_batch(self, num_task, training):
|
438 |
+
self.supports = [] # support set
|
439 |
+
self.queries = [] # query set
|
440 |
+
self.task_names = [] # Name of task
|
441 |
+
self.supports_indexs = [] # index of supports
|
442 |
+
self.queries_indexs = [] # index of queries
|
443 |
+
self.num_task=num_task
|
444 |
+
|
445 |
+
# Available tasks
|
446 |
+
domains = self.examples['domain'].unique()
|
447 |
+
|
448 |
+
# If not training, create all tasks
|
449 |
+
if not(training):
|
450 |
+
self.task_names = domains
|
451 |
+
num_task = len(self.task_names)
|
452 |
+
self.num_task=num_task
|
453 |
+
|
454 |
+
|
455 |
+
for b in range(num_task): # For each task,
|
456 |
+
total_per_class = self.k_support + self.k_query
|
457 |
+
task_size = 2*self.k_support + 2*self.k_query
|
458 |
+
|
459 |
+
# Select a task at random
|
460 |
+
if training:
|
461 |
+
domain = random.choice(domains)
|
462 |
+
self.task_names.append(domain)
|
463 |
+
else:
|
464 |
+
domain = self.task_names[b]
|
465 |
+
|
466 |
+
# Task data
|
467 |
+
domainExamples = self.examples[self.examples['domain'] == domain]
|
468 |
+
|
469 |
+
# Minimal label quantity
|
470 |
+
min_per_class = min(domainExamples['label'].value_counts())
|
471 |
+
|
472 |
+
if total_per_class > min_per_class:
|
473 |
+
total_per_class = min_per_class
|
474 |
+
|
475 |
+
# Select k_support + k_query task examples
|
476 |
+
# Sample (n) from each label(class)
|
477 |
+
selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False)
|
478 |
+
|
479 |
+
# Split data into support (training) and query (testing) sets
|
480 |
+
s, q = train_test_split(selected_examples,
|
481 |
+
stratify= selected_examples["label"],
|
482 |
+
test_size= 2*self.k_query/task_size,
|
483 |
+
shuffle=True)
|
484 |
+
|
485 |
+
# Permutating data
|
486 |
+
s = s.sample(frac=1)
|
487 |
+
q = q.sample(frac=1)
|
488 |
+
|
489 |
+
# Appending indexes
|
490 |
+
if not(training):
|
491 |
+
self.supports_indexs.append(s.index)
|
492 |
+
self.queries_indexs.append(q.index)
|
493 |
+
|
494 |
+
# Creating list of support (training) and query (testing) tasks
|
495 |
+
self.supports.append(s.to_dict('records'))
|
496 |
+
self.queries.append(q.to_dict('records'))
|
497 |
+
|
498 |
+
# Creating task tensors
|
499 |
+
def create_feature_set(self, examples):
|
500 |
+
all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
501 |
+
all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
502 |
+
all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
503 |
+
all_label_ids = torch.empty(len(examples), dtype = torch.long)
|
504 |
+
|
505 |
+
for _id, e in enumerate(examples):
|
506 |
+
all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e)
|
507 |
+
|
508 |
+
return TensorDataset(
|
509 |
+
all_input_ids,
|
510 |
+
all_attention_mask,
|
511 |
+
all_token_type_ids,
|
512 |
+
all_label_ids
|
513 |
+
)
|
514 |
+
|
515 |
+
# Data encoding
|
516 |
+
def encode_text(self, example):
|
517 |
+
comment_text = example["text"]
|
518 |
+
|
519 |
+
if self.treat_text:
|
520 |
+
comment_text = self.treat_text(comment_text)
|
521 |
+
|
522 |
+
labels = LABEL_MAP[example["label"]]
|
523 |
+
|
524 |
+
encoding = self.tokenizer.encode_plus(
|
525 |
+
(comment_text, "It is a great text."),
|
526 |
+
add_special_tokens=True,
|
527 |
+
max_length=self.max_seq_length,
|
528 |
+
return_token_type_ids=True,
|
529 |
+
padding="max_length",
|
530 |
+
truncation=True,
|
531 |
+
return_attention_mask=True,
|
532 |
+
return_tensors='pt',
|
533 |
+
)
|
534 |
+
|
535 |
+
return tuple((
|
536 |
+
encoding["input_ids"].flatten(),
|
537 |
+
encoding["attention_mask"].flatten(),
|
538 |
+
encoding["token_type_ids"].flatten(),
|
539 |
+
torch.tensor([torch.tensor(labels).to(int)])
|
540 |
+
))
|
541 |
+
|
542 |
+
# Returns data upon calling
|
543 |
+
def __getitem__(self, index):
|
544 |
+
support_set = self.create_feature_set(self.supports[index])
|
545 |
+
query_set = self.create_feature_set(self.queries[index])
|
546 |
+
name = self.task_names[index]
|
547 |
+
return support_set, query_set, name
|
548 |
+
|
549 |
+
def __len__(self):
|
550 |
+
return self.num_task
|
551 |
+
|
552 |
+
|
553 |
+
class treat_text:
|
554 |
+
def __init__(self, patterns):
|
555 |
+
self.patterns = patterns
|
556 |
+
|
557 |
+
def __call__(self,text):
|
558 |
+
text = unicodedata.normalize("NFKD",str(text))
|
559 |
+
text = multiple_replace(self.patterns,text.lower())
|
560 |
+
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
|
561 |
+
text = re.sub('( +)',' ', text)
|
562 |
+
text = re.sub('(, ,)|(,,)',',', text)
|
563 |
+
text = re.sub('(%)|(per cent)',' percent', text)
|
564 |
+
return text
|
565 |
+
|
566 |
+
|
567 |
+
# Regex multiple replace function
|
568 |
+
def multiple_replace(dict, text):
|
569 |
+
|
570 |
+
# Building regex from dict keys
|
571 |
+
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
|
572 |
+
|
573 |
+
# Substitution
|
574 |
+
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
|
ML-SLRC/Util_funs.py
ADDED
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ML_SLRC import *
|
2 |
+
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torch.optim import Adam
|
10 |
+
|
11 |
+
import gc
|
12 |
+
from torchmetrics import functional as fn
|
13 |
+
|
14 |
+
import random
|
15 |
+
|
16 |
+
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from sklearn.metrics import confusion_matrix
|
20 |
+
from sklearn.metrics import roc_curve, auc
|
21 |
+
import ipywidgets as widgets
|
22 |
+
from IPython.display import display, clear_output
|
23 |
+
import matplotlib.pyplot as plt
|
24 |
+
import warnings
|
25 |
+
import torch
|
26 |
+
|
27 |
+
import time
|
28 |
+
from sklearn.manifold import TSNE
|
29 |
+
from copy import deepcopy
|
30 |
+
import seaborn as sns
|
31 |
+
import matplotlib.pylab as plt
|
32 |
+
import json
|
33 |
+
from pathlib import Path
|
34 |
+
|
35 |
+
import re
|
36 |
+
from collections import defaultdict
|
37 |
+
|
38 |
+
# SEED = 2222
|
39 |
+
|
40 |
+
# gen_seed = torch.Generator().manual_seed(SEED)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
# Random seed function
|
48 |
+
def random_seed(value):
|
49 |
+
torch.backends.cudnn.deterministic=True
|
50 |
+
torch.manual_seed(value)
|
51 |
+
torch.cuda.manual_seed(value)
|
52 |
+
np.random.seed(value)
|
53 |
+
random.seed(value)
|
54 |
+
|
55 |
+
# Tasks for meta-learner
|
56 |
+
def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
|
57 |
+
idxs = list(range(0,len(taskset)))
|
58 |
+
if is_shuffle:
|
59 |
+
random.shuffle(idxs)
|
60 |
+
for i in range(0,len(idxs), batch_size):
|
61 |
+
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
|
62 |
+
|
63 |
+
|
64 |
+
# Prepare data to process by Domain-learner
|
65 |
+
def prepare_data(data, batch_size, tokenizer,max_seq_length,
|
66 |
+
input = 'text', output = 'label',
|
67 |
+
train_size_per_class = 5, global_datasets = False,
|
68 |
+
treat_text_fun =None):
|
69 |
+
data = data.reset_index().drop("index", axis=1)
|
70 |
+
|
71 |
+
if global_datasets:
|
72 |
+
global data_train, data_test
|
73 |
+
|
74 |
+
# Sample task for training
|
75 |
+
data_train = data.groupby('label').sample(train_size_per_class, replace=False)
|
76 |
+
idex = data.index.isin(data_train.index)
|
77 |
+
|
78 |
+
# The Test set to label by the model
|
79 |
+
data_test = data
|
80 |
+
|
81 |
+
|
82 |
+
# Transform in dataset to model
|
83 |
+
## Train
|
84 |
+
dataset_train = SLR_DataSet(
|
85 |
+
data = data_train.sample(frac=1),
|
86 |
+
input = input,
|
87 |
+
output = output,
|
88 |
+
tokenizer=tokenizer,
|
89 |
+
max_seq_length =max_seq_length,
|
90 |
+
treat_text =treat_text_fun)
|
91 |
+
|
92 |
+
## Test
|
93 |
+
dataset_test = SLR_DataSet(
|
94 |
+
data = data_test,
|
95 |
+
input = input,
|
96 |
+
output = output,
|
97 |
+
tokenizer=tokenizer,
|
98 |
+
max_seq_length =max_seq_length,
|
99 |
+
treat_text =treat_text_fun)
|
100 |
+
|
101 |
+
# Dataloaders
|
102 |
+
## Train
|
103 |
+
data_train_loader = DataLoader(dataset_train,
|
104 |
+
shuffle=True,
|
105 |
+
batch_size=batch_size['train']
|
106 |
+
)
|
107 |
+
|
108 |
+
## Test
|
109 |
+
if len(dataset_test) % batch_size['test'] == 1 :
|
110 |
+
data_test_loader = DataLoader(dataset_test,
|
111 |
+
batch_size=batch_size['test'],
|
112 |
+
drop_last=True)
|
113 |
+
else:
|
114 |
+
data_test_loader = DataLoader(dataset_test,
|
115 |
+
batch_size=batch_size['test'],
|
116 |
+
drop_last=False)
|
117 |
+
|
118 |
+
return data_train_loader, data_test_loader, data_train, data_test
|
119 |
+
|
120 |
+
|
121 |
+
# Meta trainer
|
122 |
+
def meta_train(data, model, device, Info,
|
123 |
+
print_epoch =True,
|
124 |
+
Test_resource =None,
|
125 |
+
treat_text_fun =None):
|
126 |
+
|
127 |
+
# Meta-learner model
|
128 |
+
learner = Learner(model = model, device = device, **Info)
|
129 |
+
|
130 |
+
# Testing tasks
|
131 |
+
if isinstance(Test_resource, pd.DataFrame):
|
132 |
+
test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
|
133 |
+
training=False,treat_text =treat_text_fun, **Info)
|
134 |
+
|
135 |
+
|
136 |
+
torch.clear_autocast_cache()
|
137 |
+
gc.collect()
|
138 |
+
torch.cuda.empty_cache()
|
139 |
+
|
140 |
+
# Meta epoch (Outer epoch)
|
141 |
+
for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
|
142 |
+
|
143 |
+
# Train tasks
|
144 |
+
train = MetaTask(data,
|
145 |
+
num_task = Info['num_task_train'],
|
146 |
+
k_support=Info['k_qry'],
|
147 |
+
k_query=Info['k_spt'],
|
148 |
+
treat_text =treat_text_fun, **Info)
|
149 |
+
|
150 |
+
# Batch of train tasks
|
151 |
+
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
|
152 |
+
|
153 |
+
if print_epoch:
|
154 |
+
# Outer loop bach training
|
155 |
+
for step, task_batch in enumerate(db):
|
156 |
+
print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
|
157 |
+
|
158 |
+
# meta-feedfoward (outer-feedfoward)
|
159 |
+
acc = learner(task_batch, valid_train= print_epoch)
|
160 |
+
print('Step:', step, '\ttraining Acc:', acc)
|
161 |
+
|
162 |
+
if isinstance(Test_resource, pd.DataFrame):
|
163 |
+
# Validating Model
|
164 |
+
if ((epoch+1) % 4) + step == 0:
|
165 |
+
random_seed(123)
|
166 |
+
print("\n-----------------Testing Mode-----------------\n")
|
167 |
+
|
168 |
+
# Batch of test tasks
|
169 |
+
db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
|
170 |
+
acc_all_test = []
|
171 |
+
|
172 |
+
# Looping testing tasks
|
173 |
+
for test_batch in db_test:
|
174 |
+
acc = learner(test_batch, training = False)
|
175 |
+
acc_all_test.append(acc)
|
176 |
+
|
177 |
+
print('Test acc:', np.mean(acc_all_test))
|
178 |
+
del acc_all_test, db_test
|
179 |
+
|
180 |
+
# Restarting training randomly
|
181 |
+
random_seed(int(time.time() % 10))
|
182 |
+
|
183 |
+
else:
|
184 |
+
for step, task_batch in enumerate(db):
|
185 |
+
# meta-feedfoward (outer-feedfoward)
|
186 |
+
acc = learner(task_batch, print_epoch, valid_train= print_epoch)
|
187 |
+
|
188 |
+
torch.clear_autocast_cache()
|
189 |
+
gc.collect()
|
190 |
+
torch.cuda.empty_cache()
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name', weight_decay = 1):
|
195 |
+
# Start the model's parameters
|
196 |
+
model_meta = deepcopy(model)
|
197 |
+
optimizer = Adam(model_meta.parameters(), lr=lr, weight_decay = weight_decay)
|
198 |
+
|
199 |
+
model_meta.to(device)
|
200 |
+
model_meta.train()
|
201 |
+
|
202 |
+
# Task epoch (Inner epoch)
|
203 |
+
for i in range(0, epoch):
|
204 |
+
all_loss = []
|
205 |
+
|
206 |
+
# Inner training batch (support set)
|
207 |
+
for inner_step, batch in enumerate(data_train_loader):
|
208 |
+
batch = tuple(t.to(device) for t in batch)
|
209 |
+
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
210 |
+
|
211 |
+
# Inner Feedfoward
|
212 |
+
loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
213 |
+
|
214 |
+
# compute grads
|
215 |
+
loss.backward()
|
216 |
+
|
217 |
+
# update parameters
|
218 |
+
optimizer.step()
|
219 |
+
optimizer.zero_grad()
|
220 |
+
|
221 |
+
all_loss.append(loss.item())
|
222 |
+
|
223 |
+
|
224 |
+
if (i % 2 == 0) & print_info:
|
225 |
+
print("Loss: ", np.mean(all_loss))
|
226 |
+
|
227 |
+
|
228 |
+
# Test evaluation
|
229 |
+
model_meta.eval()
|
230 |
+
all_loss = []
|
231 |
+
all_acc = []
|
232 |
+
features = []
|
233 |
+
labels = []
|
234 |
+
predi_logit = []
|
235 |
+
|
236 |
+
with torch.no_grad():
|
237 |
+
# Test's Batch loop
|
238 |
+
for inner_step, batch in enumerate(tqdm(data_test_loader,
|
239 |
+
desc="Test validation | " + name,
|
240 |
+
ncols=80)) :
|
241 |
+
batch = tuple(t.to(device) for t in batch)
|
242 |
+
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
243 |
+
|
244 |
+
# Predictions
|
245 |
+
_, feature, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
246 |
+
|
247 |
+
# prediction = prediction.detach().cpu().squeeze()
|
248 |
+
# label_id = label_id.detach().cpu()
|
249 |
+
logit = feature[1].detach().cpu()
|
250 |
+
# feature_lat = feature[0].detach().cpu()
|
251 |
+
|
252 |
+
# labels.append(label_id.numpy().squeeze())
|
253 |
+
# features.append(feature_lat.numpy())
|
254 |
+
predi_logit.append(logit.numpy())
|
255 |
+
|
256 |
+
# Accuracy over the test's bach
|
257 |
+
# acc = fn.accuracy(prediction, label_id).item()
|
258 |
+
# all_acc.append(acc)
|
259 |
+
del input_ids, attention_mask, label_id, batch
|
260 |
+
|
261 |
+
if print_info:
|
262 |
+
print("acc:", np.mean(all_acc))
|
263 |
+
|
264 |
+
model_meta.to('cpu')
|
265 |
+
gc.collect()
|
266 |
+
torch.cuda.empty_cache()
|
267 |
+
|
268 |
+
del model_meta, optimizer
|
269 |
+
|
270 |
+
logits = np.concatenate(np.array(predi_logit,dtype=object))
|
271 |
+
logits = torch.tensor(logits.astype(np.float32)).detach().clone()
|
272 |
+
# return features, labels, predi_logit
|
273 |
+
|
274 |
+
return logits.detach().clone()
|
275 |
+
|
276 |
+
# Process predictions and map the feature_map in tsne
|
277 |
+
def map_feature_tsne(features, labels, predi_logit):
|
278 |
+
|
279 |
+
features = np.concatenate(np.array(features,dtype=object))
|
280 |
+
features = torch.tensor(features.astype(np.float32)).detach().clone()
|
281 |
+
|
282 |
+
labels = np.concatenate(np.array(labels,dtype=object))
|
283 |
+
labels = torch.tensor(labels.astype(int)).detach().clone()
|
284 |
+
|
285 |
+
logits = np.concatenate(np.array(predi_logit,dtype=object))
|
286 |
+
logits = torch.tensor(logits.astype(np.float32)).detach().clone()
|
287 |
+
|
288 |
+
# Dimention reduction
|
289 |
+
X_embedded = TSNE(n_components=2, learning_rate='auto',
|
290 |
+
init='random').fit_transform(features.detach().clone())
|
291 |
+
|
292 |
+
return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
|
293 |
+
|
294 |
+
def wss_calc(logit, labels, trsh = 0.5):
|
295 |
+
|
296 |
+
# Prediction label given the threshold
|
297 |
+
predict_trash = torch.sigmoid(logit).squeeze() >= trsh
|
298 |
+
|
299 |
+
# Compute confusion matrix values
|
300 |
+
CM = confusion_matrix(labels, predict_trash.to(int) )
|
301 |
+
tn, fp, fne, tp = CM.ravel()
|
302 |
+
|
303 |
+
P = (tp + fne)
|
304 |
+
N = (tn + fp)
|
305 |
+
recall = tp/(tp+fne)
|
306 |
+
|
307 |
+
# WSS
|
308 |
+
wss = (tn + fne)/len(labels) -(1- recall)
|
309 |
+
|
310 |
+
# AWSS
|
311 |
+
awss = (tn/N - fne/P)
|
312 |
+
|
313 |
+
return {
|
314 |
+
"wss": round(wss,4),
|
315 |
+
"awss": round(awss,4),
|
316 |
+
"R": round(recall,4),
|
317 |
+
"CM": CM
|
318 |
+
}
|
319 |
+
|
320 |
+
|
321 |
+
# Compute the metrics
|
322 |
+
def plot(logits, X_embedded, labels, threshold, show = True,
|
323 |
+
namefig = "plot", make_plot = True, print_stats = True, save = True):
|
324 |
+
col = pd.MultiIndex.from_tuples([
|
325 |
+
("Predict", "0"),
|
326 |
+
("Predict", "1")
|
327 |
+
])
|
328 |
+
index = pd.MultiIndex.from_tuples([
|
329 |
+
("Real", "0"),
|
330 |
+
("Real", "1")
|
331 |
+
])
|
332 |
+
|
333 |
+
predict = torch.sigmoid(logits).detach().clone()
|
334 |
+
|
335 |
+
# Roc curve
|
336 |
+
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
337 |
+
|
338 |
+
# Given by a Recall of 95% (threshold avaliation)
|
339 |
+
## WSS
|
340 |
+
### Index to recall
|
341 |
+
idx_wss95 = sum(tpr < 0.95)
|
342 |
+
### threshold
|
343 |
+
thresholds95 = thresholds[idx_wss95]
|
344 |
+
|
345 |
+
### Compute the metrics
|
346 |
+
wss95_info = wss_calc(logits,labels, thresholds95 )
|
347 |
+
acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
|
348 |
+
f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
|
349 |
+
|
350 |
+
|
351 |
+
# Given by a threshold (recall avaliation)
|
352 |
+
### Compute the metrics
|
353 |
+
wss_info = wss_calc(logits,labels, threshold )
|
354 |
+
acc_wssR = fn.accuracy(predict, labels, threshold=threshold)
|
355 |
+
f1_wssR = fn.f1_score(predict, labels, threshold=threshold)
|
356 |
+
|
357 |
+
|
358 |
+
metrics= {
|
359 |
+
# WSS
|
360 |
+
"WSS@95": wss95_info['wss'],
|
361 |
+
"AWSS@95": wss95_info['awss'],
|
362 |
+
"WSS@R": wss_info['wss'],
|
363 |
+
"AWSS@R": wss_info['awss'],
|
364 |
+
# Recall
|
365 |
+
"Recall_WSS@95": wss95_info['R'],
|
366 |
+
"Recall_WSS@R": wss_info['R'],
|
367 |
+
# acc
|
368 |
+
"acc@95": acc_wss95.item(),
|
369 |
+
"acc@R": acc_wssR.item(),
|
370 |
+
# f1
|
371 |
+
"f1@95": f1_wss95.item(),
|
372 |
+
"f1@R": f1_wssR.item(),
|
373 |
+
# threshold 95
|
374 |
+
"threshold@95": thresholds95
|
375 |
+
}
|
376 |
+
|
377 |
+
# Print stats
|
378 |
+
if print_stats:
|
379 |
+
wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
|
380 |
+
wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
|
381 |
+
print(wss95)
|
382 |
+
print(wss95_adj)
|
383 |
+
print('Acc.:', round(acc_wss95.item(), 4))
|
384 |
+
print('F1-score:', round(f1_wss95.item(), 4))
|
385 |
+
print(f"threshold to wss95: {round(thresholds95, 4)}")
|
386 |
+
cm = pd.DataFrame(wss95_info['CM'],
|
387 |
+
index=index,
|
388 |
+
columns=col)
|
389 |
+
|
390 |
+
print("\nConfusion matrix:")
|
391 |
+
print(cm)
|
392 |
+
print("\n---Metrics with threshold:", threshold, "----\n")
|
393 |
+
wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
|
394 |
+
print(wss)
|
395 |
+
wss_adj= f"AWSS@R:{wss_info['awss']}"
|
396 |
+
print(wss_adj)
|
397 |
+
print('Acc.:', round(acc_wssR.item(), 4))
|
398 |
+
print('F1-score:', round(f1_wssR.item(), 4))
|
399 |
+
cm = pd.DataFrame(wss_info['CM'],
|
400 |
+
index=index,
|
401 |
+
columns=col)
|
402 |
+
|
403 |
+
print("\nConfusion matrix:")
|
404 |
+
print(cm)
|
405 |
+
|
406 |
+
|
407 |
+
# Plots
|
408 |
+
|
409 |
+
if make_plot:
|
410 |
+
|
411 |
+
fig, axes = plt.subplots(1, 4, figsize=(25,10))
|
412 |
+
alpha = torch.squeeze(predict).numpy()
|
413 |
+
|
414 |
+
# TSNE
|
415 |
+
p1 = sns.scatterplot(x=X_embedded[:, 0],
|
416 |
+
y=X_embedded[:, 1],
|
417 |
+
hue=labels,
|
418 |
+
alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20)
|
419 |
+
|
420 |
+
|
421 |
+
# WSS@95
|
422 |
+
t_wss = predict >= thresholds95
|
423 |
+
t_wss = t_wss.squeeze().numpy()
|
424 |
+
p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
|
425 |
+
y=X_embedded[t_wss, 1],
|
426 |
+
hue=labels[t_wss],
|
427 |
+
alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20)
|
428 |
+
|
429 |
+
# WSS@R
|
430 |
+
t = predict >= threshold
|
431 |
+
t = t.squeeze().numpy()
|
432 |
+
p3 = sns.scatterplot(x=X_embedded[t, 0],
|
433 |
+
y=X_embedded[t, 1],
|
434 |
+
hue=labels[t],
|
435 |
+
alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20)
|
436 |
+
|
437 |
+
# ROC-Curve
|
438 |
+
roc_auc = auc(fpr, tpr)
|
439 |
+
lw = 2
|
440 |
+
axes[3].plot(
|
441 |
+
fpr,
|
442 |
+
tpr,
|
443 |
+
color="darkorange",
|
444 |
+
lw=lw,
|
445 |
+
label="ROC curve (area = %0.2f)" % roc_auc)
|
446 |
+
axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
|
447 |
+
axes[3].axhline(y=0.95, color='r', linestyle='-')
|
448 |
+
# axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate")
|
449 |
+
axes[3].legend(loc="lower right")
|
450 |
+
axes[3].set_title(label= "ROC", size = 20)
|
451 |
+
axes[3].set_ylabel("True Positive Rate", fontsize = 15)
|
452 |
+
axes[3].set_xlabel("False Positive Rate", fontsize = 15)
|
453 |
+
|
454 |
+
|
455 |
+
if show:
|
456 |
+
plt.show()
|
457 |
+
|
458 |
+
if save:
|
459 |
+
fig.savefig(namefig, dpi=fig.dpi)
|
460 |
+
|
461 |
+
return metrics
|
462 |
+
|
463 |
+
|
464 |
+
def auc_plot(logits,labels, color = "darkorange", label = "test"):
|
465 |
+
predict = torch.sigmoid(logits).detach().clone()
|
466 |
+
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
467 |
+
roc_auc = auc(fpr, tpr)
|
468 |
+
lw = 2
|
469 |
+
|
470 |
+
label = label + str(round(roc_auc,2))
|
471 |
+
# print(label)
|
472 |
+
|
473 |
+
plt.plot(
|
474 |
+
fpr,
|
475 |
+
tpr,
|
476 |
+
color=color,
|
477 |
+
lw=lw,
|
478 |
+
label= label
|
479 |
+
)
|
480 |
+
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
|
481 |
+
plt.axhline(y=0.95, color='r', linestyle='-')
|
482 |
+
|
483 |
+
# Interface to evaluation
|
484 |
+
class diagnosis():
|
485 |
+
def __init__(self, names, Valid_resource, batch_size_test,
|
486 |
+
model,Info, device,treat_text_fun=None,start = 0):
|
487 |
+
self.names=names
|
488 |
+
self.Valid_resource=Valid_resource
|
489 |
+
self.batch_size_test=batch_size_test
|
490 |
+
self.model=model
|
491 |
+
self.start=start
|
492 |
+
self.Info = Info
|
493 |
+
self.device = device
|
494 |
+
self.treat_text_fun = treat_text_fun
|
495 |
+
|
496 |
+
|
497 |
+
# BOX INPUT
|
498 |
+
self.value_trash = widgets.FloatText(
|
499 |
+
value=0.95,
|
500 |
+
description='threshold',
|
501 |
+
disabled=False
|
502 |
+
)
|
503 |
+
self.valueb = widgets.IntText(
|
504 |
+
value=10,
|
505 |
+
description='size',
|
506 |
+
disabled=False
|
507 |
+
)
|
508 |
+
|
509 |
+
# Buttons
|
510 |
+
self.train_b = widgets.Button(description="Train")
|
511 |
+
self.next_b = widgets.Button(description="Next")
|
512 |
+
self.eval_b = widgets.Button(description="Evaluation")
|
513 |
+
|
514 |
+
self.hbox = widgets.HBox([self.train_b, self.valueb])
|
515 |
+
|
516 |
+
# Click buttons functions
|
517 |
+
self.next_b.on_click(self.Next_button)
|
518 |
+
self.train_b.on_click(self.Train_button)
|
519 |
+
self.eval_b.on_click(self.Evaluation_button)
|
520 |
+
|
521 |
+
|
522 |
+
# Next button
|
523 |
+
def Next_button(self,p):
|
524 |
+
clear_output()
|
525 |
+
self.i=self.i+1
|
526 |
+
|
527 |
+
# Select the domain data
|
528 |
+
self.domain = self.names[self.i]
|
529 |
+
self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
|
530 |
+
|
531 |
+
print("Name:", self.domain)
|
532 |
+
print(self.data['label'].value_counts())
|
533 |
+
display(self.hbox)
|
534 |
+
display(self.next_b)
|
535 |
+
|
536 |
+
|
537 |
+
# Train button
|
538 |
+
def Train_button(self, y):
|
539 |
+
clear_output()
|
540 |
+
print(self.domain)
|
541 |
+
|
542 |
+
# Prepare data for training (domain-learner)
|
543 |
+
self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
|
544 |
+
train_size_per_class = self.valueb.value,
|
545 |
+
batch_size = {'train': self.Info['inner_batch_size'],
|
546 |
+
'test': self.batch_size_test},
|
547 |
+
max_seq_length = self.Info['max_seq_length'],
|
548 |
+
tokenizer = self.Info['tokenizer'],
|
549 |
+
input = "text",
|
550 |
+
output = "label",
|
551 |
+
treat_text_fun=self.treat_text_fun)
|
552 |
+
|
553 |
+
# Train the model and predict in the test set
|
554 |
+
self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
|
555 |
+
self.model, self.device,
|
556 |
+
epoch = self.Info['inner_update_step'],
|
557 |
+
lr=self.Info['inner_update_lr'],
|
558 |
+
print_info=True,
|
559 |
+
name = self.domain)
|
560 |
+
|
561 |
+
tresh_box = widgets.HBox([self.eval_b, self.value_trash])
|
562 |
+
display(self.hbox)
|
563 |
+
display(tresh_box)
|
564 |
+
display(self.next_b)
|
565 |
+
|
566 |
+
|
567 |
+
# Evaluation button
|
568 |
+
def Evaluation_button(self, te):
|
569 |
+
clear_output()
|
570 |
+
tresh_box = widgets.HBox([self.eval_b, self.value_trash])
|
571 |
+
|
572 |
+
print(self.domain)
|
573 |
+
# print("\n")
|
574 |
+
print("-------Train data-------")
|
575 |
+
print(data_train['label'].value_counts())
|
576 |
+
print("-------Test data-------")
|
577 |
+
print(data_test['label'].value_counts())
|
578 |
+
# print("\n")
|
579 |
+
|
580 |
+
display(self.next_b)
|
581 |
+
display(tresh_box)
|
582 |
+
display(self.hbox)
|
583 |
+
|
584 |
+
# Compute metrics
|
585 |
+
metrics = plot(self.logits, self.X_embedded, self.labels,
|
586 |
+
threshold=self.Info['threshold'], show = True,
|
587 |
+
namefig= 'test',
|
588 |
+
make_plot = True,
|
589 |
+
print_stats = True,
|
590 |
+
save=False)
|
591 |
+
|
592 |
+
def __call__(self):
|
593 |
+
self.i= self.start-1
|
594 |
+
clear_output()
|
595 |
+
display(self.next_b)
|
596 |
+
|
597 |
+
|
598 |
+
|
599 |
+
|
600 |
+
# Simulation attemps of domain learner
|
601 |
+
def pipeline_simulation(Valid_resource, names_to_valid, path_save,
|
602 |
+
model, Info, device, initializer_model,
|
603 |
+
treat_text_fun=None):
|
604 |
+
n_attempt = 5
|
605 |
+
batch_test = 100
|
606 |
+
|
607 |
+
# Create a directory to save informations
|
608 |
+
for name in names_to_valid:
|
609 |
+
name = re.sub("\.csv", "",name)
|
610 |
+
Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True)
|
611 |
+
|
612 |
+
# Dict to sabe roc curves
|
613 |
+
roc_stats = defaultdict(lambda: defaultdict(
|
614 |
+
lambda: defaultdict(
|
615 |
+
list
|
616 |
+
)
|
617 |
+
)
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
|
622 |
+
|
623 |
+
all_metrics = []
|
624 |
+
# Loop over a list of domains
|
625 |
+
for name in names_to_valid:
|
626 |
+
|
627 |
+
# Select a domain dataset
|
628 |
+
data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1)
|
629 |
+
|
630 |
+
# Attempts simulation
|
631 |
+
for attempt in range(n_attempt):
|
632 |
+
print("---"*4,"attempt", attempt, "---"*4)
|
633 |
+
|
634 |
+
# Prepare data to pass to the model
|
635 |
+
data_train_loader, data_test_loader, _ , _ = prepare_data(data,
|
636 |
+
train_size_per_class = Info['k_spt'],
|
637 |
+
batch_size = {'train': Info['inner_batch_size'],
|
638 |
+
'test': batch_test},
|
639 |
+
max_seq_length = Info['max_seq_length'],
|
640 |
+
tokenizer = Info['tokenizer'],
|
641 |
+
input = "text",
|
642 |
+
output = "label",
|
643 |
+
treat_text_fun=treat_text_fun)
|
644 |
+
|
645 |
+
# Train the model and evaluate on the test set of the domain
|
646 |
+
logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader,
|
647 |
+
model, device,
|
648 |
+
epoch = Info['inner_update_step'],
|
649 |
+
lr=Info['inner_update_lr'],
|
650 |
+
print_info=False,
|
651 |
+
name = name)
|
652 |
+
|
653 |
+
|
654 |
+
name_domain = re.sub("\.csv", "",name)
|
655 |
+
|
656 |
+
# Compute the metrics
|
657 |
+
metrics = plot(logits, X_embedded, labels,
|
658 |
+
threshold=Info['threshold'], show = False,
|
659 |
+
namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots',
|
660 |
+
make_plot = True, print_stats = False, save = True)
|
661 |
+
|
662 |
+
# Compute the roc-curve
|
663 |
+
fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze())
|
664 |
+
|
665 |
+
# Save the correspoud information of the domain
|
666 |
+
metrics['name'] = name_domain
|
667 |
+
metrics['layer_size'] = Info['bert_layers']
|
668 |
+
metrics['attempt'] = attempt
|
669 |
+
roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist())
|
670 |
+
roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist())
|
671 |
+
all_metrics.append(metrics)
|
672 |
+
|
673 |
+
# Save the metrics and the roc curve of the attemp
|
674 |
+
pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv")
|
675 |
+
roc_path = path_save + "roc_stats.json"
|
676 |
+
with open(roc_path, 'w') as fp:
|
677 |
+
json.dump(roc_stats, fp)
|
678 |
+
|
679 |
+
|
680 |
+
del fpr, tpr, logits, X_embedded, labels
|
681 |
+
del features, metrics, _
|
682 |
+
|
683 |
+
|
684 |
+
# Save the information used to evaluate the validation resource
|
685 |
+
save_info = Info.copy()
|
686 |
+
save_info['model'] = initializer_model.tokenizer.name_or_path
|
687 |
+
save_info.pop("tokenizer")
|
688 |
+
save_info.pop("bert_layers")
|
689 |
+
|
690 |
+
info_path = path_save+"info.json"
|
691 |
+
with open(info_path, 'w') as fp:
|
692 |
+
json.dump(save_info, fp)
|
693 |
+
|
694 |
+
|
695 |
+
# Loading dataset statistics
|
696 |
+
def load_data_statistics(paths, names):
|
697 |
+
size = []
|
698 |
+
pos = []
|
699 |
+
neg = []
|
700 |
+
for p in paths:
|
701 |
+
data = pd.read_csv(p)
|
702 |
+
data = data.dropna()
|
703 |
+
# Dataset size
|
704 |
+
size.append(len(data))
|
705 |
+
# Number of positive labels
|
706 |
+
pos.append(data['labels'].value_counts()[1])
|
707 |
+
# Number of negative labels
|
708 |
+
neg.append(data['labels'].value_counts()[0])
|
709 |
+
del data
|
710 |
+
|
711 |
+
info_load = pd.DataFrame({
|
712 |
+
"size":size,
|
713 |
+
"pos":pos,
|
714 |
+
"neg":neg,
|
715 |
+
"names":names,
|
716 |
+
"paths": paths })
|
717 |
+
return info_load
|
718 |
+
|
719 |
+
# Loading the datasets
|
720 |
+
def load_data(train_info_load):
|
721 |
+
|
722 |
+
col = ['abstract','title', 'labels', 'domain']
|
723 |
+
|
724 |
+
data_train = pd.DataFrame(columns=col)
|
725 |
+
for p in train_info_load['paths']:
|
726 |
+
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
|
727 |
+
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
|
728 |
+
data_temp['domain'] = os.path.basename(p)
|
729 |
+
data_train = pd.concat([data_train, data_temp])
|
730 |
+
|
731 |
+
data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '')
|
732 |
+
|
733 |
+
return( data_train \
|
734 |
+
.replace({"labels":{0:"negative", 1:'positive'}})\
|
735 |
+
.rename({"labels":"label"} , axis=1)\
|
736 |
+
.loc[ :,("text","domain","label")]
|
737 |
+
)
|
738 |
+
|
739 |
+
|
740 |
+
|
ML-SLRC/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import Util_funs
|
2 |
+
|
3 |
+
|
4 |
+
|
ML-SLRC/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a859f39dc8ff55df919ef6794dcfc3ca08f873ae11fd0fd78c50d65089a6f019
|
3 |
+
size 213540902
|
app.py
CHANGED
@@ -174,7 +174,6 @@ def treat_data_input(data, etailment_txt):
|
|
174 |
|
175 |
import gc
|
176 |
from torch.optim import Adam
|
177 |
-
from scipy.stats import entropy
|
178 |
|
179 |
def treat_train_evaluate(dataload_train, dataload_remain):
|
180 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
174 |
|
175 |
import gc
|
176 |
from torch.optim import Adam
|
|
|
177 |
|
178 |
def treat_train_evaluate(dataload_train, dataload_remain):
|
179 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|