File size: 6,255 Bytes
a476bbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import pandas as pd
from typing import Dict, Literal, Optional
from datasets import Dataset, load_dataset
from datasets.dataset_dict import DatasetDict
from pydantic import BaseModel
class DatasetConfig(BaseModel):
dataset: Literal['ptm', 'uniref50']
dataset_type: Literal["csv", "huggingface", "fasta"]
# The path if local or the huggingface dataset name if huggingface
dataset_loc: str
# sample size to limit to, if any, usually for debugging
subsample_size: Optional[int] = None
"""
Args for splitting into train, val, test
to be updated once we have more options
"""
# split seed
split_seed: Optional[int] = None
# size of validation dataset
val_size: int
# size of test dataset
test_size: int
# name of the column that contains the sequence
sequence_column_name: str
max_sequence_length: Optional[int] = None
cache_dir: Optional[str] = None
def set_labels(result):
result["labels"] = result["input_ids"].copy()
return result
def train_val_test_split(
dataset_dict: DatasetDict,
config: DatasetConfig,
) -> DatasetDict:
"""
Given a dictionary of datasets that only contains the split "train",
optionally subsamples it, and then splits it
so that it has potentially 3 splits: "train", "val", "test", where
"val" and "test" splits do not exist if the specified sizes are 0
"""
assert set(dataset_dict.keys()) == {
"train"
}, f"{train_val_test_split.__name__} expects its input to have the keys \
['train'] but the input has keys {list(dataset_dict.keys())}"
dataset = dataset_dict["train"]
val_size = config.val_size
test_size = config.test_size
assert isinstance(
dataset, Dataset
), f"Invalid dataset type {type(dataset)}, only datasets.Dataset allowed"
dataset = dataset.shuffle(seed=config.split_seed)
if config.subsample_size is not None:
dataset = dataset.select(range(config.subsample_size))
valtest_size = val_size + test_size
if valtest_size > 0:
train_valtest = dataset.train_test_split(
test_size=val_size + test_size,
shuffle=False,
)
split_dict = {
"train": train_valtest["train"],
}
if test_size > 0 and val_size > 0:
test_val = train_valtest["test"].train_test_split(
test_size=test_size,
shuffle=False,
)
split_dict["val"] = test_val["train"]
split_dict["test"] = test_val["test"]
elif val_size > 0:
split_dict["val"] = train_valtest["test"]
else:
split_dict["test"] = train_valtest["test"]
else:
split_dict = {
"train": dataset,
}
split_dataset_dict = DatasetDict(split_dict)
return split_dataset_dict
def load_ptm_dataset(df: pd.DataFrame, config: DatasetConfig) -> DatasetDict:
ds = Dataset.from_pandas(df)
ds_dict = DatasetDict({"train": ds})
return train_val_test_split(ds_dict, config)
def create_token_dict_from_dataframe(
df, seq_col="ori_seq", pos_col="pos", token_col="token"
):
result_dict = {}
for index, row in df.iterrows():
ac_id = row[seq_col]
pos = row[pos_col]
token = row[token_col]
if ac_id not in result_dict:
result_dict[ac_id] = {}
result_dict[ac_id][pos] = token
return result_dict
def subsitute_tokens(sequence_lst, token_dict):
if isinstance(sequence_lst, list):
return [
_substitute_token(sequence, token_dict[sequence])
for sequence in sequence_lst
]
elif isinstance(sequence_lst, str):
return _substitute_token(sequence_lst, token_dict[sequence_lst])
def _substitute_token(sequence, token_dict):
result = list(sequence)
for position, new_tokens in token_dict.items():
result[position] = new_tokens
return "".join(result)
def construct_ptm_seq(
batch, tokenizer, ptm_token_dict, sequence_column_name, max_sequence_length
):
"""
apply transform to the batch to replace the tokens with the PTM tokens
"""
batch['wt_seq'] = batch[sequence_column_name]
batch[sequence_column_name] = subsitute_tokens(
batch[sequence_column_name], ptm_token_dict
)
batch['ptm_seq'] = batch[sequence_column_name]
batch["input_ids"] = tokenizer(
batch[sequence_column_name],
add_special_tokens=True,
max_sequence_length=max_sequence_length,
)
# batch["labels"] = batch["input_ids"]
return batch
def get_ptm_dataset(config_dict: Dict, tokenizer) -> Dataset:
config = DatasetConfig(**config_dict)
if config.dataset_type == "csv":
df = pd.read_csv(config.dataset_loc)
ptm_token_dict = create_token_dict_from_dataframe(df)
df.drop(df.filter(regex="Unnamed"), axis=1, inplace=True)
df.drop_duplicates(subset=config.sequence_column_name, inplace=True)
split_dict = load_ptm_dataset(df, config)
else:
raise ValueError(f"Invalid dataset_type {config.dataset_type}!")
split_dict = split_dict.map(
lambda e: construct_ptm_seq(
batch=e,
tokenizer=tokenizer,
ptm_token_dict=ptm_token_dict,
sequence_column_name=config.sequence_column_name,
max_sequence_length=config.max_sequence_length,
),
batched=True,
keep_in_memory=True,
)
return split_dict
if __name__ == "__main__":
from protein_lm.tokenizer.tokenizer import PTMTokenizer
from transformers import DataCollatorWithPadding, default_data_collator
tokenizer = PTMTokenizer()
data_collator = DataCollatorWithPadding(
tokenizer=tokenizer, padding="longest", max_length=1024, return_tensors="pt"
)
config_dict = {
"dataset_type": "csv",
"dataset_loc": "protein_lm/dataset/ptm_labels.csv",
"subsample_size": None,
"val_size": 0,
"test_size": 0,
"sequence_column_name": "ori_seq",
"max_sequence_length": None,
"dataset": "ptm",
}
dataset = get_ptm_dataset(config_dict, tokenizer)
samples = dataset["train"][:8]
|