BecomeAllan commited on
Commit
6b71499
1 Parent(s): 4a0a4f7
.vscode/settings.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "workbench.colorCustomizations": {
3
+ "activityBar.background": "#590F35",
4
+ "titleBar.activeBackground": "#7C154B",
5
+ "titleBar.activeForeground": "#FEFCFD"
6
+ }
7
+ }
ML-SLRC/Info.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"inner_print": 2, "bert_layers": 4, "max_seq_length": 512, "meta_epoch": 20, "k_spt": 8, "k_qry": 8, "outer_batch_size": 5, "inner_batch_size": 4, "outer_update_lr": 5e-05, "inner_update_lr": 5e-05, "inner_update_step": 4, "inner_update_step_eval": 4, "num_task_train": 20, "pos_weight": 1.5, "tresh": 0.9, "model": "allenai/scibert_scivocab_uncased"}
ML-SLRC/ML_SLRC.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+ from copy import deepcopy
5
+ import re
6
+ import unicodedata
7
+ from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler
8
+ from sklearn.model_selection import train_test_split
9
+ from torch.optim import Adam
10
+ from copy import deepcopy
11
+ import gc
12
+ import torch
13
+ import numpy as np
14
+ from torchmetrics import functional as fn
15
+ import random
16
+
17
+
18
+ # Pre-trained model
19
+ class Encoder(nn.Module):
20
+ def __init__(self, layers, freeze_bert, model):
21
+ super(Encoder, self).__init__()
22
+
23
+ # Dummy Parameter
24
+ self.dummy_param = nn.Parameter(torch.empty(0))
25
+
26
+ # Pre-trained model
27
+ self.model = deepcopy(model)
28
+
29
+ # Freezing bert parameters
30
+ if freeze_bert:
31
+ for param in self.model.parameters():
32
+ param.requires_grad = freeze_bert
33
+
34
+ # Selecting hidden layers of the pre-trained model
35
+ old_model_encoder = self.model.encoder.layer
36
+ new_model_encoder = nn.ModuleList()
37
+
38
+ for i in layers:
39
+ new_model_encoder.append(old_model_encoder[i])
40
+
41
+ self.model.encoder.layer = new_model_encoder
42
+
43
+ # Feed forward
44
+ def forward(self, **x):
45
+ return self.model(**x)['pooler_output']
46
+
47
+ # Complete model
48
+ class SLR_Classifier(nn.Module):
49
+ def __init__(self, **data):
50
+ super(SLR_Classifier, self).__init__()
51
+
52
+ # Dummy Parameter
53
+ self.dummy_param = nn.Parameter(torch.empty(0))
54
+
55
+ # Loss function
56
+ # Binary Cross Entropy with logits reduced to mean
57
+ self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
58
+ pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
59
+
60
+ # Pre-trained model
61
+ self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
62
+ freeze_bert = data.get("freeze_bert", False),
63
+ model = data.get("model"),
64
+ )
65
+
66
+ # Feature Map Layer
67
+ self.feature_map = nn.Sequential(
68
+ # nn.LayerNorm(self.Encoder.model.config.hidden_size),
69
+ nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
70
+ # nn.Dropout(data.get("drop", 0.5)),
71
+ nn.Linear(self.Encoder.model.config.hidden_size, 200),
72
+ nn.Dropout(data.get("drop", 0.5)),
73
+ )
74
+
75
+ # Classifier Layer
76
+ self.classifier = nn.Sequential(
77
+ # nn.LayerNorm(self.Encoder.model.config.hidden_size),
78
+ # nn.Dropout(data.get("drop", 0.5)),
79
+ # nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
80
+ # nn.Dropout(data.get("drop", 0.5)),
81
+ nn.Tanh(),
82
+ nn.Linear(200, 1)
83
+ )
84
+
85
+ # Initializing layer parameters
86
+ nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001)
87
+ nn.init.zeros_(self.feature_map[1].bias)
88
+
89
+ # Feed forward
90
+ def forward(self, input_ids, attention_mask, token_type_ids, labels):
91
+
92
+ predict = self.Encoder(**{"input_ids":input_ids,
93
+ "attention_mask":attention_mask,
94
+ "token_type_ids":token_type_ids})
95
+ feature = self.feature_map(predict)
96
+ logit = self.classifier(feature)
97
+
98
+ predict = torch.sigmoid(logit)
99
+
100
+ # Loss function
101
+ loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1))
102
+
103
+ return [loss, [feature, logit], predict]
104
+
105
+ # Undesirable patterns within texts
106
+ patterns = {
107
+ 'CONCLUSIONS AND IMPLICATIONS':'',
108
+ 'BACKGROUND AND PURPOSE':'',
109
+ 'EXPERIMENTAL APPROACH':'',
110
+ 'KEY RESULTS AEA':'',
111
+ '©':'',
112
+ '®':'',
113
+ 'μ':'',
114
+ '(C)':'',
115
+ 'OBJECTIVE:':'',
116
+ 'MATERIALS AND METHODS:':'',
117
+ 'SIGNIFICANCE:':'',
118
+ 'BACKGROUND:':'',
119
+ 'RESULTS:':'',
120
+ 'METHODS:':'',
121
+ 'CONCLUSIONS:':'',
122
+ 'AIM:':'',
123
+ 'STUDY DESIGN:':'',
124
+ 'CLINICAL RELEVANCE:':'',
125
+ 'CONCLUSION:':'',
126
+ 'HYPOTHESIS:':'',
127
+ 'CLINICAL RELEVANCE:':'',
128
+ 'Questions/Purposes:':'',
129
+ 'Introduction:':'',
130
+ 'PURPOSE:':'',
131
+ 'PATIENTS AND METHODS:':'',
132
+ 'FINDINGS:':'',
133
+ 'INTERPRETATIONS:':'',
134
+ 'FUNDING:':'',
135
+ 'PROGRESS:':'',
136
+ 'CONTEXT:':'',
137
+ 'MEASURES:':'',
138
+ 'DESIGN:':'',
139
+ 'BACKGROUND AND OBJECTIVES:':'',
140
+ '<p>':'',
141
+ '</p>':'',
142
+ '<<ETX>>':'',
143
+ '+/-':'',
144
+ '\(.+\)':'',
145
+ '\[.+\]':'',
146
+ ' \d ':'',
147
+ '<':'',
148
+ '>':'',
149
+ '- ':'',
150
+ ' +':' ',
151
+ ', ,':',',
152
+ ',,':',',
153
+ '%':' percent',
154
+ 'per cent':' percent'
155
+ }
156
+
157
+ patterns = {x.lower():y for x,y in patterns.items()}
158
+
159
+
160
+ LABEL_MAP = {'negative': 0,
161
+ 'not included':0,
162
+ '0':0,
163
+ 0:0,
164
+ 'excluded':0,
165
+ 'positive': 1,
166
+ 'included':1,
167
+ '1':1,
168
+ 1:1,
169
+ }
170
+
171
+ class SLR_DataSet(Dataset):
172
+ def __init__(self,treat_text =None, **args):
173
+ self.tokenizer = args.get('tokenizer')
174
+ self.data = args.get('data')
175
+ self.max_seq_length = args.get("max_seq_length", 512)
176
+ self.INPUT_NAME = args.get("input", 'x')
177
+ self.LABEL_NAME = args.get("output", 'y')
178
+ self.treat_text = treat_text
179
+
180
+ # Tokenizing and processing text
181
+ def encode_text(self, example):
182
+ comment_text = example[self.INPUT_NAME]
183
+ if self.treat_text:
184
+ comment_text = self.treat_text(comment_text)
185
+
186
+ try:
187
+ labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
188
+ except:
189
+ labels = -1
190
+
191
+ encoding = self.tokenizer.encode_plus(
192
+ (comment_text, "It is great text"),
193
+ add_special_tokens=True,
194
+ max_length=self.max_seq_length,
195
+ return_token_type_ids=True,
196
+ padding="max_length",
197
+ truncation=True,
198
+ return_attention_mask=True,
199
+ return_tensors='pt',
200
+ )
201
+
202
+
203
+ return tuple((
204
+ encoding["input_ids"].flatten(),
205
+ encoding["attention_mask"].flatten(),
206
+ encoding["token_type_ids"].flatten(),
207
+ torch.tensor([torch.tensor(labels).to(int)])
208
+ ))
209
+
210
+
211
+ def __len__(self):
212
+ return len(self.data)
213
+
214
+ # Returning data
215
+ def __getitem__(self, index: int):
216
+ # print(index)
217
+ data_row = self.data.reset_index().iloc[index]
218
+ temp_data = self.encode_text(data_row)
219
+ return temp_data
220
+
221
+
222
+ class Learner(nn.Module):
223
+
224
+ def __init__(self, **args):
225
+ """
226
+ :param args:
227
+ """
228
+ super(Learner, self).__init__()
229
+
230
+ self.inner_print = args.get('inner_print')
231
+ self.inner_batch_size = args.get('inner_batch_size')
232
+ self.outer_update_lr = args.get('outer_update_lr')
233
+ self.inner_update_lr = args.get('inner_update_lr')
234
+ self.inner_update_step = args.get('inner_update_step')
235
+ self.inner_update_step_eval = args.get('inner_update_step_eval')
236
+ self.model = args.get('model')
237
+ self.device = args.get('device')
238
+
239
+ # Outer optimizer
240
+ self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
241
+ self.model.train()
242
+
243
+ def forward(self, batch_tasks, training = True, valid_train = True):
244
+ """
245
+ batch = [(support TensorDataset, query TensorDataset),
246
+ (support TensorDataset, query TensorDataset),
247
+ (support TensorDataset, query TensorDataset),
248
+ (support TensorDataset, query TensorDataset)]
249
+
250
+ # support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
251
+ """
252
+ task_accs = []
253
+ task_f1 = []
254
+ task_recall = []
255
+ sum_gradients = []
256
+ num_task = len(batch_tasks)
257
+ num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval
258
+
259
+ # Outer loop tasks
260
+ for task_id, task in enumerate(batch_tasks):
261
+ support = task[0]
262
+ query = task[1]
263
+ name = task[2]
264
+
265
+ # Copying model
266
+ fast_model = deepcopy(self.model)
267
+ fast_model.to(self.device)
268
+
269
+ # Inner trainer optimizer
270
+ inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
271
+
272
+ # Creating training data loaders
273
+ if len(support) % self.inner_batch_size == 1 :
274
+ support_dataloader = DataLoader(support, sampler=RandomSampler(support),
275
+ batch_size=self.inner_batch_size,
276
+ drop_last=True)
277
+ else:
278
+ support_dataloader = DataLoader(support, sampler=RandomSampler(support),
279
+ batch_size=self.inner_batch_size,
280
+ drop_last=False)
281
+
282
+ # steps_per_epoch=len(support) // self.inner_batch_size
283
+ # total_training_steps = steps_per_epoch * 5
284
+ # warmup_steps = total_training_steps // 3
285
+ #
286
+
287
+ # scheduler = get_linear_schedule_with_warmup(
288
+ # inner_optimizer,
289
+ # num_warmup_steps=warmup_steps,
290
+ # num_training_steps=total_training_steps
291
+ # )
292
+
293
+ fast_model.train()
294
+
295
+ # Inner loop training epoch (support set)
296
+ if valid_train:
297
+ print('----Task',task_id,":", name, '----')
298
+
299
+ for i in range(0, num_inner_update_step):
300
+ all_loss = []
301
+
302
+ # Inner loop training batch (support set)
303
+ for inner_step, batch in enumerate(support_dataloader):
304
+ batch = tuple(t.to(self.device) for t in batch)
305
+ input_ids, attention_mask, token_type_ids, label_id = batch
306
+
307
+ # Feed Foward
308
+ loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id)
309
+
310
+ # Computing gradients
311
+ loss.backward()
312
+ # torch.nn.utils.clip_grad_norm_(fast_model.parameters(), max_norm=1)
313
+
314
+ # Updating inner training parameters
315
+ inner_optimizer.step()
316
+ inner_optimizer.zero_grad()
317
+
318
+ # Appending losses
319
+ all_loss.append(loss.item())
320
+
321
+ del batch, input_ids, attention_mask, label_id
322
+ torch.cuda.empty_cache()
323
+
324
+ if valid_train:
325
+ if (i+1) % self.inner_print == 0:
326
+ print("Inner Loss: ", np.mean(all_loss))
327
+
328
+ fast_model.to(torch.device('cpu'))
329
+
330
+ # Inner training phase weights
331
+ if training:
332
+ meta_weights = list(self.model.parameters())
333
+ fast_weights = list(fast_model.parameters())
334
+
335
+ # Appending gradients
336
+ gradients = []
337
+ for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)):
338
+ gradient = meta_params - fast_params
339
+ if task_id == 0:
340
+ sum_gradients.append(gradient)
341
+ else:
342
+ sum_gradients[i] += gradient
343
+
344
+
345
+ # Inner test (query set)
346
+ fast_model.to(self.device)
347
+ fast_model.eval()
348
+
349
+ if valid_train:
350
+ # Inner test (query set)
351
+ fast_model.to(self.device)
352
+ fast_model.eval()
353
+
354
+ with torch.no_grad():
355
+ # Data loader
356
+ query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
357
+ query_batch = iter(query_dataloader).next()
358
+ query_batch = tuple(t.to(self.device) for t in query_batch)
359
+ q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch
360
+
361
+ # Feedfoward
362
+ _, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id)
363
+
364
+ # Predictions
365
+ pre_label_id = pre_label_id.detach().cpu().squeeze()
366
+ # Labels
367
+ q_label_id = q_label_id.detach().cpu()
368
+
369
+ # Calculating metrics
370
+ acc = fn.accuracy(pre_label_id, q_label_id).item()
371
+ recall = fn.recall(pre_label_id, q_label_id).item(),
372
+ f1 = fn.f1_score(pre_label_id, q_label_id).item()
373
+
374
+ # appending metrics
375
+ task_accs.append(acc)
376
+ task_f1.append(f1)
377
+ task_recall.append(recall)
378
+
379
+ fast_model.to(torch.device('cpu'))
380
+
381
+ del fast_model, inner_optimizer
382
+ torch.cuda.empty_cache()
383
+
384
+ print("\n")
385
+ print("f1:",np.mean(task_f1))
386
+ print("recall:",np.mean(task_recall))
387
+
388
+ # Updating outer training parameters
389
+ if training:
390
+ # Mean of gradients
391
+ for i in range(0,len(sum_gradients)):
392
+ sum_gradients[i] = sum_gradients[i] / float(num_task)
393
+
394
+ # Indexing parameters to model
395
+ for i, params in enumerate(self.model.parameters()):
396
+ params.grad = sum_gradients[i]
397
+
398
+ # Updating parameters
399
+ self.outer_optimizer.step()
400
+ self.outer_optimizer.zero_grad()
401
+
402
+ del sum_gradients
403
+ gc.collect()
404
+ torch.cuda.empty_cache()
405
+
406
+ if valid_train:
407
+ return np.mean(task_accs)
408
+ else:
409
+ return np.array(0)
410
+
411
+
412
+
413
+ # Creating Meta Tasks
414
+ class MetaTask(Dataset):
415
+ def __init__(self, examples, num_task, k_support, k_query,
416
+ tokenizer, training=True, max_seq_length=512,
417
+ treat_text =None, **args):
418
+ """
419
+ :param samples: list of samples
420
+ :param num_task: number of training tasks.
421
+ :param k_support: number of classes support samples per task
422
+ :param k_query: number of classes query sample per task
423
+ """
424
+ self.examples = examples
425
+
426
+ self.num_task = num_task
427
+ self.k_support = k_support
428
+ self.k_query = k_query
429
+ self.tokenizer = tokenizer
430
+ self.max_seq_length = max_seq_length
431
+ self.treat_text = treat_text
432
+
433
+ # Randomly generating tasks
434
+ self.create_batch(self.num_task, training)
435
+
436
+ # Creating batch
437
+ def create_batch(self, num_task, training):
438
+ self.supports = [] # support set
439
+ self.queries = [] # query set
440
+ self.task_names = [] # Name of task
441
+ self.supports_indexs = [] # index of supports
442
+ self.queries_indexs = [] # index of queries
443
+ self.num_task=num_task
444
+
445
+ # Available tasks
446
+ domains = self.examples['domain'].unique()
447
+
448
+ # If not training, create all tasks
449
+ if not(training):
450
+ self.task_names = domains
451
+ num_task = len(self.task_names)
452
+ self.num_task=num_task
453
+
454
+
455
+ for b in range(num_task): # For each task,
456
+ total_per_class = self.k_support + self.k_query
457
+ task_size = 2*self.k_support + 2*self.k_query
458
+
459
+ # Select a task at random
460
+ if training:
461
+ domain = random.choice(domains)
462
+ self.task_names.append(domain)
463
+ else:
464
+ domain = self.task_names[b]
465
+
466
+ # Task data
467
+ domainExamples = self.examples[self.examples['domain'] == domain]
468
+
469
+ # Minimal label quantity
470
+ min_per_class = min(domainExamples['label'].value_counts())
471
+
472
+ if total_per_class > min_per_class:
473
+ total_per_class = min_per_class
474
+
475
+ # Select k_support + k_query task examples
476
+ # Sample (n) from each label(class)
477
+ selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False)
478
+
479
+ # Split data into support (training) and query (testing) sets
480
+ s, q = train_test_split(selected_examples,
481
+ stratify= selected_examples["label"],
482
+ test_size= 2*self.k_query/task_size,
483
+ shuffle=True)
484
+
485
+ # Permutating data
486
+ s = s.sample(frac=1)
487
+ q = q.sample(frac=1)
488
+
489
+ # Appending indexes
490
+ if not(training):
491
+ self.supports_indexs.append(s.index)
492
+ self.queries_indexs.append(q.index)
493
+
494
+ # Creating list of support (training) and query (testing) tasks
495
+ self.supports.append(s.to_dict('records'))
496
+ self.queries.append(q.to_dict('records'))
497
+
498
+ # Creating task tensors
499
+ def create_feature_set(self, examples):
500
+ all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
501
+ all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
502
+ all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
503
+ all_label_ids = torch.empty(len(examples), dtype = torch.long)
504
+
505
+ for _id, e in enumerate(examples):
506
+ all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e)
507
+
508
+ return TensorDataset(
509
+ all_input_ids,
510
+ all_attention_mask,
511
+ all_token_type_ids,
512
+ all_label_ids
513
+ )
514
+
515
+ # Data encoding
516
+ def encode_text(self, example):
517
+ comment_text = example["text"]
518
+
519
+ if self.treat_text:
520
+ comment_text = self.treat_text(comment_text)
521
+
522
+ labels = LABEL_MAP[example["label"]]
523
+
524
+ encoding = self.tokenizer.encode_plus(
525
+ (comment_text, "It is a great text."),
526
+ add_special_tokens=True,
527
+ max_length=self.max_seq_length,
528
+ return_token_type_ids=True,
529
+ padding="max_length",
530
+ truncation=True,
531
+ return_attention_mask=True,
532
+ return_tensors='pt',
533
+ )
534
+
535
+ return tuple((
536
+ encoding["input_ids"].flatten(),
537
+ encoding["attention_mask"].flatten(),
538
+ encoding["token_type_ids"].flatten(),
539
+ torch.tensor([torch.tensor(labels).to(int)])
540
+ ))
541
+
542
+ # Returns data upon calling
543
+ def __getitem__(self, index):
544
+ support_set = self.create_feature_set(self.supports[index])
545
+ query_set = self.create_feature_set(self.queries[index])
546
+ name = self.task_names[index]
547
+ return support_set, query_set, name
548
+
549
+ def __len__(self):
550
+ return self.num_task
551
+
552
+
553
+ class treat_text:
554
+ def __init__(self, patterns):
555
+ self.patterns = patterns
556
+
557
+ def __call__(self,text):
558
+ text = unicodedata.normalize("NFKD",str(text))
559
+ text = multiple_replace(self.patterns,text.lower())
560
+ text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
561
+ text = re.sub('( +)',' ', text)
562
+ text = re.sub('(, ,)|(,,)',',', text)
563
+ text = re.sub('(%)|(per cent)',' percent', text)
564
+ return text
565
+
566
+
567
+ # Regex multiple replace function
568
+ def multiple_replace(dict, text):
569
+
570
+ # Building regex from dict keys
571
+ regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
572
+
573
+ # Substitution
574
+ return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
ML-SLRC/Util_funs.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ML_SLRC import *
2
+
3
+ import os
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+ from torch.utils.data import DataLoader
9
+ from torch.optim import Adam
10
+
11
+ import gc
12
+ from torchmetrics import functional as fn
13
+
14
+ import random
15
+
16
+
17
+ from tqdm import tqdm
18
+
19
+ from sklearn.metrics import confusion_matrix
20
+ from sklearn.metrics import roc_curve, auc
21
+ import ipywidgets as widgets
22
+ from IPython.display import display, clear_output
23
+ import matplotlib.pyplot as plt
24
+ import warnings
25
+ import torch
26
+
27
+ import time
28
+ from sklearn.manifold import TSNE
29
+ from copy import deepcopy
30
+ import seaborn as sns
31
+ import matplotlib.pylab as plt
32
+ import json
33
+ from pathlib import Path
34
+
35
+ import re
36
+ from collections import defaultdict
37
+
38
+ # SEED = 2222
39
+
40
+ # gen_seed = torch.Generator().manual_seed(SEED)
41
+
42
+
43
+
44
+
45
+
46
+
47
+ # Random seed function
48
+ def random_seed(value):
49
+ torch.backends.cudnn.deterministic=True
50
+ torch.manual_seed(value)
51
+ torch.cuda.manual_seed(value)
52
+ np.random.seed(value)
53
+ random.seed(value)
54
+
55
+ # Tasks for meta-learner
56
+ def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
57
+ idxs = list(range(0,len(taskset)))
58
+ if is_shuffle:
59
+ random.shuffle(idxs)
60
+ for i in range(0,len(idxs), batch_size):
61
+ yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
62
+
63
+
64
+ # Prepare data to process by Domain-learner
65
+ def prepare_data(data, batch_size, tokenizer,max_seq_length,
66
+ input = 'text', output = 'label',
67
+ train_size_per_class = 5, global_datasets = False,
68
+ treat_text_fun =None):
69
+ data = data.reset_index().drop("index", axis=1)
70
+
71
+ if global_datasets:
72
+ global data_train, data_test
73
+
74
+ # Sample task for training
75
+ data_train = data.groupby('label').sample(train_size_per_class, replace=False)
76
+ idex = data.index.isin(data_train.index)
77
+
78
+ # The Test set to label by the model
79
+ data_test = data
80
+
81
+
82
+ # Transform in dataset to model
83
+ ## Train
84
+ dataset_train = SLR_DataSet(
85
+ data = data_train.sample(frac=1),
86
+ input = input,
87
+ output = output,
88
+ tokenizer=tokenizer,
89
+ max_seq_length =max_seq_length,
90
+ treat_text =treat_text_fun)
91
+
92
+ ## Test
93
+ dataset_test = SLR_DataSet(
94
+ data = data_test,
95
+ input = input,
96
+ output = output,
97
+ tokenizer=tokenizer,
98
+ max_seq_length =max_seq_length,
99
+ treat_text =treat_text_fun)
100
+
101
+ # Dataloaders
102
+ ## Train
103
+ data_train_loader = DataLoader(dataset_train,
104
+ shuffle=True,
105
+ batch_size=batch_size['train']
106
+ )
107
+
108
+ ## Test
109
+ if len(dataset_test) % batch_size['test'] == 1 :
110
+ data_test_loader = DataLoader(dataset_test,
111
+ batch_size=batch_size['test'],
112
+ drop_last=True)
113
+ else:
114
+ data_test_loader = DataLoader(dataset_test,
115
+ batch_size=batch_size['test'],
116
+ drop_last=False)
117
+
118
+ return data_train_loader, data_test_loader, data_train, data_test
119
+
120
+
121
+ # Meta trainer
122
+ def meta_train(data, model, device, Info,
123
+ print_epoch =True,
124
+ Test_resource =None,
125
+ treat_text_fun =None):
126
+
127
+ # Meta-learner model
128
+ learner = Learner(model = model, device = device, **Info)
129
+
130
+ # Testing tasks
131
+ if isinstance(Test_resource, pd.DataFrame):
132
+ test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
133
+ training=False,treat_text =treat_text_fun, **Info)
134
+
135
+
136
+ torch.clear_autocast_cache()
137
+ gc.collect()
138
+ torch.cuda.empty_cache()
139
+
140
+ # Meta epoch (Outer epoch)
141
+ for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
142
+
143
+ # Train tasks
144
+ train = MetaTask(data,
145
+ num_task = Info['num_task_train'],
146
+ k_support=Info['k_qry'],
147
+ k_query=Info['k_spt'],
148
+ treat_text =treat_text_fun, **Info)
149
+
150
+ # Batch of train tasks
151
+ db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
152
+
153
+ if print_epoch:
154
+ # Outer loop bach training
155
+ for step, task_batch in enumerate(db):
156
+ print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
157
+
158
+ # meta-feedfoward (outer-feedfoward)
159
+ acc = learner(task_batch, valid_train= print_epoch)
160
+ print('Step:', step, '\ttraining Acc:', acc)
161
+
162
+ if isinstance(Test_resource, pd.DataFrame):
163
+ # Validating Model
164
+ if ((epoch+1) % 4) + step == 0:
165
+ random_seed(123)
166
+ print("\n-----------------Testing Mode-----------------\n")
167
+
168
+ # Batch of test tasks
169
+ db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
170
+ acc_all_test = []
171
+
172
+ # Looping testing tasks
173
+ for test_batch in db_test:
174
+ acc = learner(test_batch, training = False)
175
+ acc_all_test.append(acc)
176
+
177
+ print('Test acc:', np.mean(acc_all_test))
178
+ del acc_all_test, db_test
179
+
180
+ # Restarting training randomly
181
+ random_seed(int(time.time() % 10))
182
+
183
+ else:
184
+ for step, task_batch in enumerate(db):
185
+ # meta-feedfoward (outer-feedfoward)
186
+ acc = learner(task_batch, print_epoch, valid_train= print_epoch)
187
+
188
+ torch.clear_autocast_cache()
189
+ gc.collect()
190
+ torch.cuda.empty_cache()
191
+
192
+
193
+
194
+ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name', weight_decay = 1):
195
+ # Start the model's parameters
196
+ model_meta = deepcopy(model)
197
+ optimizer = Adam(model_meta.parameters(), lr=lr, weight_decay = weight_decay)
198
+
199
+ model_meta.to(device)
200
+ model_meta.train()
201
+
202
+ # Task epoch (Inner epoch)
203
+ for i in range(0, epoch):
204
+ all_loss = []
205
+
206
+ # Inner training batch (support set)
207
+ for inner_step, batch in enumerate(data_train_loader):
208
+ batch = tuple(t.to(device) for t in batch)
209
+ input_ids, attention_mask,q_token_type_ids, label_id = batch
210
+
211
+ # Inner Feedfoward
212
+ loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
213
+
214
+ # compute grads
215
+ loss.backward()
216
+
217
+ # update parameters
218
+ optimizer.step()
219
+ optimizer.zero_grad()
220
+
221
+ all_loss.append(loss.item())
222
+
223
+
224
+ if (i % 2 == 0) & print_info:
225
+ print("Loss: ", np.mean(all_loss))
226
+
227
+
228
+ # Test evaluation
229
+ model_meta.eval()
230
+ all_loss = []
231
+ all_acc = []
232
+ features = []
233
+ labels = []
234
+ predi_logit = []
235
+
236
+ with torch.no_grad():
237
+ # Test's Batch loop
238
+ for inner_step, batch in enumerate(tqdm(data_test_loader,
239
+ desc="Test validation | " + name,
240
+ ncols=80)) :
241
+ batch = tuple(t.to(device) for t in batch)
242
+ input_ids, attention_mask,q_token_type_ids, label_id = batch
243
+
244
+ # Predictions
245
+ _, feature, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
246
+
247
+ # prediction = prediction.detach().cpu().squeeze()
248
+ # label_id = label_id.detach().cpu()
249
+ logit = feature[1].detach().cpu()
250
+ # feature_lat = feature[0].detach().cpu()
251
+
252
+ # labels.append(label_id.numpy().squeeze())
253
+ # features.append(feature_lat.numpy())
254
+ predi_logit.append(logit.numpy())
255
+
256
+ # Accuracy over the test's bach
257
+ # acc = fn.accuracy(prediction, label_id).item()
258
+ # all_acc.append(acc)
259
+ del input_ids, attention_mask, label_id, batch
260
+
261
+ if print_info:
262
+ print("acc:", np.mean(all_acc))
263
+
264
+ model_meta.to('cpu')
265
+ gc.collect()
266
+ torch.cuda.empty_cache()
267
+
268
+ del model_meta, optimizer
269
+
270
+ logits = np.concatenate(np.array(predi_logit,dtype=object))
271
+ logits = torch.tensor(logits.astype(np.float32)).detach().clone()
272
+ # return features, labels, predi_logit
273
+
274
+ return logits.detach().clone()
275
+
276
+ # Process predictions and map the feature_map in tsne
277
+ def map_feature_tsne(features, labels, predi_logit):
278
+
279
+ features = np.concatenate(np.array(features,dtype=object))
280
+ features = torch.tensor(features.astype(np.float32)).detach().clone()
281
+
282
+ labels = np.concatenate(np.array(labels,dtype=object))
283
+ labels = torch.tensor(labels.astype(int)).detach().clone()
284
+
285
+ logits = np.concatenate(np.array(predi_logit,dtype=object))
286
+ logits = torch.tensor(logits.astype(np.float32)).detach().clone()
287
+
288
+ # Dimention reduction
289
+ X_embedded = TSNE(n_components=2, learning_rate='auto',
290
+ init='random').fit_transform(features.detach().clone())
291
+
292
+ return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
293
+
294
+ def wss_calc(logit, labels, trsh = 0.5):
295
+
296
+ # Prediction label given the threshold
297
+ predict_trash = torch.sigmoid(logit).squeeze() >= trsh
298
+
299
+ # Compute confusion matrix values
300
+ CM = confusion_matrix(labels, predict_trash.to(int) )
301
+ tn, fp, fne, tp = CM.ravel()
302
+
303
+ P = (tp + fne)
304
+ N = (tn + fp)
305
+ recall = tp/(tp+fne)
306
+
307
+ # WSS
308
+ wss = (tn + fne)/len(labels) -(1- recall)
309
+
310
+ # AWSS
311
+ awss = (tn/N - fne/P)
312
+
313
+ return {
314
+ "wss": round(wss,4),
315
+ "awss": round(awss,4),
316
+ "R": round(recall,4),
317
+ "CM": CM
318
+ }
319
+
320
+
321
+ # Compute the metrics
322
+ def plot(logits, X_embedded, labels, threshold, show = True,
323
+ namefig = "plot", make_plot = True, print_stats = True, save = True):
324
+ col = pd.MultiIndex.from_tuples([
325
+ ("Predict", "0"),
326
+ ("Predict", "1")
327
+ ])
328
+ index = pd.MultiIndex.from_tuples([
329
+ ("Real", "0"),
330
+ ("Real", "1")
331
+ ])
332
+
333
+ predict = torch.sigmoid(logits).detach().clone()
334
+
335
+ # Roc curve
336
+ fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
337
+
338
+ # Given by a Recall of 95% (threshold avaliation)
339
+ ## WSS
340
+ ### Index to recall
341
+ idx_wss95 = sum(tpr < 0.95)
342
+ ### threshold
343
+ thresholds95 = thresholds[idx_wss95]
344
+
345
+ ### Compute the metrics
346
+ wss95_info = wss_calc(logits,labels, thresholds95 )
347
+ acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
348
+ f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
349
+
350
+
351
+ # Given by a threshold (recall avaliation)
352
+ ### Compute the metrics
353
+ wss_info = wss_calc(logits,labels, threshold )
354
+ acc_wssR = fn.accuracy(predict, labels, threshold=threshold)
355
+ f1_wssR = fn.f1_score(predict, labels, threshold=threshold)
356
+
357
+
358
+ metrics= {
359
+ # WSS
360
+ "WSS@95": wss95_info['wss'],
361
+ "AWSS@95": wss95_info['awss'],
362
+ "WSS@R": wss_info['wss'],
363
+ "AWSS@R": wss_info['awss'],
364
+ # Recall
365
+ "Recall_WSS@95": wss95_info['R'],
366
+ "Recall_WSS@R": wss_info['R'],
367
+ # acc
368
+ "acc@95": acc_wss95.item(),
369
+ "acc@R": acc_wssR.item(),
370
+ # f1
371
+ "f1@95": f1_wss95.item(),
372
+ "f1@R": f1_wssR.item(),
373
+ # threshold 95
374
+ "threshold@95": thresholds95
375
+ }
376
+
377
+ # Print stats
378
+ if print_stats:
379
+ wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
380
+ wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
381
+ print(wss95)
382
+ print(wss95_adj)
383
+ print('Acc.:', round(acc_wss95.item(), 4))
384
+ print('F1-score:', round(f1_wss95.item(), 4))
385
+ print(f"threshold to wss95: {round(thresholds95, 4)}")
386
+ cm = pd.DataFrame(wss95_info['CM'],
387
+ index=index,
388
+ columns=col)
389
+
390
+ print("\nConfusion matrix:")
391
+ print(cm)
392
+ print("\n---Metrics with threshold:", threshold, "----\n")
393
+ wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
394
+ print(wss)
395
+ wss_adj= f"AWSS@R:{wss_info['awss']}"
396
+ print(wss_adj)
397
+ print('Acc.:', round(acc_wssR.item(), 4))
398
+ print('F1-score:', round(f1_wssR.item(), 4))
399
+ cm = pd.DataFrame(wss_info['CM'],
400
+ index=index,
401
+ columns=col)
402
+
403
+ print("\nConfusion matrix:")
404
+ print(cm)
405
+
406
+
407
+ # Plots
408
+
409
+ if make_plot:
410
+
411
+ fig, axes = plt.subplots(1, 4, figsize=(25,10))
412
+ alpha = torch.squeeze(predict).numpy()
413
+
414
+ # TSNE
415
+ p1 = sns.scatterplot(x=X_embedded[:, 0],
416
+ y=X_embedded[:, 1],
417
+ hue=labels,
418
+ alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20)
419
+
420
+
421
+ # WSS@95
422
+ t_wss = predict >= thresholds95
423
+ t_wss = t_wss.squeeze().numpy()
424
+ p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
425
+ y=X_embedded[t_wss, 1],
426
+ hue=labels[t_wss],
427
+ alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20)
428
+
429
+ # WSS@R
430
+ t = predict >= threshold
431
+ t = t.squeeze().numpy()
432
+ p3 = sns.scatterplot(x=X_embedded[t, 0],
433
+ y=X_embedded[t, 1],
434
+ hue=labels[t],
435
+ alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20)
436
+
437
+ # ROC-Curve
438
+ roc_auc = auc(fpr, tpr)
439
+ lw = 2
440
+ axes[3].plot(
441
+ fpr,
442
+ tpr,
443
+ color="darkorange",
444
+ lw=lw,
445
+ label="ROC curve (area = %0.2f)" % roc_auc)
446
+ axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
447
+ axes[3].axhline(y=0.95, color='r', linestyle='-')
448
+ # axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate")
449
+ axes[3].legend(loc="lower right")
450
+ axes[3].set_title(label= "ROC", size = 20)
451
+ axes[3].set_ylabel("True Positive Rate", fontsize = 15)
452
+ axes[3].set_xlabel("False Positive Rate", fontsize = 15)
453
+
454
+
455
+ if show:
456
+ plt.show()
457
+
458
+ if save:
459
+ fig.savefig(namefig, dpi=fig.dpi)
460
+
461
+ return metrics
462
+
463
+
464
+ def auc_plot(logits,labels, color = "darkorange", label = "test"):
465
+ predict = torch.sigmoid(logits).detach().clone()
466
+ fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
467
+ roc_auc = auc(fpr, tpr)
468
+ lw = 2
469
+
470
+ label = label + str(round(roc_auc,2))
471
+ # print(label)
472
+
473
+ plt.plot(
474
+ fpr,
475
+ tpr,
476
+ color=color,
477
+ lw=lw,
478
+ label= label
479
+ )
480
+ plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
481
+ plt.axhline(y=0.95, color='r', linestyle='-')
482
+
483
+ # Interface to evaluation
484
+ class diagnosis():
485
+ def __init__(self, names, Valid_resource, batch_size_test,
486
+ model,Info, device,treat_text_fun=None,start = 0):
487
+ self.names=names
488
+ self.Valid_resource=Valid_resource
489
+ self.batch_size_test=batch_size_test
490
+ self.model=model
491
+ self.start=start
492
+ self.Info = Info
493
+ self.device = device
494
+ self.treat_text_fun = treat_text_fun
495
+
496
+
497
+ # BOX INPUT
498
+ self.value_trash = widgets.FloatText(
499
+ value=0.95,
500
+ description='threshold',
501
+ disabled=False
502
+ )
503
+ self.valueb = widgets.IntText(
504
+ value=10,
505
+ description='size',
506
+ disabled=False
507
+ )
508
+
509
+ # Buttons
510
+ self.train_b = widgets.Button(description="Train")
511
+ self.next_b = widgets.Button(description="Next")
512
+ self.eval_b = widgets.Button(description="Evaluation")
513
+
514
+ self.hbox = widgets.HBox([self.train_b, self.valueb])
515
+
516
+ # Click buttons functions
517
+ self.next_b.on_click(self.Next_button)
518
+ self.train_b.on_click(self.Train_button)
519
+ self.eval_b.on_click(self.Evaluation_button)
520
+
521
+
522
+ # Next button
523
+ def Next_button(self,p):
524
+ clear_output()
525
+ self.i=self.i+1
526
+
527
+ # Select the domain data
528
+ self.domain = self.names[self.i]
529
+ self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
530
+
531
+ print("Name:", self.domain)
532
+ print(self.data['label'].value_counts())
533
+ display(self.hbox)
534
+ display(self.next_b)
535
+
536
+
537
+ # Train button
538
+ def Train_button(self, y):
539
+ clear_output()
540
+ print(self.domain)
541
+
542
+ # Prepare data for training (domain-learner)
543
+ self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
544
+ train_size_per_class = self.valueb.value,
545
+ batch_size = {'train': self.Info['inner_batch_size'],
546
+ 'test': self.batch_size_test},
547
+ max_seq_length = self.Info['max_seq_length'],
548
+ tokenizer = self.Info['tokenizer'],
549
+ input = "text",
550
+ output = "label",
551
+ treat_text_fun=self.treat_text_fun)
552
+
553
+ # Train the model and predict in the test set
554
+ self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
555
+ self.model, self.device,
556
+ epoch = self.Info['inner_update_step'],
557
+ lr=self.Info['inner_update_lr'],
558
+ print_info=True,
559
+ name = self.domain)
560
+
561
+ tresh_box = widgets.HBox([self.eval_b, self.value_trash])
562
+ display(self.hbox)
563
+ display(tresh_box)
564
+ display(self.next_b)
565
+
566
+
567
+ # Evaluation button
568
+ def Evaluation_button(self, te):
569
+ clear_output()
570
+ tresh_box = widgets.HBox([self.eval_b, self.value_trash])
571
+
572
+ print(self.domain)
573
+ # print("\n")
574
+ print("-------Train data-------")
575
+ print(data_train['label'].value_counts())
576
+ print("-------Test data-------")
577
+ print(data_test['label'].value_counts())
578
+ # print("\n")
579
+
580
+ display(self.next_b)
581
+ display(tresh_box)
582
+ display(self.hbox)
583
+
584
+ # Compute metrics
585
+ metrics = plot(self.logits, self.X_embedded, self.labels,
586
+ threshold=self.Info['threshold'], show = True,
587
+ namefig= 'test',
588
+ make_plot = True,
589
+ print_stats = True,
590
+ save=False)
591
+
592
+ def __call__(self):
593
+ self.i= self.start-1
594
+ clear_output()
595
+ display(self.next_b)
596
+
597
+
598
+
599
+
600
+ # Simulation attemps of domain learner
601
+ def pipeline_simulation(Valid_resource, names_to_valid, path_save,
602
+ model, Info, device, initializer_model,
603
+ treat_text_fun=None):
604
+ n_attempt = 5
605
+ batch_test = 100
606
+
607
+ # Create a directory to save informations
608
+ for name in names_to_valid:
609
+ name = re.sub("\.csv", "",name)
610
+ Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True)
611
+
612
+ # Dict to sabe roc curves
613
+ roc_stats = defaultdict(lambda: defaultdict(
614
+ lambda: defaultdict(
615
+ list
616
+ )
617
+ )
618
+ )
619
+
620
+
621
+
622
+
623
+ all_metrics = []
624
+ # Loop over a list of domains
625
+ for name in names_to_valid:
626
+
627
+ # Select a domain dataset
628
+ data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1)
629
+
630
+ # Attempts simulation
631
+ for attempt in range(n_attempt):
632
+ print("---"*4,"attempt", attempt, "---"*4)
633
+
634
+ # Prepare data to pass to the model
635
+ data_train_loader, data_test_loader, _ , _ = prepare_data(data,
636
+ train_size_per_class = Info['k_spt'],
637
+ batch_size = {'train': Info['inner_batch_size'],
638
+ 'test': batch_test},
639
+ max_seq_length = Info['max_seq_length'],
640
+ tokenizer = Info['tokenizer'],
641
+ input = "text",
642
+ output = "label",
643
+ treat_text_fun=treat_text_fun)
644
+
645
+ # Train the model and evaluate on the test set of the domain
646
+ logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader,
647
+ model, device,
648
+ epoch = Info['inner_update_step'],
649
+ lr=Info['inner_update_lr'],
650
+ print_info=False,
651
+ name = name)
652
+
653
+
654
+ name_domain = re.sub("\.csv", "",name)
655
+
656
+ # Compute the metrics
657
+ metrics = plot(logits, X_embedded, labels,
658
+ threshold=Info['threshold'], show = False,
659
+ namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots',
660
+ make_plot = True, print_stats = False, save = True)
661
+
662
+ # Compute the roc-curve
663
+ fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze())
664
+
665
+ # Save the correspoud information of the domain
666
+ metrics['name'] = name_domain
667
+ metrics['layer_size'] = Info['bert_layers']
668
+ metrics['attempt'] = attempt
669
+ roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist())
670
+ roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist())
671
+ all_metrics.append(metrics)
672
+
673
+ # Save the metrics and the roc curve of the attemp
674
+ pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv")
675
+ roc_path = path_save + "roc_stats.json"
676
+ with open(roc_path, 'w') as fp:
677
+ json.dump(roc_stats, fp)
678
+
679
+
680
+ del fpr, tpr, logits, X_embedded, labels
681
+ del features, metrics, _
682
+
683
+
684
+ # Save the information used to evaluate the validation resource
685
+ save_info = Info.copy()
686
+ save_info['model'] = initializer_model.tokenizer.name_or_path
687
+ save_info.pop("tokenizer")
688
+ save_info.pop("bert_layers")
689
+
690
+ info_path = path_save+"info.json"
691
+ with open(info_path, 'w') as fp:
692
+ json.dump(save_info, fp)
693
+
694
+
695
+ # Loading dataset statistics
696
+ def load_data_statistics(paths, names):
697
+ size = []
698
+ pos = []
699
+ neg = []
700
+ for p in paths:
701
+ data = pd.read_csv(p)
702
+ data = data.dropna()
703
+ # Dataset size
704
+ size.append(len(data))
705
+ # Number of positive labels
706
+ pos.append(data['labels'].value_counts()[1])
707
+ # Number of negative labels
708
+ neg.append(data['labels'].value_counts()[0])
709
+ del data
710
+
711
+ info_load = pd.DataFrame({
712
+ "size":size,
713
+ "pos":pos,
714
+ "neg":neg,
715
+ "names":names,
716
+ "paths": paths })
717
+ return info_load
718
+
719
+ # Loading the datasets
720
+ def load_data(train_info_load):
721
+
722
+ col = ['abstract','title', 'labels', 'domain']
723
+
724
+ data_train = pd.DataFrame(columns=col)
725
+ for p in train_info_load['paths']:
726
+ data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
727
+ data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
728
+ data_temp['domain'] = os.path.basename(p)
729
+ data_train = pd.concat([data_train, data_temp])
730
+
731
+ data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '')
732
+
733
+ return( data_train \
734
+ .replace({"labels":{0:"negative", 1:'positive'}})\
735
+ .rename({"labels":"label"} , axis=1)\
736
+ .loc[ :,("text","domain","label")]
737
+ )
738
+
739
+
740
+
ML-SLRC/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import Util_funs
2
+
3
+
4
+
ML-SLRC/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a859f39dc8ff55df919ef6794dcfc3ca08f873ae11fd0fd78c50d65089a6f019
3
+ size 213540902
app.py CHANGED
@@ -174,7 +174,6 @@ def treat_data_input(data, etailment_txt):
174
 
175
  import gc
176
  from torch.optim import Adam
177
- from scipy.stats import entropy
178
 
179
  def treat_train_evaluate(dataload_train, dataload_remain):
180
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
174
 
175
  import gc
176
  from torch.optim import Adam
 
177
 
178
  def treat_train_evaluate(dataload_train, dataload_remain):
179
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')