|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import os |
|
from torch.nn.functional import softmax |
|
from fuson_plm.utils.logging import log_update |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import AutoTokenizer |
|
from abc import ABC, abstractmethod |
|
|
|
|
|
|
|
|
|
class MaskingRateScheduler(ABC): |
|
def __init__(self, total_steps, min_masking_rate, max_masking_rate, last_step=-1): |
|
self.total_steps = total_steps |
|
self.min_masking_rate = min_masking_rate |
|
self.max_masking_rate = max_masking_rate |
|
self.current_step = last_step |
|
|
|
def step(self): |
|
self.current_step += 1 |
|
|
|
def reset(self): |
|
"""Reset the scheduler to its initial state.""" |
|
self.current_step = -1 |
|
|
|
def get_masking_rate(self): |
|
progress = self.current_step / self.total_steps |
|
return self.compute_masking_rate(progress) |
|
|
|
@abstractmethod |
|
def compute_masking_rate(self, progress): |
|
"""To be implemented by subclasses for specific increase functions.""" |
|
raise NotImplementedError("Subclasses must implement this method.") |
|
|
|
|
|
class CosineIncreaseMaskingRateScheduler(MaskingRateScheduler): |
|
def compute_masking_rate(self, progress): |
|
|
|
cosine_increase = 0.5 * (1 - np.cos(np.pi * progress)) |
|
return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * cosine_increase |
|
|
|
class LogLinearIncreaseMaskingRateScheduler(MaskingRateScheduler): |
|
def compute_masking_rate(self, progress): |
|
|
|
progress = max(progress, 1e-10) |
|
log_linear_increase = np.log1p(progress) / np.log1p(1) |
|
return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * log_linear_increase |
|
|
|
class StepwiseIncreaseMaskingRateScheduler(MaskingRateScheduler): |
|
def __init__(self, total_batches, min_masking_rate, max_masking_rate, num_steps): |
|
super().__init__(total_steps=total_batches, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate) |
|
self.num_steps = num_steps |
|
self.batch_interval = total_batches // (num_steps) |
|
self.rate_increment = (max_masking_rate - min_masking_rate) / (num_steps - 1) |
|
|
|
def compute_masking_rate(self, progress): |
|
|
|
current_step = int(self.current_step / self.batch_interval) |
|
|
|
current_step = min(current_step, self.num_steps - 1) |
|
|
|
masking_rate = self.min_masking_rate + current_step * self.rate_increment |
|
return masking_rate |
|
|
|
def get_mask_rate_scheduler(scheduler_type="cosine",min_masking_rate=0.15,max_masking_rate=0.40,total_batches=100,total_steps=20): |
|
""" |
|
Initialize the mask rate scheduler and return it |
|
""" |
|
if scheduler_type=="cosine": |
|
return CosineIncreaseMaskingRateScheduler(total_steps=total_batches, |
|
min_masking_rate=min_masking_rate, |
|
max_masking_rate=max_masking_rate) |
|
elif scheduler_type=="loglinear": |
|
return LogLinearIncreaseMaskingRateScheduler(total_steps=total_batches, |
|
min_masking_rate=min_masking_rate, |
|
max_masking_rate=max_masking_rate) |
|
elif scheduler_type=="stepwise": |
|
return StepwiseIncreaseMaskingRateScheduler(total_batches=total_batches, |
|
num_steps=total_steps, |
|
min_masking_rate=min_masking_rate, |
|
max_masking_rate=max_masking_rate) |
|
else: |
|
raise Exception("Must specify valid scheduler_type: cosine, loglinear, stepwise") |
|
|
|
|
|
class ProteinDataset(Dataset): |
|
def __init__(self, data_path, tokenizer, probability_type, max_length=512): |
|
self.dataframe = pd.read_csv(data_path) |
|
self.tokenizer = tokenizer |
|
self.probability_type=probability_type |
|
self.max_length = max_length |
|
|
|
self.set_probabilities() |
|
|
|
def __len__(self): |
|
return len(self.dataframe) |
|
|
|
def set_probabilities(self): |
|
if self.probability_type=="snp": |
|
self.dataframe = self.dataframe.rename(columns={'snp_probabilities':'probabilities'}) |
|
if self.probability_type=="uniform": |
|
self.dataframe['probabilities'] = self.dataframe['sequence'].apply(len).apply(lambda x: ('1,'*x)[0:-1]) |
|
|
|
|
|
if type(self.dataframe['probabilities'][0]) == str: |
|
self.dataframe['probabilities'] = self.dataframe['probabilities'].apply( |
|
lambda x: np.array([float(i) for i in x.split(',')]) |
|
) |
|
|
|
def get_padded_probabilities(self, idx): |
|
''' |
|
Pads probabilities to max_length if they're too short; truncate them if they're too long |
|
''' |
|
no_mask_value = int(-1e9) |
|
|
|
|
|
prob = np.concatenate(( |
|
np.array([no_mask_value]), |
|
self.dataframe.iloc[idx]['probabilities'] |
|
) |
|
) |
|
|
|
|
|
if len(prob) < self.max_length: |
|
return np.pad( |
|
prob, |
|
(0, self.max_length - len(prob)), |
|
'constant', constant_values=(0,no_mask_value)) |
|
|
|
|
|
prob = prob[0:self.max_length-1] |
|
prob = np.concatenate(( |
|
prob, |
|
np.array([no_mask_value]), |
|
) |
|
) |
|
return prob |
|
|
|
def __getitem__(self, idx): |
|
sequence = self.dataframe.iloc[idx]['sequence'] |
|
probability = self.get_padded_probabilities(idx) |
|
inputs = self.tokenizer(sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_length) |
|
inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()} |
|
return inputs, probability |
|
|
|
def get_dataloader(data_path, tokenizer, probability_type='snp', max_length=512, batch_size=8, shuffle=True): |
|
""" |
|
Creates a DataLoader for the dataset. |
|
Args: |
|
data_path (str): Path to the CSV file (train, val, or test). |
|
batch_size (int): Batch size. |
|
shuffle (bool): Whether to shuffle the data. |
|
tokenizer (Tokenizer): tokenizer object for data tokenization |
|
Returns: |
|
DataLoader: DataLoader object. |
|
""" |
|
dataset = ProteinDataset(data_path, tokenizer, probability_type, max_length=max_length) |
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) |
|
|
|
def check_dataloaders(train_loader, val_loader, test_loader, max_length=512, checkpoint_dir=''): |
|
log_update(f'\nBuilt train, validation, and test dataloders') |
|
log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}") |
|
log_update(f"\tNumber of sequences in the Validation DataLoader: {len(val_loader.dataset)}") |
|
log_update(f"\tNumber of sequences in the Training DataLoader: {len(test_loader.dataset)}") |
|
dataloader_overlaps = check_dataloader_overlap(train_loader, val_loader, test_loader) |
|
if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)") |
|
else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}") |
|
|
|
|
|
if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')): |
|
os.mkdir(f'{checkpoint_dir}/batch_diversity') |
|
|
|
max_length_violators = [] |
|
for name, dataloader in {'train':train_loader, 'val':val_loader, 'test':test_loader}.items(): |
|
max_length_followed, length_ranges = check_max_length_and_length_diversity(dataloader, max_length) |
|
if max_length_followed == False: |
|
max_length_violators.append(name) |
|
|
|
with open(f'{checkpoint_dir}/batch_diversity/{name}_batch_length_ranges.txt','w') as f: |
|
for tup in length_ranges: |
|
f.write(f'{tup[0]}\t{tup[1]}\n') |
|
|
|
if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}") |
|
else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}") |
|
|
|
def check_dataloader_overlap(train_loader, val_loader, test_loader): |
|
""" |
|
Check the data that's about to go into the model. Make sure there is no overlap between train, test, and val |
|
|
|
Returns: |
|
""" |
|
train_protein_seqs = set(train_loader.dataset.dataframe['sequence'].unique()) |
|
val_protein_seqs = set(val_loader.dataset.dataframe['sequence'].unique()) |
|
test_protein_seqs = set(test_loader.dataset.dataframe['sequence'].unique()) |
|
|
|
tr_va = len(train_protein_seqs.intersection(val_protein_seqs)) |
|
tr_te = len(train_protein_seqs.intersection(test_protein_seqs)) |
|
va_te = len(val_protein_seqs.intersection(test_protein_seqs)) |
|
|
|
overlaps = [] |
|
if tr_va==tr_te==va_te==0: |
|
return overlaps |
|
else: |
|
if tr_va > 0: overlaps.append(f"Train-Val Overlap={tr_va}") |
|
if tr_te > 0: overlaps.append(f"Train-Test Overlap={tr_te}") |
|
if va_te > 0: overlaps.append(f"Val-Test Overlap={va_te}") |
|
return overlaps |
|
|
|
def check_max_length_and_length_diversity(dataloader, max_length): |
|
""" |
|
Check if all sequences in the DataLoader conform to the specified max_length, |
|
and return the sequence length ranges within each batch. |
|
|
|
Args: |
|
dataloader (DataLoader): The DataLoader object to check. |
|
max_length (int): The maximum allowed sequence length. |
|
|
|
Returns: |
|
bool: True if all sequences are within the max_length, False otherwise. |
|
list: A list of tuples representing the min and max sequence lengths in each batch. |
|
""" |
|
length_ranges = [] |
|
all_within_max_length = True |
|
|
|
for batch_idx, (inputs, _) in enumerate(dataloader): |
|
input_ids = inputs['input_ids'] |
|
|
|
|
|
actual_lengths = (input_ids != dataloader.dataset.tokenizer.pad_token_id).sum(dim=1) |
|
min_length = actual_lengths.min().item() |
|
max_length_in_batch = actual_lengths.max().item() |
|
|
|
|
|
if max_length_in_batch > max_length: |
|
|
|
all_within_max_length = False |
|
|
|
|
|
length_ranges.append((min_length, max_length_in_batch)) |
|
|
|
|
|
|
|
|
|
return all_within_max_length, length_ranges |
|
|
|
|
|
def check_max_length_in_dataloader(dataloader, max_length): |
|
""" |
|
Check if all sequences in the DataLoader conform to the specified max_length. |
|
|
|
Args: |
|
dataloader (DataLoader): The DataLoader object to check. |
|
max_length (int): The maximum allowed sequence length. |
|
|
|
Returns: |
|
bool: True if all sequences are within the max_length, False otherwise. |
|
""" |
|
for batch_idx, (inputs, _) in enumerate(dataloader): |
|
input_ids = inputs['input_ids'] |
|
|
|
|
|
if input_ids.size(1) > max_length: |
|
return False |
|
|
|
return True |
|
|
|
|
|
def batch_sample_mask_tokens_with_probabilities(inputs, probabilities, tokenizer: AutoTokenizer, mask_percentage=0.15): |
|
""" |
|
""" |
|
|
|
labels = inputs["input_ids"].detach().clone() |
|
labels[labels != tokenizer.mask_token_id] = -100 |
|
|
|
|
|
for idx in range(inputs["input_ids"].size(0)): |
|
input_ids = inputs["input_ids"][idx] |
|
prob = probabilities[idx] |
|
|
|
cls_token_index = (input_ids == 0).nonzero(as_tuple=False)[0].item() |
|
eos_token_index = (input_ids == 2).nonzero(as_tuple=False)[0].item() |
|
seq_length = eos_token_index - (cls_token_index+1) |
|
|
|
assert prob.shape[0] == input_ids.shape[0] |
|
|
|
|
|
prob = softmax(prob, dim=0).cpu().numpy() |
|
assert 1 - sum(prob) < 1e-6 |
|
|
|
|
|
num_tokens_to_mask = int(mask_percentage * seq_length) |
|
|
|
|
|
mask_indices = np.random.choice(input_ids.shape[0], size=num_tokens_to_mask, replace=False, p=prob) |
|
attention_mask_1_indices = np.arange(0, eos_token_index+1, 1) |
|
|
|
|
|
labels[idx, mask_indices] = input_ids[mask_indices].detach().clone() |
|
input_ids[mask_indices] = tokenizer.mask_token_id |
|
|
|
inputs["attention_mask"][idx] = torch.zeros_like(input_ids) |
|
inputs["attention_mask"][idx][attention_mask_1_indices] = 1 |
|
|
|
|
|
inputs["input_ids"][idx] = input_ids |
|
|
|
inputs["labels"] = labels |
|
return inputs |
|
|
|
|
|
|