|
"""Ten Species Dataset. |
|
|
|
Load dataset from HF; tokenize 'on-the-fly' |
|
""" |
|
|
|
import random |
|
|
|
import datasets |
|
import torch |
|
import transformers |
|
|
|
STRING_COMPLEMENT_MAP = { |
|
"A": "T", "C": "G", "G": "C", "T": "A", |
|
"a": "t", "c": "g", "g": "c", "t": "a", |
|
"N": "N", "n": "n", |
|
} |
|
|
|
|
|
def coin_flip(p=0.5): |
|
"""Flip a (potentially weighted) coin.""" |
|
return random.random() > p |
|
|
|
|
|
def string_reverse_complement(seq): |
|
"""Reverse complement a DNA sequence.""" |
|
rev_comp = "" |
|
for base in seq[::-1]: |
|
if base in STRING_COMPLEMENT_MAP: |
|
rev_comp += STRING_COMPLEMENT_MAP[base] |
|
|
|
else: |
|
rev_comp += base |
|
return rev_comp |
|
|
|
class TenSpeciesDataset(torch.utils.data.Dataset): |
|
"""Ten Species Dataset. |
|
|
|
Tokenization happens on the fly. |
|
""" |
|
def __init__( |
|
self, |
|
split: str, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
max_length: int = 1024, |
|
rc_aug: bool = False, |
|
add_special_tokens: bool = False, |
|
dataset=None): |
|
if dataset is None: |
|
dataset = datasets.load_dataset( |
|
'yairschiff/ten_species', |
|
split='train', |
|
chunk_length=max_length, |
|
overlap=0, |
|
trust_remote_code=True) |
|
self.dataset = dataset.train_test_split( |
|
test_size=0.05, seed=42)[split] |
|
else: |
|
self.dataset = dataset |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
self.rc_aug = rc_aug |
|
self.add_special_tokens = add_special_tokens |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
"""Returns a sequence and species label.""" |
|
seq = self.dataset[idx]['sequence'] |
|
if self.rc_aug and coin_flip(): |
|
seq = string_reverse_complement(seq) |
|
seq = self.tokenizer( |
|
seq, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
add_special_tokens=self.add_special_tokens, |
|
return_attention_mask=True) |
|
|
|
input_ids = seq['input_ids'] |
|
attention_mask = seq['attention_mask'] |
|
input_ids = torch.LongTensor(input_ids) |
|
attention_mask = torch.LongTensor(attention_mask) |
|
|
|
return { |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask, |
|
'species_label': torch.LongTensor([ |
|
self.dataset[idx]['species_label']]).squeeze(), |
|
} |
|
|