zhangzhi's picture
init commit
a476bbf verified
raw
history blame
No virus
666 Bytes
import pandas as pd
from typing import Dict, Literal, Optional
from datasets import Dataset, load_dataset
from datasets.dataset_dict import DatasetDict
from pydantic import BaseModel
from protein_lm.modeling.getters.ptm_dataset import get_ptm_dataset
from protein_lm.modeling.getters.uniref_dataset import get_uniref_dataset
def get_dataset(config_dict: Dict, tokenizer) -> Dataset:
if config_dict["dataset"] == "ptm":
return get_ptm_dataset(config_dict, tokenizer)
elif config_dict["dataset"] == "uniref50":
return get_uniref_dataset(config_dict, tokenizer)
else:
raise ValueError(f"Invalid dataset {config_dict['dataset']}!")