|
import torch |
|
import pandas as pd |
|
import lightning.pytorch as pl |
|
|
|
from transformers import AutoModel, AutoTokenizer |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
class MembraneDataset(Dataset): |
|
def __init__(self, config, data_path): |
|
self.config = config |
|
self.data = pd.read_csv(data_path) |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.lm.pretrained_esm) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
sequence = self.data.iloc[idx]["Sequence"] |
|
|
|
tokens = self.tokenizer( |
|
sequence.upper(), |
|
return_tensors='pt', |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.config.data.max_seq_len, |
|
) |
|
|
|
labels = self.get_labels(sequence) |
|
|
|
return { |
|
"input_ids": tokens['input_ids'], |
|
"attention_mask": tokens['attention_mask'], |
|
"labels": labels |
|
} |
|
|
|
def get_labels(self, sequence): |
|
max_len = self.config.data.max_seq_len |
|
|
|
|
|
labels = torch.tensor([1 if residue.islower() else 0 for residue in sequence], dtype=torch.float) |
|
|
|
if len(labels) < max_len: |
|
padded_labels = torch.cat( |
|
[labels, torch.full(size=(max_len - len(labels),), fill_value=self.config.model.label_pad_value)] |
|
) |
|
else: |
|
padded_labels = labels[:max_len] |
|
return padded_labels |
|
|
|
|
|
def collate_fn(batch): |
|
input_ids = torch.stack([item['input_ids'].squeeze(0) for item in batch]) |
|
masks = torch.stack([item['attention_mask'].squeeze(0) for item in batch]) |
|
labels = torch.stack([item['labels'] for item in batch]) |
|
|
|
return { |
|
'input_ids': input_ids, |
|
'attention_mask': masks, |
|
'labels': labels |
|
} |
|
|
|
|
|
class MembraneDataModule(pl.LightningDataModule): |
|
def __init__(self, config, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn): |
|
super().__init__() |
|
self.train_dataset = train_dataset |
|
self.val_dataset = val_dataset |
|
self.test_dataset = test_dataset |
|
self.collate_fn = collate_fn |
|
self.batch_size = config.data.batch_size |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.train_dataset, |
|
batch_size=self.batch_size, |
|
collate_fn=self.collate_fn, |
|
num_workers=8, |
|
pin_memory=True) |
|
|
|
def val_dataloader(self): |
|
return DataLoader(self.val_dataset, |
|
batch_size=self.batch_size, |
|
collate_fn=self.collate_fn, |
|
num_workers=8, |
|
pin_memory=True) |
|
|
|
def test_dataloader(self): |
|
return DataLoader(self.test_dataset, |
|
batch_size=self.batch_size, |
|
collate_fn=self.collate_fn, |
|
num_workers=8, |
|
pin_memory=True) |
|
|
|
|
|
def get_datasets(config): |
|
"""Helper method to grab datasets to quickly init data module in main.py""" |
|
esm_model = AutoModel.from_pretrained(config.lm.pretrained_esm) |
|
tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm) |
|
|
|
train_dataset = MembraneDataset(config, config.data.train) |
|
val_dataset = MembraneDataset(config, config.data.val) |
|
test_dataset = MembraneDataset(config, config.data.test) |
|
|
|
return { |
|
"train": train_dataset, |
|
"val": val_dataset, |
|
"test": test_dataset |
|
} |