MemDLM / src /guidance /dataloader.py
Shrey Goel
adding code
d04a061
import torch
import pandas as pd
import lightning.pytorch as pl
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
class MembraneDataset(Dataset):
def __init__(self, config, data_path):
self.config = config
self.data = pd.read_csv(data_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.lm.pretrained_esm)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sequence = self.data.iloc[idx]["Sequence"]
tokens = self.tokenizer(
sequence.upper(),
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=self.config.data.max_seq_len,
)
labels = self.get_labels(sequence)
return {
"input_ids": tokens['input_ids'],
"attention_mask": tokens['attention_mask'],
"labels": labels
}
def get_labels(self, sequence):
max_len = self.config.data.max_seq_len
# Create per-residue labels
labels = torch.tensor([1 if residue.islower() else 0 for residue in sequence], dtype=torch.float)
if len(labels) < max_len: # Padding if sequence shorter than tokenizer truncation length
padded_labels = torch.cat(
[labels, torch.full(size=(max_len - len(labels),), fill_value=self.config.model.label_pad_value)]
)
else: # Truncation otherwise
padded_labels = labels[:max_len]
return padded_labels
def collate_fn(batch):
input_ids = torch.stack([item['input_ids'].squeeze(0) for item in batch])
masks = torch.stack([item['attention_mask'].squeeze(0) for item in batch])
labels = torch.stack([item['labels'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': masks,
'labels': labels
}
class MembraneDataModule(pl.LightningDataModule):
def __init__(self, config, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.collate_fn = collate_fn
self.batch_size = config.data.batch_size
def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True)
def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True)
def test_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True)
def get_datasets(config):
"""Helper method to grab datasets to quickly init data module in main.py"""
esm_model = AutoModel.from_pretrained(config.lm.pretrained_esm)
tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
train_dataset = MembraneDataset(config, config.data.train)
val_dataset = MembraneDataset(config, config.data.val)
test_dataset = MembraneDataset(config, config.data.test)
return {
"train": train_dataset,
"val": val_dataset,
"test": test_dataset
}