zhangzhi's picture
init commit
a476bbf verified
raw
history blame
No virus
1.97 kB
import os
import pandas as pd
from Bio import SeqIO
from typing import Dict, Literal, Optional
from datasets import Dataset, load_dataset
from datasets.dataset_dict import DatasetDict
from typing import Dict, Literal, Optional
from protein_lm.modeling.getters.ptm_dataset import DatasetConfig, train_val_test_split
def read_fasta_file(fasta_file_path, subsample_size):
ids = []
seqs = []
with open(fasta_file_path, "r") as fasta_file:
for i, record in enumerate(SeqIO.parse(fasta_file, "fasta")):
if subsample_size and i >= subsample_size:
break
ids.append(record.id)
seqs.append(str(record.seq))
return {"id": ids, "seq": seqs}
def load_uniref_dataset(seq_dict, config) -> DatasetDict:
ds = Dataset.from_dict(seq_dict)
ds_dict = DatasetDict({"train": ds})
return train_val_test_split(ds_dict, config)
def seq2token(batch, tokenizer, sequence_column_name, max_sequence_length):
batch["input_ids"] = tokenizer(
batch[sequence_column_name],
add_special_tokens=True,
max_sequence_length=max_sequence_length,
)
return batch
def get_uniref_dataset(config: Dict, tokenizer) -> Dataset:
# config = DatasetConfig(**config_dict)
if config.cache_dir is not None and os.path.exists(config.cache_dir):
split_dict = DatasetDict.load_from_disk(config.cache_dir)
return split_dict
seq_dict = read_fasta_file(config.dataset_loc, config.subsample_size)
split_dict = load_uniref_dataset(seq_dict, config)
split_dict = split_dict.map(
lambda e: seq2token(
batch=e,
tokenizer=tokenizer,
sequence_column_name="seq",
max_sequence_length=config.max_sequence_length,
),
batched=True,
)
if config.cache_dir is not None:
os.makedirs(config.cache_dir, exist_ok=True)
split_dict.save_to_disk(config.cache_dir)
return split_dict