svincoff's picture
uploaded training code and model weights
9a73cb0
import numpy as np
import pandas as pd
import torch
import os
from torch.nn.functional import softmax
from fuson_plm.utils.logging import log_update
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from abc import ABC, abstractmethod
#----------------------------------------------------------------------------------------------------------------------------------------------------
#### Masking Rate Scheduler base class and sub classes
# abstract base class
class MaskingRateScheduler(ABC):
def __init__(self, total_steps, min_masking_rate, max_masking_rate, last_step=-1):
self.total_steps = total_steps
self.min_masking_rate = min_masking_rate
self.max_masking_rate = max_masking_rate
self.current_step = last_step
def step(self):
self.current_step += 1
def reset(self):
"""Reset the scheduler to its initial state."""
self.current_step = -1
def get_masking_rate(self):
progress = self.current_step / self.total_steps
return self.compute_masking_rate(progress)
@abstractmethod
def compute_masking_rate(self, progress):
"""To be implemented by subclasses for specific increase functions."""
raise NotImplementedError("Subclasses must implement this method.")
class CosineIncreaseMaskingRateScheduler(MaskingRateScheduler):
def compute_masking_rate(self, progress):
# Use a cosine increase function
cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * cosine_increase
class LogLinearIncreaseMaskingRateScheduler(MaskingRateScheduler):
def compute_masking_rate(self, progress):
# Avoid log(0) by clamping progress to a minimum of a small positive number
progress = max(progress, 1e-10)
log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1]
return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * log_linear_increase
class StepwiseIncreaseMaskingRateScheduler(MaskingRateScheduler):
def __init__(self, total_batches, min_masking_rate, max_masking_rate, num_steps):
super().__init__(total_steps=total_batches, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate)
self.num_steps = num_steps
self.batch_interval = total_batches // (num_steps) # Adjusting to ensure max rate is included
self.rate_increment = (max_masking_rate - min_masking_rate) / (num_steps - 1) # Include end rate in the steps
def compute_masking_rate(self, progress):
# Determine the current step based on the number of completed batches
current_step = int(self.current_step / self.batch_interval)
# Cap the step number to `num_steps - 1` to include the max rate at the final step
current_step = min(current_step, self.num_steps - 1)
# Calculate the masking rate for the current step
masking_rate = self.min_masking_rate + current_step * self.rate_increment
return masking_rate
def get_mask_rate_scheduler(scheduler_type="cosine",min_masking_rate=0.15,max_masking_rate=0.40,total_batches=100,total_steps=20):
"""
Initialize the mask rate scheduler and return it
"""
if scheduler_type=="cosine":
return CosineIncreaseMaskingRateScheduler(total_steps=total_batches,
min_masking_rate=min_masking_rate,
max_masking_rate=max_masking_rate)
elif scheduler_type=="loglinear":
return LogLinearIncreaseMaskingRateScheduler(total_steps=total_batches,
min_masking_rate=min_masking_rate,
max_masking_rate=max_masking_rate)
elif scheduler_type=="stepwise":
return StepwiseIncreaseMaskingRateScheduler(total_batches=total_batches,
num_steps=total_steps,
min_masking_rate=min_masking_rate,
max_masking_rate=max_masking_rate)
else:
raise Exception("Must specify valid scheduler_type: cosine, loglinear, stepwise")
# Adjusted Dataloader for the sequences and probability vectors
class ProteinDataset(Dataset):
def __init__(self, data_path, tokenizer, probability_type, max_length=512):
self.dataframe = pd.read_csv(data_path)
self.tokenizer = tokenizer
self.probability_type=probability_type
self.max_length = max_length
self.set_probabilities()
def __len__(self):
return len(self.dataframe)
def set_probabilities(self):
if self.probability_type=="snp":
self.dataframe = self.dataframe.rename(columns={'snp_probabilities':'probabilities'})
if self.probability_type=="uniform":
self.dataframe['probabilities'] = self.dataframe['sequence'].apply(len).apply(lambda x: ('1,'*x)[0:-1])
# make probabilities into numbers if they aren't already
if type(self.dataframe['probabilities'][0]) == str:
self.dataframe['probabilities'] = self.dataframe['probabilities'].apply(
lambda x: np.array([float(i) for i in x.split(',')])
)
def get_padded_probabilities(self, idx):
'''
Pads probabilities to max_length if they're too short; truncate them if they're too long
'''
no_mask_value = int(-1e9) # will be used to make sure CLS and PAD aren't masked
# add a no-mask slot for <CLS>
prob = np.concatenate((
np.array([no_mask_value]),
self.dataframe.iloc[idx]['probabilities']
)
)
# Pad with no_mask_value for everything after the probability vector ends
if len(prob) < self.max_length:
return np.pad(
prob,
(0, self.max_length - len(prob)),
'constant', constant_values=(0,no_mask_value))
# If it's too long, we need to truncate, but we also need to change the last token to an <EOS>.
prob = prob[0:self.max_length-1]
prob = np.concatenate((
prob,
np.array([no_mask_value]),
)
)
return prob
def __getitem__(self, idx):
sequence = self.dataframe.iloc[idx]['sequence']
probability = self.get_padded_probabilities(idx) # extract them
inputs = self.tokenizer(sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_length) # does this have to be 512?
inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()} # Remove batch dimension
return inputs, probability
def get_dataloader(data_path, tokenizer, probability_type='snp', max_length=512, batch_size=8, shuffle=True):
"""
Creates a DataLoader for the dataset.
Args:
data_path (str): Path to the CSV file (train, val, or test).
batch_size (int): Batch size.
shuffle (bool): Whether to shuffle the data.
tokenizer (Tokenizer): tokenizer object for data tokenization
Returns:
DataLoader: DataLoader object.
"""
dataset = ProteinDataset(data_path, tokenizer, probability_type, max_length=max_length)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
def check_dataloaders(train_loader, val_loader, test_loader, max_length=512, checkpoint_dir=''):
log_update(f'\nBuilt train, validation, and test dataloders')
log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}")
log_update(f"\tNumber of sequences in the Validation DataLoader: {len(val_loader.dataset)}")
log_update(f"\tNumber of sequences in the Training DataLoader: {len(test_loader.dataset)}")
dataloader_overlaps = check_dataloader_overlap(train_loader, val_loader, test_loader)
if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)")
else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}")
# write length ranges to a text file
if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')):
os.mkdir(f'{checkpoint_dir}/batch_diversity')
max_length_violators = []
for name, dataloader in {'train':train_loader, 'val':val_loader, 'test':test_loader}.items():
max_length_followed, length_ranges = check_max_length_and_length_diversity(dataloader, max_length)
if max_length_followed == False:
max_length_violators.append(name)
with open(f'{checkpoint_dir}/batch_diversity/{name}_batch_length_ranges.txt','w') as f:
for tup in length_ranges:
f.write(f'{tup[0]}\t{tup[1]}\n')
if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}")
else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}")
def check_dataloader_overlap(train_loader, val_loader, test_loader):
"""
Check the data that's about to go into the model. Make sure there is no overlap between train, test, and val
Returns:
"""
train_protein_seqs = set(train_loader.dataset.dataframe['sequence'].unique())
val_protein_seqs = set(val_loader.dataset.dataframe['sequence'].unique())
test_protein_seqs = set(test_loader.dataset.dataframe['sequence'].unique())
tr_va = len(train_protein_seqs.intersection(val_protein_seqs))
tr_te = len(train_protein_seqs.intersection(test_protein_seqs))
va_te = len(val_protein_seqs.intersection(test_protein_seqs))
overlaps = []
if tr_va==tr_te==va_te==0:
return overlaps # data is clean
else:
if tr_va > 0: overlaps.append(f"Train-Val Overlap={tr_va}")
if tr_te > 0: overlaps.append(f"Train-Test Overlap={tr_te}")
if va_te > 0: overlaps.append(f"Val-Test Overlap={va_te}")
return overlaps
def check_max_length_and_length_diversity(dataloader, max_length):
"""
Check if all sequences in the DataLoader conform to the specified max_length,
and return the sequence length ranges within each batch.
Args:
dataloader (DataLoader): The DataLoader object to check.
max_length (int): The maximum allowed sequence length.
Returns:
bool: True if all sequences are within the max_length, False otherwise.
list: A list of tuples representing the min and max sequence lengths in each batch.
"""
length_ranges = []
all_within_max_length = True
for batch_idx, (inputs, _) in enumerate(dataloader):
input_ids = inputs['input_ids']
# Calculate the actual lengths of sequences in this batch
actual_lengths = (input_ids != dataloader.dataset.tokenizer.pad_token_id).sum(dim=1)
min_length = actual_lengths.min().item()
max_length_in_batch = actual_lengths.max().item()
# Check for max length violation
if max_length_in_batch > max_length:
#print(f"Error: Sequence exceeds max_length of {max_length} at batch {batch_idx + 1}. Max length found: {max_length_in_batch}")
all_within_max_length = False
# Store the length range for this batch
length_ranges.append((min_length, max_length_in_batch))
#print(f"All sequences in the DataLoader conform to the max_length of {max_length}.") if all_within_max_length else None
#print(f"Sequence length ranges per batch: {length_ranges}")
return all_within_max_length, length_ranges
def check_max_length_in_dataloader(dataloader, max_length):
"""
Check if all sequences in the DataLoader conform to the specified max_length.
Args:
dataloader (DataLoader): The DataLoader object to check.
max_length (int): The maximum allowed sequence length.
Returns:
bool: True if all sequences are within the max_length, False otherwise.
"""
for batch_idx, (inputs, _) in enumerate(dataloader):
input_ids = inputs['input_ids']
# Check if any sequence length exceeds max_length
if input_ids.size(1) > max_length:
return False
return True
def batch_sample_mask_tokens_with_probabilities(inputs, probabilities, tokenizer: AutoTokenizer, mask_percentage=0.15):
"""
"""
#print('the batch sample method was called!')
labels = inputs["input_ids"].detach().clone()
labels[labels != tokenizer.mask_token_id] = -100 # Set labels for unmasked tokens to -100
# Iterate over each sequence and its corresponding probabilities in the batch
for idx in range(inputs["input_ids"].size(0)): # Assuming the first dimension is batch size
input_ids = inputs["input_ids"][idx]
prob = probabilities[idx]
cls_token_index = (input_ids == 0).nonzero(as_tuple=False)[0].item()
eos_token_index = (input_ids == 2).nonzero(as_tuple=False)[0].item()
seq_length = eos_token_index - (cls_token_index+1)
assert prob.shape[0] == input_ids.shape[0]
# Normalize probabilities using softmax
prob = softmax(prob, dim=0).cpu().numpy() # move to CPU for numpy
assert 1 - sum(prob) < 1e-6
# Calculate the number of tokens to mask
num_tokens_to_mask = int(mask_percentage * seq_length)
# Choose indices to mask based on the probability distribution
mask_indices = np.random.choice(input_ids.shape[0], size=num_tokens_to_mask, replace=False, p=prob)
attention_mask_1_indices = np.arange(0, eos_token_index+1, 1)
# Mask the selected indices and set the corresponding labels
labels[idx, mask_indices] = input_ids[mask_indices].detach().clone()
input_ids[mask_indices] = tokenizer.mask_token_id
inputs["attention_mask"][idx] = torch.zeros_like(input_ids)
inputs["attention_mask"][idx][attention_mask_1_indices] = 1 # just added this to try and update the attention mask....
# Update the input_ids in the inputs dictionary
inputs["input_ids"][idx] = input_ids
inputs["labels"] = labels
return inputs