meta-demo-app / ML-SLRC /ML_SLRC.py
BecomeAllan
update
6b71499
raw
history blame contribute delete
No virus
19.2 kB
from torch import nn
import torch
import numpy as np
from copy import deepcopy
import re
import unicodedata
from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from copy import deepcopy
import gc
import torch
import numpy as np
from torchmetrics import functional as fn
import random
# Pre-trained model
class Encoder(nn.Module):
def __init__(self, layers, freeze_bert, model):
super(Encoder, self).__init__()
# Dummy Parameter
self.dummy_param = nn.Parameter(torch.empty(0))
# Pre-trained model
self.model = deepcopy(model)
# Freezing bert parameters
if freeze_bert:
for param in self.model.parameters():
param.requires_grad = freeze_bert
# Selecting hidden layers of the pre-trained model
old_model_encoder = self.model.encoder.layer
new_model_encoder = nn.ModuleList()
for i in layers:
new_model_encoder.append(old_model_encoder[i])
self.model.encoder.layer = new_model_encoder
# Feed forward
def forward(self, **x):
return self.model(**x)['pooler_output']
# Complete model
class SLR_Classifier(nn.Module):
def __init__(self, **data):
super(SLR_Classifier, self).__init__()
# Dummy Parameter
self.dummy_param = nn.Parameter(torch.empty(0))
# Loss function
# Binary Cross Entropy with logits reduced to mean
self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
# Pre-trained model
self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
freeze_bert = data.get("freeze_bert", False),
model = data.get("model"),
)
# Feature Map Layer
self.feature_map = nn.Sequential(
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
# nn.Dropout(data.get("drop", 0.5)),
nn.Linear(self.Encoder.model.config.hidden_size, 200),
nn.Dropout(data.get("drop", 0.5)),
)
# Classifier Layer
self.classifier = nn.Sequential(
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
# nn.Dropout(data.get("drop", 0.5)),
# nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
# nn.Dropout(data.get("drop", 0.5)),
nn.Tanh(),
nn.Linear(200, 1)
)
# Initializing layer parameters
nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001)
nn.init.zeros_(self.feature_map[1].bias)
# Feed forward
def forward(self, input_ids, attention_mask, token_type_ids, labels):
predict = self.Encoder(**{"input_ids":input_ids,
"attention_mask":attention_mask,
"token_type_ids":token_type_ids})
feature = self.feature_map(predict)
logit = self.classifier(feature)
predict = torch.sigmoid(logit)
# Loss function
loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1))
return [loss, [feature, logit], predict]
# Undesirable patterns within texts
patterns = {
'CONCLUSIONS AND IMPLICATIONS':'',
'BACKGROUND AND PURPOSE':'',
'EXPERIMENTAL APPROACH':'',
'KEY RESULTS AEA':'',
'©':'',
'®':'',
'μ':'',
'(C)':'',
'OBJECTIVE:':'',
'MATERIALS AND METHODS:':'',
'SIGNIFICANCE:':'',
'BACKGROUND:':'',
'RESULTS:':'',
'METHODS:':'',
'CONCLUSIONS:':'',
'AIM:':'',
'STUDY DESIGN:':'',
'CLINICAL RELEVANCE:':'',
'CONCLUSION:':'',
'HYPOTHESIS:':'',
'CLINICAL RELEVANCE:':'',
'Questions/Purposes:':'',
'Introduction:':'',
'PURPOSE:':'',
'PATIENTS AND METHODS:':'',
'FINDINGS:':'',
'INTERPRETATIONS:':'',
'FUNDING:':'',
'PROGRESS:':'',
'CONTEXT:':'',
'MEASURES:':'',
'DESIGN:':'',
'BACKGROUND AND OBJECTIVES:':'',
'<p>':'',
'</p>':'',
'<<ETX>>':'',
'+/-':'',
'\(.+\)':'',
'\[.+\]':'',
' \d ':'',
'<':'',
'>':'',
'- ':'',
' +':' ',
', ,':',',
',,':',',
'%':' percent',
'per cent':' percent'
}
patterns = {x.lower():y for x,y in patterns.items()}
LABEL_MAP = {'negative': 0,
'not included':0,
'0':0,
0:0,
'excluded':0,
'positive': 1,
'included':1,
'1':1,
1:1,
}
class SLR_DataSet(Dataset):
def __init__(self,treat_text =None, **args):
self.tokenizer = args.get('tokenizer')
self.data = args.get('data')
self.max_seq_length = args.get("max_seq_length", 512)
self.INPUT_NAME = args.get("input", 'x')
self.LABEL_NAME = args.get("output", 'y')
self.treat_text = treat_text
# Tokenizing and processing text
def encode_text(self, example):
comment_text = example[self.INPUT_NAME]
if self.treat_text:
comment_text = self.treat_text(comment_text)
try:
labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
except:
labels = -1
encoding = self.tokenizer.encode_plus(
(comment_text, "It is great text"),
add_special_tokens=True,
max_length=self.max_seq_length,
return_token_type_ids=True,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
return tuple((
encoding["input_ids"].flatten(),
encoding["attention_mask"].flatten(),
encoding["token_type_ids"].flatten(),
torch.tensor([torch.tensor(labels).to(int)])
))
def __len__(self):
return len(self.data)
# Returning data
def __getitem__(self, index: int):
# print(index)
data_row = self.data.reset_index().iloc[index]
temp_data = self.encode_text(data_row)
return temp_data
class Learner(nn.Module):
def __init__(self, **args):
"""
:param args:
"""
super(Learner, self).__init__()
self.inner_print = args.get('inner_print')
self.inner_batch_size = args.get('inner_batch_size')
self.outer_update_lr = args.get('outer_update_lr')
self.inner_update_lr = args.get('inner_update_lr')
self.inner_update_step = args.get('inner_update_step')
self.inner_update_step_eval = args.get('inner_update_step_eval')
self.model = args.get('model')
self.device = args.get('device')
# Outer optimizer
self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
self.model.train()
def forward(self, batch_tasks, training = True, valid_train = True):
"""
batch = [(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset)]
# support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
"""
task_accs = []
task_f1 = []
task_recall = []
sum_gradients = []
num_task = len(batch_tasks)
num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval
# Outer loop tasks
for task_id, task in enumerate(batch_tasks):
support = task[0]
query = task[1]
name = task[2]
# Copying model
fast_model = deepcopy(self.model)
fast_model.to(self.device)
# Inner trainer optimizer
inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
# Creating training data loaders
if len(support) % self.inner_batch_size == 1 :
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
batch_size=self.inner_batch_size,
drop_last=True)
else:
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
batch_size=self.inner_batch_size,
drop_last=False)
# steps_per_epoch=len(support) // self.inner_batch_size
# total_training_steps = steps_per_epoch * 5
# warmup_steps = total_training_steps // 3
#
# scheduler = get_linear_schedule_with_warmup(
# inner_optimizer,
# num_warmup_steps=warmup_steps,
# num_training_steps=total_training_steps
# )
fast_model.train()
# Inner loop training epoch (support set)
if valid_train:
print('----Task',task_id,":", name, '----')
for i in range(0, num_inner_update_step):
all_loss = []
# Inner loop training batch (support set)
for inner_step, batch in enumerate(support_dataloader):
batch = tuple(t.to(self.device) for t in batch)
input_ids, attention_mask, token_type_ids, label_id = batch
# Feed Foward
loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id)
# Computing gradients
loss.backward()
# torch.nn.utils.clip_grad_norm_(fast_model.parameters(), max_norm=1)
# Updating inner training parameters
inner_optimizer.step()
inner_optimizer.zero_grad()
# Appending losses
all_loss.append(loss.item())
del batch, input_ids, attention_mask, label_id
torch.cuda.empty_cache()
if valid_train:
if (i+1) % self.inner_print == 0:
print("Inner Loss: ", np.mean(all_loss))
fast_model.to(torch.device('cpu'))
# Inner training phase weights
if training:
meta_weights = list(self.model.parameters())
fast_weights = list(fast_model.parameters())
# Appending gradients
gradients = []
for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)):
gradient = meta_params - fast_params
if task_id == 0:
sum_gradients.append(gradient)
else:
sum_gradients[i] += gradient
# Inner test (query set)
fast_model.to(self.device)
fast_model.eval()
if valid_train:
# Inner test (query set)
fast_model.to(self.device)
fast_model.eval()
with torch.no_grad():
# Data loader
query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
query_batch = iter(query_dataloader).next()
query_batch = tuple(t.to(self.device) for t in query_batch)
q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch
# Feedfoward
_, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id)
# Predictions
pre_label_id = pre_label_id.detach().cpu().squeeze()
# Labels
q_label_id = q_label_id.detach().cpu()
# Calculating metrics
acc = fn.accuracy(pre_label_id, q_label_id).item()
recall = fn.recall(pre_label_id, q_label_id).item(),
f1 = fn.f1_score(pre_label_id, q_label_id).item()
# appending metrics
task_accs.append(acc)
task_f1.append(f1)
task_recall.append(recall)
fast_model.to(torch.device('cpu'))
del fast_model, inner_optimizer
torch.cuda.empty_cache()
print("\n")
print("f1:",np.mean(task_f1))
print("recall:",np.mean(task_recall))
# Updating outer training parameters
if training:
# Mean of gradients
for i in range(0,len(sum_gradients)):
sum_gradients[i] = sum_gradients[i] / float(num_task)
# Indexing parameters to model
for i, params in enumerate(self.model.parameters()):
params.grad = sum_gradients[i]
# Updating parameters
self.outer_optimizer.step()
self.outer_optimizer.zero_grad()
del sum_gradients
gc.collect()
torch.cuda.empty_cache()
if valid_train:
return np.mean(task_accs)
else:
return np.array(0)
# Creating Meta Tasks
class MetaTask(Dataset):
def __init__(self, examples, num_task, k_support, k_query,
tokenizer, training=True, max_seq_length=512,
treat_text =None, **args):
"""
:param samples: list of samples
:param num_task: number of training tasks.
:param k_support: number of classes support samples per task
:param k_query: number of classes query sample per task
"""
self.examples = examples
self.num_task = num_task
self.k_support = k_support
self.k_query = k_query
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.treat_text = treat_text
# Randomly generating tasks
self.create_batch(self.num_task, training)
# Creating batch
def create_batch(self, num_task, training):
self.supports = [] # support set
self.queries = [] # query set
self.task_names = [] # Name of task
self.supports_indexs = [] # index of supports
self.queries_indexs = [] # index of queries
self.num_task=num_task
# Available tasks
domains = self.examples['domain'].unique()
# If not training, create all tasks
if not(training):
self.task_names = domains
num_task = len(self.task_names)
self.num_task=num_task
for b in range(num_task): # For each task,
total_per_class = self.k_support + self.k_query
task_size = 2*self.k_support + 2*self.k_query
# Select a task at random
if training:
domain = random.choice(domains)
self.task_names.append(domain)
else:
domain = self.task_names[b]
# Task data
domainExamples = self.examples[self.examples['domain'] == domain]
# Minimal label quantity
min_per_class = min(domainExamples['label'].value_counts())
if total_per_class > min_per_class:
total_per_class = min_per_class
# Select k_support + k_query task examples
# Sample (n) from each label(class)
selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False)
# Split data into support (training) and query (testing) sets
s, q = train_test_split(selected_examples,
stratify= selected_examples["label"],
test_size= 2*self.k_query/task_size,
shuffle=True)
# Permutating data
s = s.sample(frac=1)
q = q.sample(frac=1)
# Appending indexes
if not(training):
self.supports_indexs.append(s.index)
self.queries_indexs.append(q.index)
# Creating list of support (training) and query (testing) tasks
self.supports.append(s.to_dict('records'))
self.queries.append(q.to_dict('records'))
# Creating task tensors
def create_feature_set(self, examples):
all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
all_label_ids = torch.empty(len(examples), dtype = torch.long)
for _id, e in enumerate(examples):
all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e)
return TensorDataset(
all_input_ids,
all_attention_mask,
all_token_type_ids,
all_label_ids
)
# Data encoding
def encode_text(self, example):
comment_text = example["text"]
if self.treat_text:
comment_text = self.treat_text(comment_text)
labels = LABEL_MAP[example["label"]]
encoding = self.tokenizer.encode_plus(
(comment_text, "It is a great text."),
add_special_tokens=True,
max_length=self.max_seq_length,
return_token_type_ids=True,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
return tuple((
encoding["input_ids"].flatten(),
encoding["attention_mask"].flatten(),
encoding["token_type_ids"].flatten(),
torch.tensor([torch.tensor(labels).to(int)])
))
# Returns data upon calling
def __getitem__(self, index):
support_set = self.create_feature_set(self.supports[index])
query_set = self.create_feature_set(self.queries[index])
name = self.task_names[index]
return support_set, query_set, name
def __len__(self):
return self.num_task
class treat_text:
def __init__(self, patterns):
self.patterns = patterns
def __call__(self,text):
text = unicodedata.normalize("NFKD",str(text))
text = multiple_replace(self.patterns,text.lower())
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
text = re.sub('( +)',' ', text)
text = re.sub('(, ,)|(,,)',',', text)
text = re.sub('(%)|(per cent)',' percent', text)
return text
# Regex multiple replace function
def multiple_replace(dict, text):
# Building regex from dict keys
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
# Substitution
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)