File size: 3,676 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import pandas as pd
import lightning.pytorch as pl

from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader


class MembraneDataset(Dataset):
    def __init__(self, config, data_path):
        self.config = config
        self.data = pd.read_csv(data_path)
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.lm.pretrained_esm)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sequence = self.data.iloc[idx]["Sequence"]

        tokens = self.tokenizer(
            sequence.upper(),
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=self.config.data.max_seq_len,
        )

        labels = self.get_labels(sequence)

        return {
            "input_ids": tokens['input_ids'],
            "attention_mask": tokens['attention_mask'],
            "labels": labels
        }

    def get_labels(self, sequence):
        max_len = self.config.data.max_seq_len

        # Create per-residue labels
        labels = torch.tensor([1 if residue.islower() else 0 for residue in sequence], dtype=torch.float)
        
        if len(labels) < max_len: # Padding if sequence shorter than tokenizer truncation length
            padded_labels = torch.cat(
                [labels, torch.full(size=(max_len - len(labels),), fill_value=self.config.model.label_pad_value)]
            )
        else: # Truncation otherwise
            padded_labels = labels[:max_len]
        return padded_labels


def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'].squeeze(0) for item in batch])
    masks = torch.stack([item['attention_mask'].squeeze(0) for item in batch])
    labels = torch.stack([item['labels'] for item in batch])

    return {
        'input_ids': input_ids,
        'attention_mask': masks,
        'labels': labels
    }


class MembraneDataModule(pl.LightningDataModule):
    def __init__(self, config, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.collate_fn = collate_fn
        self.batch_size = config.data.batch_size

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          collate_fn=self.collate_fn,
                          num_workers=8,
                          pin_memory=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          collate_fn=self.collate_fn,
                          num_workers=8,
                          pin_memory=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          collate_fn=self.collate_fn,
                          num_workers=8,
                          pin_memory=True)
    

def get_datasets(config):
    """Helper method to grab datasets to quickly init data module in main.py"""
    esm_model = AutoModel.from_pretrained(config.lm.pretrained_esm)
    tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)

    train_dataset = MembraneDataset(config, config.data.train)
    val_dataset = MembraneDataset(config, config.data.val)
    test_dataset = MembraneDataset(config, config.data.test)
    
    return  {
        "train": train_dataset,
        "val": val_dataset,
        "test": test_dataset
    }