|
from torch import nn |
|
import torch |
|
import numpy as np |
|
from copy import deepcopy |
|
import re |
|
import unicodedata |
|
from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler |
|
from sklearn.model_selection import train_test_split |
|
from torch.optim import Adam |
|
from copy import deepcopy |
|
import gc |
|
import torch |
|
import numpy as np |
|
from torchmetrics import functional as fn |
|
import random |
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, layers, freeze_bert, model): |
|
super(Encoder, self).__init__() |
|
|
|
|
|
self.dummy_param = nn.Parameter(torch.empty(0)) |
|
|
|
|
|
self.model = deepcopy(model) |
|
|
|
|
|
if freeze_bert: |
|
for param in self.model.parameters(): |
|
param.requires_grad = freeze_bert |
|
|
|
|
|
old_model_encoder = self.model.encoder.layer |
|
new_model_encoder = nn.ModuleList() |
|
|
|
for i in layers: |
|
new_model_encoder.append(old_model_encoder[i]) |
|
|
|
self.model.encoder.layer = new_model_encoder |
|
|
|
|
|
def forward(self, **x): |
|
return self.model(**x)['pooler_output'] |
|
|
|
|
|
class SLR_Classifier(nn.Module): |
|
def __init__(self, **data): |
|
super(SLR_Classifier, self).__init__() |
|
|
|
|
|
self.dummy_param = nn.Parameter(torch.empty(0)) |
|
|
|
|
|
|
|
self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean', |
|
pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)])) |
|
|
|
|
|
self.Encoder = Encoder(layers = data.get("bert_layers", range(12)), |
|
freeze_bert = data.get("freeze_bert", False), |
|
model = data.get("model"), |
|
) |
|
|
|
|
|
self.feature_map = nn.Sequential( |
|
|
|
nn.BatchNorm1d(self.Encoder.model.config.hidden_size), |
|
|
|
nn.Linear(self.Encoder.model.config.hidden_size, 200), |
|
nn.Dropout(data.get("drop", 0.5)), |
|
) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
|
|
|
|
|
|
|
nn.Tanh(), |
|
nn.Linear(200, 1) |
|
) |
|
|
|
|
|
nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001) |
|
nn.init.zeros_(self.feature_map[1].bias) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids, labels): |
|
|
|
predict = self.Encoder(**{"input_ids":input_ids, |
|
"attention_mask":attention_mask, |
|
"token_type_ids":token_type_ids}) |
|
feature = self.feature_map(predict) |
|
logit = self.classifier(feature) |
|
|
|
predict = torch.sigmoid(logit) |
|
|
|
|
|
loss = self.loss_fn(logit.to(torch.float), labels.to(torch.float).unsqueeze(1)) |
|
|
|
return [loss, [feature, logit], predict] |
|
|
|
|
|
patterns = { |
|
'CONCLUSIONS AND IMPLICATIONS':'', |
|
'BACKGROUND AND PURPOSE':'', |
|
'EXPERIMENTAL APPROACH':'', |
|
'KEY RESULTS AEA':'', |
|
'©':'', |
|
'®':'', |
|
'μ':'', |
|
'(C)':'', |
|
'OBJECTIVE:':'', |
|
'MATERIALS AND METHODS:':'', |
|
'SIGNIFICANCE:':'', |
|
'BACKGROUND:':'', |
|
'RESULTS:':'', |
|
'METHODS:':'', |
|
'CONCLUSIONS:':'', |
|
'AIM:':'', |
|
'STUDY DESIGN:':'', |
|
'CLINICAL RELEVANCE:':'', |
|
'CONCLUSION:':'', |
|
'HYPOTHESIS:':'', |
|
'CLINICAL RELEVANCE:':'', |
|
'Questions/Purposes:':'', |
|
'Introduction:':'', |
|
'PURPOSE:':'', |
|
'PATIENTS AND METHODS:':'', |
|
'FINDINGS:':'', |
|
'INTERPRETATIONS:':'', |
|
'FUNDING:':'', |
|
'PROGRESS:':'', |
|
'CONTEXT:':'', |
|
'MEASURES:':'', |
|
'DESIGN:':'', |
|
'BACKGROUND AND OBJECTIVES:':'', |
|
'<p>':'', |
|
'</p>':'', |
|
'<<ETX>>':'', |
|
'+/-':'', |
|
'\(.+\)':'', |
|
'\[.+\]':'', |
|
' \d ':'', |
|
'<':'', |
|
'>':'', |
|
'- ':'', |
|
' +':' ', |
|
', ,':',', |
|
',,':',', |
|
'%':' percent', |
|
'per cent':' percent' |
|
} |
|
|
|
patterns = {x.lower():y for x,y in patterns.items()} |
|
|
|
|
|
LABEL_MAP = {'negative': 0, |
|
'not included':0, |
|
'0':0, |
|
0:0, |
|
'excluded':0, |
|
'positive': 1, |
|
'included':1, |
|
'1':1, |
|
1:1, |
|
} |
|
|
|
class SLR_DataSet(Dataset): |
|
def __init__(self,treat_text =None, **args): |
|
self.tokenizer = args.get('tokenizer') |
|
self.data = args.get('data') |
|
self.max_seq_length = args.get("max_seq_length", 512) |
|
self.INPUT_NAME = args.get("input", 'x') |
|
self.LABEL_NAME = args.get("output", 'y') |
|
self.treat_text = treat_text |
|
|
|
|
|
def encode_text(self, example): |
|
comment_text = example[self.INPUT_NAME] |
|
if self.treat_text: |
|
comment_text = self.treat_text(comment_text) |
|
|
|
try: |
|
labels = LABEL_MAP[example[self.LABEL_NAME].lower()] |
|
except: |
|
labels = -1 |
|
|
|
encoding = self.tokenizer.encode_plus( |
|
(comment_text, "It is great text"), |
|
add_special_tokens=True, |
|
max_length=self.max_seq_length, |
|
return_token_type_ids=True, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
|
|
|
|
return tuple(( |
|
encoding["input_ids"].flatten(), |
|
encoding["attention_mask"].flatten(), |
|
encoding["token_type_ids"].flatten(), |
|
torch.tensor([torch.tensor(labels).to(int)]) |
|
)) |
|
|
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, index: int): |
|
|
|
data_row = self.data.reset_index().iloc[index] |
|
temp_data = self.encode_text(data_row) |
|
return temp_data |
|
|
|
|
|
class Learner(nn.Module): |
|
|
|
def __init__(self, **args): |
|
""" |
|
:param args: |
|
""" |
|
super(Learner, self).__init__() |
|
|
|
self.inner_print = args.get('inner_print') |
|
self.inner_batch_size = args.get('inner_batch_size') |
|
self.outer_update_lr = args.get('outer_update_lr') |
|
self.inner_update_lr = args.get('inner_update_lr') |
|
self.inner_update_step = args.get('inner_update_step') |
|
self.inner_update_step_eval = args.get('inner_update_step_eval') |
|
self.model = args.get('model') |
|
self.device = args.get('device') |
|
|
|
|
|
self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr) |
|
self.model.train() |
|
|
|
def forward(self, batch_tasks, training = True, valid_train = True): |
|
""" |
|
batch = [(support TensorDataset, query TensorDataset), |
|
(support TensorDataset, query TensorDataset), |
|
(support TensorDataset, query TensorDataset), |
|
(support TensorDataset, query TensorDataset)] |
|
|
|
# support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids) |
|
""" |
|
task_accs = [] |
|
task_f1 = [] |
|
task_recall = [] |
|
sum_gradients = [] |
|
num_task = len(batch_tasks) |
|
num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval |
|
|
|
|
|
for task_id, task in enumerate(batch_tasks): |
|
support = task[0] |
|
query = task[1] |
|
name = task[2] |
|
|
|
|
|
fast_model = deepcopy(self.model) |
|
fast_model.to(self.device) |
|
|
|
|
|
inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr) |
|
|
|
|
|
if len(support) % self.inner_batch_size == 1 : |
|
support_dataloader = DataLoader(support, sampler=RandomSampler(support), |
|
batch_size=self.inner_batch_size, |
|
drop_last=True) |
|
else: |
|
support_dataloader = DataLoader(support, sampler=RandomSampler(support), |
|
batch_size=self.inner_batch_size, |
|
drop_last=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fast_model.train() |
|
|
|
|
|
if valid_train: |
|
print('----Task',task_id,":", name, '----') |
|
|
|
for i in range(0, num_inner_update_step): |
|
all_loss = [] |
|
|
|
|
|
for inner_step, batch in enumerate(support_dataloader): |
|
batch = tuple(t.to(self.device) for t in batch) |
|
input_ids, attention_mask, token_type_ids, label_id = batch |
|
|
|
|
|
loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
inner_optimizer.step() |
|
inner_optimizer.zero_grad() |
|
|
|
|
|
all_loss.append(loss.item()) |
|
|
|
del batch, input_ids, attention_mask, label_id |
|
torch.cuda.empty_cache() |
|
|
|
if valid_train: |
|
if (i+1) % self.inner_print == 0: |
|
print("Inner Loss: ", np.mean(all_loss)) |
|
|
|
fast_model.to(torch.device('cpu')) |
|
|
|
|
|
if training: |
|
meta_weights = list(self.model.parameters()) |
|
fast_weights = list(fast_model.parameters()) |
|
|
|
|
|
gradients = [] |
|
for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)): |
|
gradient = meta_params - fast_params |
|
if task_id == 0: |
|
sum_gradients.append(gradient) |
|
else: |
|
sum_gradients[i] += gradient |
|
|
|
|
|
|
|
fast_model.to(self.device) |
|
fast_model.eval() |
|
|
|
if valid_train: |
|
|
|
fast_model.to(self.device) |
|
fast_model.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
query_dataloader = DataLoader(query, sampler=None, batch_size=len(query)) |
|
query_batch = iter(query_dataloader).next() |
|
query_batch = tuple(t.to(self.device) for t in query_batch) |
|
q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch |
|
|
|
|
|
_, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id) |
|
|
|
|
|
pre_label_id = pre_label_id.detach().cpu().squeeze() |
|
|
|
q_label_id = q_label_id.detach().cpu() |
|
|
|
|
|
acc = fn.accuracy(pre_label_id, q_label_id).item() |
|
recall = fn.recall(pre_label_id, q_label_id).item(), |
|
f1 = fn.f1_score(pre_label_id, q_label_id).item() |
|
|
|
|
|
task_accs.append(acc) |
|
task_f1.append(f1) |
|
task_recall.append(recall) |
|
|
|
fast_model.to(torch.device('cpu')) |
|
|
|
del fast_model, inner_optimizer |
|
torch.cuda.empty_cache() |
|
|
|
print("\n") |
|
print("f1:",np.mean(task_f1)) |
|
print("recall:",np.mean(task_recall)) |
|
|
|
|
|
if training: |
|
|
|
for i in range(0,len(sum_gradients)): |
|
sum_gradients[i] = sum_gradients[i] / float(num_task) |
|
|
|
|
|
for i, params in enumerate(self.model.parameters()): |
|
params.grad = sum_gradients[i] |
|
|
|
|
|
self.outer_optimizer.step() |
|
self.outer_optimizer.zero_grad() |
|
|
|
del sum_gradients |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
if valid_train: |
|
return np.mean(task_accs) |
|
else: |
|
return np.array(0) |
|
|
|
|
|
|
|
|
|
class MetaTask(Dataset): |
|
def __init__(self, examples, num_task, k_support, k_query, |
|
tokenizer, training=True, max_seq_length=512, |
|
treat_text =None, **args): |
|
""" |
|
:param samples: list of samples |
|
:param num_task: number of training tasks. |
|
:param k_support: number of classes support samples per task |
|
:param k_query: number of classes query sample per task |
|
""" |
|
self.examples = examples |
|
|
|
self.num_task = num_task |
|
self.k_support = k_support |
|
self.k_query = k_query |
|
self.tokenizer = tokenizer |
|
self.max_seq_length = max_seq_length |
|
self.treat_text = treat_text |
|
|
|
|
|
self.create_batch(self.num_task, training) |
|
|
|
|
|
def create_batch(self, num_task, training): |
|
self.supports = [] |
|
self.queries = [] |
|
self.task_names = [] |
|
self.supports_indexs = [] |
|
self.queries_indexs = [] |
|
self.num_task=num_task |
|
|
|
|
|
domains = self.examples['domain'].unique() |
|
|
|
|
|
if not(training): |
|
self.task_names = domains |
|
num_task = len(self.task_names) |
|
self.num_task=num_task |
|
|
|
|
|
for b in range(num_task): |
|
total_per_class = self.k_support + self.k_query |
|
task_size = 2*self.k_support + 2*self.k_query |
|
|
|
|
|
if training: |
|
domain = random.choice(domains) |
|
self.task_names.append(domain) |
|
else: |
|
domain = self.task_names[b] |
|
|
|
|
|
domainExamples = self.examples[self.examples['domain'] == domain] |
|
|
|
|
|
min_per_class = min(domainExamples['label'].value_counts()) |
|
|
|
if total_per_class > min_per_class: |
|
total_per_class = min_per_class |
|
|
|
|
|
|
|
selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False) |
|
|
|
|
|
s, q = train_test_split(selected_examples, |
|
stratify= selected_examples["label"], |
|
test_size= 2*self.k_query/task_size, |
|
shuffle=True) |
|
|
|
|
|
s = s.sample(frac=1) |
|
q = q.sample(frac=1) |
|
|
|
|
|
if not(training): |
|
self.supports_indexs.append(s.index) |
|
self.queries_indexs.append(q.index) |
|
|
|
|
|
self.supports.append(s.to_dict('records')) |
|
self.queries.append(q.to_dict('records')) |
|
|
|
|
|
def create_feature_set(self, examples): |
|
all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long) |
|
all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long) |
|
all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long) |
|
all_label_ids = torch.empty(len(examples), dtype = torch.long) |
|
|
|
for _id, e in enumerate(examples): |
|
all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e) |
|
|
|
return TensorDataset( |
|
all_input_ids, |
|
all_attention_mask, |
|
all_token_type_ids, |
|
all_label_ids |
|
) |
|
|
|
|
|
def encode_text(self, example): |
|
comment_text = example["text"] |
|
|
|
if self.treat_text: |
|
comment_text = self.treat_text(comment_text) |
|
|
|
labels = LABEL_MAP[example["label"]] |
|
|
|
encoding = self.tokenizer.encode_plus( |
|
(comment_text, "It is a great text."), |
|
add_special_tokens=True, |
|
max_length=self.max_seq_length, |
|
return_token_type_ids=True, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
|
|
return tuple(( |
|
encoding["input_ids"].flatten(), |
|
encoding["attention_mask"].flatten(), |
|
encoding["token_type_ids"].flatten(), |
|
torch.tensor([torch.tensor(labels).to(int)]) |
|
)) |
|
|
|
|
|
def __getitem__(self, index): |
|
support_set = self.create_feature_set(self.supports[index]) |
|
query_set = self.create_feature_set(self.queries[index]) |
|
name = self.task_names[index] |
|
return support_set, query_set, name |
|
|
|
def __len__(self): |
|
return self.num_task |
|
|
|
|
|
class treat_text: |
|
def __init__(self, patterns): |
|
self.patterns = patterns |
|
|
|
def __call__(self,text): |
|
text = unicodedata.normalize("NFKD",str(text)) |
|
text = multiple_replace(self.patterns,text.lower()) |
|
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text) |
|
text = re.sub('( +)',' ', text) |
|
text = re.sub('(, ,)|(,,)',',', text) |
|
text = re.sub('(%)|(per cent)',' percent', text) |
|
return text |
|
|
|
|
|
|
|
def multiple_replace(dict, text): |
|
|
|
|
|
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys()))) |
|
|
|
|
|
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text) |