zhangzhi's picture
init commit
a476bbf verified
raw
history blame
No virus
6.26 kB
import pandas as pd
from typing import Dict, Literal, Optional
from datasets import Dataset, load_dataset
from datasets.dataset_dict import DatasetDict
from pydantic import BaseModel
class DatasetConfig(BaseModel):
dataset: Literal['ptm', 'uniref50']
dataset_type: Literal["csv", "huggingface", "fasta"]
# The path if local or the huggingface dataset name if huggingface
dataset_loc: str
# sample size to limit to, if any, usually for debugging
subsample_size: Optional[int] = None
"""
Args for splitting into train, val, test
to be updated once we have more options
"""
# split seed
split_seed: Optional[int] = None
# size of validation dataset
val_size: int
# size of test dataset
test_size: int
# name of the column that contains the sequence
sequence_column_name: str
max_sequence_length: Optional[int] = None
cache_dir: Optional[str] = None
def set_labels(result):
result["labels"] = result["input_ids"].copy()
return result
def train_val_test_split(
dataset_dict: DatasetDict,
config: DatasetConfig,
) -> DatasetDict:
"""
Given a dictionary of datasets that only contains the split "train",
optionally subsamples it, and then splits it
so that it has potentially 3 splits: "train", "val", "test", where
"val" and "test" splits do not exist if the specified sizes are 0
"""
assert set(dataset_dict.keys()) == {
"train"
}, f"{train_val_test_split.__name__} expects its input to have the keys \
['train'] but the input has keys {list(dataset_dict.keys())}"
dataset = dataset_dict["train"]
val_size = config.val_size
test_size = config.test_size
assert isinstance(
dataset, Dataset
), f"Invalid dataset type {type(dataset)}, only datasets.Dataset allowed"
dataset = dataset.shuffle(seed=config.split_seed)
if config.subsample_size is not None:
dataset = dataset.select(range(config.subsample_size))
valtest_size = val_size + test_size
if valtest_size > 0:
train_valtest = dataset.train_test_split(
test_size=val_size + test_size,
shuffle=False,
)
split_dict = {
"train": train_valtest["train"],
}
if test_size > 0 and val_size > 0:
test_val = train_valtest["test"].train_test_split(
test_size=test_size,
shuffle=False,
)
split_dict["val"] = test_val["train"]
split_dict["test"] = test_val["test"]
elif val_size > 0:
split_dict["val"] = train_valtest["test"]
else:
split_dict["test"] = train_valtest["test"]
else:
split_dict = {
"train": dataset,
}
split_dataset_dict = DatasetDict(split_dict)
return split_dataset_dict
def load_ptm_dataset(df: pd.DataFrame, config: DatasetConfig) -> DatasetDict:
ds = Dataset.from_pandas(df)
ds_dict = DatasetDict({"train": ds})
return train_val_test_split(ds_dict, config)
def create_token_dict_from_dataframe(
df, seq_col="ori_seq", pos_col="pos", token_col="token"
):
result_dict = {}
for index, row in df.iterrows():
ac_id = row[seq_col]
pos = row[pos_col]
token = row[token_col]
if ac_id not in result_dict:
result_dict[ac_id] = {}
result_dict[ac_id][pos] = token
return result_dict
def subsitute_tokens(sequence_lst, token_dict):
if isinstance(sequence_lst, list):
return [
_substitute_token(sequence, token_dict[sequence])
for sequence in sequence_lst
]
elif isinstance(sequence_lst, str):
return _substitute_token(sequence_lst, token_dict[sequence_lst])
def _substitute_token(sequence, token_dict):
result = list(sequence)
for position, new_tokens in token_dict.items():
result[position] = new_tokens
return "".join(result)
def construct_ptm_seq(
batch, tokenizer, ptm_token_dict, sequence_column_name, max_sequence_length
):
"""
apply transform to the batch to replace the tokens with the PTM tokens
"""
batch['wt_seq'] = batch[sequence_column_name]
batch[sequence_column_name] = subsitute_tokens(
batch[sequence_column_name], ptm_token_dict
)
batch['ptm_seq'] = batch[sequence_column_name]
batch["input_ids"] = tokenizer(
batch[sequence_column_name],
add_special_tokens=True,
max_sequence_length=max_sequence_length,
)
# batch["labels"] = batch["input_ids"]
return batch
def get_ptm_dataset(config_dict: Dict, tokenizer) -> Dataset:
config = DatasetConfig(**config_dict)
if config.dataset_type == "csv":
df = pd.read_csv(config.dataset_loc)
ptm_token_dict = create_token_dict_from_dataframe(df)
df.drop(df.filter(regex="Unnamed"), axis=1, inplace=True)
df.drop_duplicates(subset=config.sequence_column_name, inplace=True)
split_dict = load_ptm_dataset(df, config)
else:
raise ValueError(f"Invalid dataset_type {config.dataset_type}!")
split_dict = split_dict.map(
lambda e: construct_ptm_seq(
batch=e,
tokenizer=tokenizer,
ptm_token_dict=ptm_token_dict,
sequence_column_name=config.sequence_column_name,
max_sequence_length=config.max_sequence_length,
),
batched=True,
keep_in_memory=True,
)
return split_dict
if __name__ == "__main__":
from protein_lm.tokenizer.tokenizer import PTMTokenizer
from transformers import DataCollatorWithPadding, default_data_collator
tokenizer = PTMTokenizer()
data_collator = DataCollatorWithPadding(
tokenizer=tokenizer, padding="longest", max_length=1024, return_tensors="pt"
)
config_dict = {
"dataset_type": "csv",
"dataset_loc": "protein_lm/dataset/ptm_labels.csv",
"subsample_size": None,
"val_size": 0,
"test_size": 0,
"sequence_column_name": "ori_seq",
"max_sequence_length": None,
"dataset": "ptm",
}
dataset = get_ptm_dataset(config_dict, tokenizer)
samples = dataset["train"][:8]