| import random | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from augmentation import ( | |
| mask_augmentation, | |
| random_change_augmentation, | |
| random_delete_augmentation, | |
| truncate_augmentation, | |
| ) | |
| def tokenize_input(cfg, text): | |
| inputs = cfg.tokenizer( | |
| text, | |
| add_special_tokens=True, | |
| max_length=cfg.max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_offsets_mapping=False, | |
| return_attention_mask=True, | |
| ) | |
| for k, v in inputs.items(): | |
| inputs[k] = torch.tensor(v, dtype=torch.long) | |
| return inputs | |
| def one_hot_encoding(aa, amino_acids, cfg): | |
| aa = aa[: cfg.max_length].ljust(cfg.max_length, " ") | |
| one_hot = np.zeros((len(aa), len(amino_acids))) | |
| for i, a in enumerate(aa): | |
| if a in amino_acids: | |
| one_hot[i, amino_acids.index(a)] = 1 | |
| return one_hot | |
| def one_hot_encode_input(text, cfg): | |
| inputs = one_hot_encoding(text, ("A","C","D","E","F","G","H","I","K","L","M","N","P","Q","R","S","T","V","W","Y"," "), cfg) | |
| return torch.tensor(inputs, dtype=torch.float) | |
| class PLTNUMDataset(Dataset): | |
| def __init__(self, cfg, df, train=True): | |
| self.df = df | |
| self.cfg = cfg | |
| self.train = train | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| data = self.df.iloc[idx] | |
| aas = self._adjust_sequence_length(data[self.cfg.sequence_col]) | |
| if self.train: | |
| aas = self._apply_augmentation(aas) | |
| aas = aas.replace("__", "<pad>") | |
| inputs = tokenize_input(self.cfg, aas) | |
| if "target" in data: | |
| return inputs, torch.tensor(data["target"], dtype=torch.float32) | |
| return inputs, np.nan | |
| def _adjust_sequence_length(self, aas): | |
| max_length = (self.cfg.max_length - 2) * self.cfg.token_length | |
| if len(aas) > max_length: | |
| if self.cfg.used_sequence == "left": | |
| return aas[: max_length] | |
| elif self.cfg.used_sequence == "right": | |
| return aas[-max_length:] | |
| elif self.cfg.used_sequence == "both": | |
| half_max_len = max_length // 2 | |
| return aas[:half_max_len] + "__" + aas[-half_max_len:] | |
| elif self.cfg.used_sequence == "internal": | |
| offset = (len(aas) - max_length) // 2 | |
| return aas[offset:offset + max_length] | |
| return aas | |
| def _apply_augmentation(self, aas): | |
| if self.cfg.random_change_ratio > 0: | |
| aas = random_change_augmentation(aas, self.cfg) | |
| if ( | |
| random.random() <= self.cfg.random_delete_prob | |
| ) and self.cfg.random_delete_ratio > 0: | |
| aas = random_delete_augmentation(aas, self.cfg) | |
| if (random.random() <= self.cfg.mask_prob) and self.cfg.mask_ratio > 0: | |
| aas = mask_augmentation(aas, self.cfg) | |
| if random.random() <= self.cfg.truncate_augmentation_prob: | |
| aas = truncate_augmentation(aas, self.cfg) | |
| return aas | |
| class LSTMDataset(Dataset): | |
| def __init__(self, cfg, df, train=True): | |
| self.df = df | |
| self.cfg = cfg | |
| self.train = train | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| data = self.df.iloc[idx] | |
| aas = data[self.cfg.sequence_col] | |
| aas = self._adjust_sequence_length(aas) | |
| aas = aas.replace("__", "<pad>") | |
| inputs = one_hot_encode_input(aas, self.cfg) | |
| return inputs, torch.tensor(data["target"], dtype=torch.float32) | |
| def _adjust_sequence_length(self, aas): | |
| max_length = (self.cfg.max_length - 2) * self.cfg.token_length | |
| if len(aas) > max_length: | |
| if self.cfg.used_sequence == "left": | |
| return aas[:max_length] | |
| elif self.cfg.used_sequence == "right": | |
| return aas[-max_length:] | |
| elif self.cfg.used_sequence == "both": | |
| half_max_len = max_length // 2 | |
| return aas[:half_max_len] + "__" + aas[-half_max_len:] | |
| elif self.cfg.used_sequence == "internal": | |
| offset = (len(aas) - max_length) // 2 | |
| return aas[offset:offset + max_length] | |
| return aas | |