from torch import nn import torch import numpy as np from copy import deepcopy import re import unicodedata from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler from sklearn.model_selection import train_test_split from torch.optim import Adam from copy import deepcopy import gc import torch import numpy as np from torchmetrics import functional as fn import random # Pre-trained model class Encoder(nn.Module): def __init__(self, layers, freeze_bert, model): super(Encoder, self).__init__() # Dummy Parameter self.dummy_param = nn.Parameter(torch.empty(0)) # Pre-trained model self.model = deepcopy(model) # Freezing bert parameters if freeze_bert: for param in self.model.parameters(): param.requires_grad = freeze_bert # Selecting hidden layers of the pre-trained model old_model_encoder = self.model.encoder.layer new_model_encoder = nn.ModuleList() for i in layers: new_model_encoder.append(old_model_encoder[i]) self.model.encoder.layer = new_model_encoder # Feed forward def forward(self, **x): return self.model(**x)['pooler_output'] # Complete model class SLR_Classifier(nn.Module): def __init__(self, **data): super(SLR_Classifier, self).__init__() # Dummy Parameter self.dummy_param = nn.Parameter(torch.empty(0)) # Loss function # Binary Cross Entropy with logits reduced to mean self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean', pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)])) # Pre-trained model self.Encoder = Encoder(layers = data.get("bert_layers", range(12)), freeze_bert = data.get("freeze_bert", False), model = data.get("model"), ) # Feature Map Layer self.feature_map = nn.Sequential( # nn.LayerNorm(self.Encoder.model.config.hidden_size), nn.BatchNorm1d(self.Encoder.model.config.hidden_size), # nn.Dropout(data.get("drop", 0.5)), nn.Linear(self.Encoder.model.config.hidden_size, 200), nn.Dropout(data.get("drop", 0.5)), ) # Classifier Layer self.classifier = nn.Sequential( # nn.LayerNorm(self.Encoder.model.config.hidden_size), # nn.Dropout(data.get("drop", 0.5)), # nn.BatchNorm1d(self.Encoder.model.config.hidden_size), # nn.Dropout(data.get("drop", 0.5)), nn.Tanh(), nn.Linear(200, 1) ) # Initializing layer parameters nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001) nn.init.zeros_(self.feature_map[1].bias) # Feed forward def forward(self, input_ids, attention_mask, token_type_ids, labels): predict = self.Encoder(**{"input_ids":input_ids, "attention_mask":attention_mask, "token_type_ids":token_type_ids}) feature = self.feature_map(predict) logit = self.classifier(feature) predict = torch.sigmoid(logit) # Loss function loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1)) return [loss, [feature, logit], predict] # Undesirable patterns within texts patterns = { 'CONCLUSIONS AND IMPLICATIONS':'', 'BACKGROUND AND PURPOSE':'', 'EXPERIMENTAL APPROACH':'', 'KEY RESULTS AEA':'', '©':'', '®':'', 'μ':'', '(C)':'', 'OBJECTIVE:':'', 'MATERIALS AND METHODS:':'', 'SIGNIFICANCE:':'', 'BACKGROUND:':'', 'RESULTS:':'', 'METHODS:':'', 'CONCLUSIONS:':'', 'AIM:':'', 'STUDY DESIGN:':'', 'CLINICAL RELEVANCE:':'', 'CONCLUSION:':'', 'HYPOTHESIS:':'', 'CLINICAL RELEVANCE:':'', 'Questions/Purposes:':'', 'Introduction:':'', 'PURPOSE:':'', 'PATIENTS AND METHODS:':'', 'FINDINGS:':'', 'INTERPRETATIONS:':'', 'FUNDING:':'', 'PROGRESS:':'', 'CONTEXT:':'', 'MEASURES:':'', 'DESIGN:':'', 'BACKGROUND AND OBJECTIVES:':'', '

':'', '

