| import random | |
| def random_change_augmentation(aas, cfg): | |
| residue_tokens = ("A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y") | |
| stracture_aware_tokens = ("a", "c", "d", "e", "f", "g", "h", "i", "k", "l", "m", "n", "p", "q", "r", "s", "t", "v", "w", "y") | |
| length = len(aas) | |
| swap_indices = random.sample( | |
| range(length), int(length * cfg.random_change_ratio) | |
| ) | |
| new_aas = "" | |
| for i, aa in enumerate(aas): | |
| if i in swap_indices: | |
| if aas[i] in residue_tokens: | |
| new_aas += random.choice(residue_tokens) | |
| elif aas[i] in stracture_aware_tokens: | |
| new_aas += random.choice(stracture_aware_tokens) | |
| else: | |
| new_aas += aa | |
| return new_aas | |
| def mask_augmentation(aas, cfg): | |
| length = len(aas) | |
| swap_indices = random.sample( | |
| range(0, length // cfg.token_length), | |
| int(length // cfg.token_length * cfg.mask_ratio), | |
| ) | |
| for ith in swap_indices: | |
| aas = ( | |
| aas[: ith * cfg.token_length] | |
| + "@" * cfg.token_length | |
| + aas[(ith + 1) * cfg.token_length :] | |
| ) | |
| aas = aas.replace("@@", "<mask>").replace("@", "<mask>") | |
| return aas | |
| def random_delete_augmentation(aas, cfg): | |
| length = len(aas) | |
| swap_indices = random.sample( | |
| range(0, length // cfg.token_length), | |
| int(length // cfg.token_length * cfg.random_delete_ratio), | |
| ) | |
| for ith in swap_indices: | |
| aas = ( | |
| aas[: ith * cfg.token_length] | |
| + "@" * cfg.token_length | |
| + aas[(ith + 1) * cfg.token_length :] | |
| ) | |
| aas = aas.replace("@@", "").replace("@", "") | |
| return aas | |
| def truncate_augmentation(aas, cfg): | |
| length = len(aas) | |
| if length > cfg.max_length: | |
| diff = length - cfg.max_length | |
| start = random.randint(0, diff) | |
| return aas[start : start + cfg.max_length] | |
| else: | |
| return aas |