|
import pandas as pd |
|
from typing import Dict, Literal, Optional |
|
from datasets import Dataset, load_dataset |
|
from datasets.dataset_dict import DatasetDict |
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
|
class DatasetConfig(BaseModel): |
|
dataset: Literal['ptm', 'uniref50'] |
|
dataset_type: Literal["csv", "huggingface", "fasta"] |
|
|
|
|
|
dataset_loc: str |
|
|
|
|
|
subsample_size: Optional[int] = None |
|
|
|
""" |
|
Args for splitting into train, val, test |
|
to be updated once we have more options |
|
""" |
|
|
|
split_seed: Optional[int] = None |
|
|
|
val_size: int |
|
|
|
test_size: int |
|
|
|
|
|
sequence_column_name: str |
|
|
|
max_sequence_length: Optional[int] = None |
|
cache_dir: Optional[str] = None |
|
|
|
|
|
def set_labels(result): |
|
result["labels"] = result["input_ids"].copy() |
|
return result |
|
|
|
|
|
def train_val_test_split( |
|
dataset_dict: DatasetDict, |
|
config: DatasetConfig, |
|
) -> DatasetDict: |
|
""" |
|
Given a dictionary of datasets that only contains the split "train", |
|
optionally subsamples it, and then splits it |
|
so that it has potentially 3 splits: "train", "val", "test", where |
|
"val" and "test" splits do not exist if the specified sizes are 0 |
|
""" |
|
assert set(dataset_dict.keys()) == { |
|
"train" |
|
}, f"{train_val_test_split.__name__} expects its input to have the keys \ |
|
['train'] but the input has keys {list(dataset_dict.keys())}" |
|
|
|
dataset = dataset_dict["train"] |
|
|
|
val_size = config.val_size |
|
test_size = config.test_size |
|
|
|
assert isinstance( |
|
dataset, Dataset |
|
), f"Invalid dataset type {type(dataset)}, only datasets.Dataset allowed" |
|
|
|
dataset = dataset.shuffle(seed=config.split_seed) |
|
|
|
if config.subsample_size is not None: |
|
dataset = dataset.select(range(config.subsample_size)) |
|
|
|
valtest_size = val_size + test_size |
|
|
|
if valtest_size > 0: |
|
train_valtest = dataset.train_test_split( |
|
test_size=val_size + test_size, |
|
shuffle=False, |
|
) |
|
split_dict = { |
|
"train": train_valtest["train"], |
|
} |
|
if test_size > 0 and val_size > 0: |
|
test_val = train_valtest["test"].train_test_split( |
|
test_size=test_size, |
|
shuffle=False, |
|
) |
|
split_dict["val"] = test_val["train"] |
|
split_dict["test"] = test_val["test"] |
|
elif val_size > 0: |
|
split_dict["val"] = train_valtest["test"] |
|
else: |
|
split_dict["test"] = train_valtest["test"] |
|
else: |
|
split_dict = { |
|
"train": dataset, |
|
} |
|
|
|
split_dataset_dict = DatasetDict(split_dict) |
|
return split_dataset_dict |
|
|
|
|
|
|
|
|
|
def load_ptm_dataset(df: pd.DataFrame, config: DatasetConfig) -> DatasetDict: |
|
ds = Dataset.from_pandas(df) |
|
ds_dict = DatasetDict({"train": ds}) |
|
return train_val_test_split(ds_dict, config) |
|
|
|
|
|
def create_token_dict_from_dataframe( |
|
df, seq_col="ori_seq", pos_col="pos", token_col="token" |
|
): |
|
result_dict = {} |
|
|
|
for index, row in df.iterrows(): |
|
ac_id = row[seq_col] |
|
pos = row[pos_col] |
|
token = row[token_col] |
|
|
|
if ac_id not in result_dict: |
|
result_dict[ac_id] = {} |
|
|
|
result_dict[ac_id][pos] = token |
|
|
|
return result_dict |
|
|
|
|
|
def subsitute_tokens(sequence_lst, token_dict): |
|
if isinstance(sequence_lst, list): |
|
return [ |
|
_substitute_token(sequence, token_dict[sequence]) |
|
for sequence in sequence_lst |
|
] |
|
elif isinstance(sequence_lst, str): |
|
return _substitute_token(sequence_lst, token_dict[sequence_lst]) |
|
|
|
|
|
def _substitute_token(sequence, token_dict): |
|
result = list(sequence) |
|
for position, new_tokens in token_dict.items(): |
|
result[position] = new_tokens |
|
return "".join(result) |
|
|
|
|
|
def construct_ptm_seq( |
|
batch, tokenizer, ptm_token_dict, sequence_column_name, max_sequence_length |
|
): |
|
""" |
|
apply transform to the batch to replace the tokens with the PTM tokens |
|
""" |
|
batch['wt_seq'] = batch[sequence_column_name] |
|
batch[sequence_column_name] = subsitute_tokens( |
|
batch[sequence_column_name], ptm_token_dict |
|
) |
|
batch['ptm_seq'] = batch[sequence_column_name] |
|
batch["input_ids"] = tokenizer( |
|
batch[sequence_column_name], |
|
add_special_tokens=True, |
|
max_sequence_length=max_sequence_length, |
|
) |
|
|
|
return batch |
|
|
|
|
|
def get_ptm_dataset(config_dict: Dict, tokenizer) -> Dataset: |
|
config = DatasetConfig(**config_dict) |
|
|
|
if config.dataset_type == "csv": |
|
df = pd.read_csv(config.dataset_loc) |
|
ptm_token_dict = create_token_dict_from_dataframe(df) |
|
df.drop(df.filter(regex="Unnamed"), axis=1, inplace=True) |
|
df.drop_duplicates(subset=config.sequence_column_name, inplace=True) |
|
split_dict = load_ptm_dataset(df, config) |
|
else: |
|
raise ValueError(f"Invalid dataset_type {config.dataset_type}!") |
|
|
|
split_dict = split_dict.map( |
|
lambda e: construct_ptm_seq( |
|
batch=e, |
|
tokenizer=tokenizer, |
|
ptm_token_dict=ptm_token_dict, |
|
sequence_column_name=config.sequence_column_name, |
|
max_sequence_length=config.max_sequence_length, |
|
), |
|
batched=True, |
|
keep_in_memory=True, |
|
) |
|
|
|
return split_dict |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from protein_lm.tokenizer.tokenizer import PTMTokenizer |
|
from transformers import DataCollatorWithPadding, default_data_collator |
|
|
|
tokenizer = PTMTokenizer() |
|
data_collator = DataCollatorWithPadding( |
|
tokenizer=tokenizer, padding="longest", max_length=1024, return_tensors="pt" |
|
) |
|
config_dict = { |
|
"dataset_type": "csv", |
|
"dataset_loc": "protein_lm/dataset/ptm_labels.csv", |
|
"subsample_size": None, |
|
"val_size": 0, |
|
"test_size": 0, |
|
"sequence_column_name": "ori_seq", |
|
"max_sequence_length": None, |
|
"dataset": "ptm", |
|
} |
|
dataset = get_ptm_dataset(config_dict, tokenizer) |
|
samples = dataset["train"][:8] |
|
|