':'', '<>':'', '+/-':'', '\(.+\)':'', '\[.+\]':'', ' \d ':'', '<':'', '>':'', '- ':'', ' +':' ', ', ,':',', ',,':',', '%':' percent', 'per cent':' percent' } patterns = {x.lower():y for x,y in patterns.items()} LABEL_MAP = {'negative': 0, 'not included':0, '0':0, 0:0, 'excluded':0, 'positive': 1, 'included':1, '1':1, 1:1, } class SLR_DataSet(Dataset): def __init__(self,treat_text =None, **args): self.tokenizer = args.get('tokenizer') self.data = args.get('data') self.max_seq_length = args.get("max_seq_length", 512) self.INPUT_NAME = args.get("input", 'x') self.LABEL_NAME = args.get("output", 'y') self.treat_text = treat_text # Tokenizing and processing text def encode_text(self, example): comment_text = example[self.INPUT_NAME] if self.treat_text: comment_text = self.treat_text(comment_text) try: labels = LABEL_MAP[example[self.LABEL_NAME].lower()] except: labels = -1 encoding = self.tokenizer.encode_plus( (comment_text, "It is great text"), add_special_tokens=True, max_length=self.max_seq_length, return_token_type_ids=True, padding="max_length", truncation=True, return_attention_mask=True, return_tensors='pt', ) return tuple(( encoding["input_ids"].flatten(), encoding["attention_mask"].flatten(), encoding["token_type_ids"].flatten(), torch.tensor([torch.tensor(labels).to(int)]) )) def __len__(self): return len(self.data) # Returning data def __getitem__(self, index: int): # print(index) data_row = self.data.reset_index().iloc[index] temp_data = self.encode_text(data_row) return temp_data class Learner(nn.Module): def __init__(self, **args): """ :param args: """ super(Learner, self).__init__() self.inner_print = args.get('inner_print') self.inner_batch_size = args.get('inner_batch_size') self.outer_update_lr = args.get('outer_update_lr') self.inner_update_lr = args.get('inner_update_lr') self.inner_update_step = args.get('inner_update_step') self.inner_update_step_eval = args.get('inner_update_step_eval') self.model = args.get('model') self.device = args.get('device') # Outer optimizer self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr) self.model.train() def forward(self, batch_tasks, training = True, valid_train = True): """ batch = [(support TensorDataset, query TensorDataset), (support TensorDataset, query TensorDataset), (support TensorDataset, query TensorDataset), (support TensorDataset, query TensorDataset)] # support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids) """ task_accs = [] task_f1 = [] task_recall = [] sum_gradients = [] num_task = len(batch_tasks) num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval # Outer loop tasks for task_id, task in enumerate(batch_tasks): support = task[0] query = task[1] name = task[2] # Copying model fast_model = deepcopy(self.model) fast_model.to(self.device) # Inner trainer optimizer inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr) # Creating training data loaders if len(support) % self.inner_batch_size == 1 : support_dataloader = DataLoader(support, sampler=RandomSampler(support), batch_size=self.inner_batch_size, drop_last=True) else: support_dataloader = DataLoader(support, sampler=RandomSampler(support), batch_size=self.inner_batch_size, drop_last=False) # steps_per_epoch=len(support) // self.inner_batch_size # total_training_steps = steps_per_epoch * 5 # warmup_steps = total_training_steps // 3 # # scheduler = get_linear_schedule_with_warmup( # inner_optimizer, # num_warmup_steps=warmup_steps, # num_training_steps=total_training_steps # ) fast_model.train() # Inner loop training epoch (support set) if valid_train: print('----Task',task_id,":", name, '----') for i in range(0, num_inner_update_step): all_loss = [] # Inner loop training batch (support set) for inner_step, batch in enumerate(support_dataloader): batch = tuple(t.to(self.device) for t in batch) input_ids, attention_mask, token_type_ids, label_id = batch # Feed Foward loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id) # Computing gradients loss.backward() # torch.nn.utils.clip_grad_norm_(fast_model.parameters(), max_norm=1) # Updating inner training parameters inner_optimizer.step() inner_optimizer.zero_grad() # Appending losses all_loss.append(loss.item()) del batch, input_ids, attention_mask, label_id torch.cuda.empty_cache() if valid_train: if (i+1) % self.inner_print == 0: print("Inner Loss: ", np.mean(all_loss)) fast_model.to(torch.device('cpu')) # Inner training phase weights if training: meta_weights = list(self.model.parameters()) fast_weights = list(fast_model.parameters()) # Appending gradients gradients = [] for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)): gradient = meta_params - fast_params if task_id == 0: sum_gradients.append(gradient) else: sum_gradients[i] += gradient # Inner test (query set) fast_model.to(self.device) fast_model.eval() if valid_train: # Inner test (query set) fast_model.to(self.device) fast_model.eval() with torch.no_grad(): # Data loader query_dataloader = DataLoader(query, sampler=None, batch_size=len(query)) query_batch = iter(query_dataloader).next() query_batch = tuple(t.to(self.device) for t in query_batch) q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch # Feedfoward _, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id) # Predictions pre_label_id = pre_label_id.detach().cpu().squeeze() # Labels q_label_id = q_label_id.detach().cpu() # Calculating metrics acc = fn.accuracy(pre_label_id, q_label_id).item() recall = fn.recall(pre_label_id, q_label_id).item(), f1 = fn.f1_score(pre_label_id, q_label_id).item() # appending metrics task_accs.append(acc) task_f1.append(f1) task_recall.append(recall) fast_model.to(torch.device('cpu')) del fast_model, inner_optimizer torch.cuda.empty_cache() print("\n") print("f1:",np.mean(task_f1)) print("recall:",np.mean(task_recall)) # Updating outer training parameters if training: # Mean of gradients for i in range(0,len(sum_gradients)): sum_gradients[i] = sum_gradients[i] / float(num_task) # Indexing parameters to model for i, params in enumerate(self.model.parameters()): params.grad = sum_gradients[i] # Updating parameters self.outer_optimizer.step() self.outer_optimizer.zero_grad() del sum_gradients gc.collect() torch.cuda.empty_cache() if valid_train: return np.mean(task_accs) else: return np.array(0) # Creating Meta Tasks class MetaTask(Dataset): def __init__(self, examples, num_task, k_support, k_query, tokenizer, training=True, max_seq_length=512, treat_text =None, **args): """ :param samples: list of samples :param num_task: number of training tasks. :param k_support: number of classes support samples per task :param k_query: number of classes query sample per task """ self.examples = examples self.num_task = num_task self.k_support = k_support self.k_query = k_query self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.treat_text = treat_text # Randomly generating tasks self.create_batch(self.num_task, training) # Creating batch def create_batch(self, num_task, training): self.supports = [] # support set self.queries = [] # query set self.task_names = [] # Name of task self.supports_indexs = [] # index of supports self.queries_indexs = [] # index of queries self.num_task=num_task # Available tasks domains = self.examples['domain'].unique() # If not training, create all tasks if not(training): self.task_names = domains num_task = len(self.task_names) self.num_task=num_task for b in range(num_task): # For each task, total_per_class = self.k_support + self.k_query task_size = 2*self.k_support + 2*self.k_query # Select a task at random if training: domain = random.choice(domains) self.task_names.append(domain) else: domain = self.task_names[b] # Task data domainExamples = self.examples[self.examples['domain'] == domain] # Minimal label quantity min_per_class = min(domainExamples['label'].value_counts()) if total_per_class > min_per_class: total_per_class = min_per_class # Select k_support + k_query task examples # Sample (n) from each label(class) selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False) # Split data into support (training) and query (testing) sets s, q = train_test_split(selected_examples, stratify= selected_examples["label"], test_size= 2*self.k_query/task_size, shuffle=True) # Permutating data s = s.sample(frac=1) q = q.sample(frac=1) # Appending indexes if not(training): self.supports_indexs.append(s.index) self.queries_indexs.append(q.index) # Creating list of support (training) and query (testing) tasks self.supports.append(s.to_dict('records')) self.queries.append(q.to_dict('records')) # Creating task tensors def create_feature_set(self, examples): all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long) all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long) all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long) all_label_ids = torch.empty(len(examples), dtype = torch.long) for _id, e in enumerate(examples): all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e) return TensorDataset( all_input_ids, all_attention_mask, all_token_type_ids, all_label_ids ) # Data encoding def encode_text(self, example): comment_text = example["text"] if self.treat_text: comment_text = self.treat_text(comment_text) labels = LABEL_MAP[example["label"]] encoding = self.tokenizer.encode_plus( (comment_text, "It is a great text."), add_special_tokens=True, max_length=self.max_seq_length, return_token_type_ids=True, padding="max_length", truncation=True, return_attention_mask=True, return_tensors='pt', ) return tuple(( encoding["input_ids"].flatten(), encoding["attention_mask"].flatten(), encoding["token_type_ids"].flatten(), torch.tensor([torch.tensor(labels).to(int)]) )) # Returns data upon calling def __getitem__(self, index): support_set = self.create_feature_set(self.supports[index]) query_set = self.create_feature_set(self.queries[index]) name = self.task_names[index] return support_set, query_set, name def __len__(self): return self.num_task class treat_text: def __init__(self, patterns): self.patterns = patterns def __call__(self,text): text = unicodedata.normalize("NFKD",str(text)) text = multiple_replace(self.patterns,text.lower()) text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text) text = re.sub('( +)',' ', text) text = re.sub('(, ,)|(,,)',',', text) text = re.sub('(%)|(per cent)',' percent', text) return text # Regex multiple replace function def multiple_replace(dict, text): # Building regex from dict keys regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys()))) # Substitution return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)