nas / PFMBench /src /data /msa_dataset.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
import torch
import os, glob
import pandas as pd
import numpy as np
from src.data.esm.sdk.api import ESMProtein
from src.utils.utils import pmap_multi
def read_data(aa_seq, pdb_path, target, unique_id, position=None, msa_aas=None):
try:
if unique_id is None:
unique_id = str(hash(aa_seq))
if position:
position = torch.tensor([int(ele) for ele in position.split(",")])
# name = target
if pdb_path:
structure = ESMProtein.from_pdb(pdb_path)
coordinates = structure.coordinates
else:
coordinates = None
return {
'name': target,
'seq': aa_seq,
'X': coordinates,
'label': torch.tensor([0]),
'unique_id': unique_id,
'pdb_path': pdb_path,
'position': position,
'msa_aas': msa_aas
}
except:
return None
class MSADataset(torch.utils.data.Dataset):
def __init__(
self,
msa_csv_path: str,
type: str = "center"
):
self.msa_csv_path = msa_csv_path
msa_df = pd.read_csv(self.msa_csv_path)
# msa_df = msa_df.drop_duplicates(subset="aa_seq", keep="first")
if type == "center":
msa_df = msa_df[msa_df["type"] == type]
else:
center_df = msa_df[msa_df["type"] == "center"]
id_seq_dict = dict(zip(center_df['unique_id'], center_df['aa_seq_ori']))
msa_df = msa_df[msa_df["type"] == type]
def compute_position_diff(x):
seq_with_gap = x["aa_seq_ori"]
center_seq = id_seq_dict[x["target"]]
pos_diff, msa_pos_aa = [], []
center_idx = 0
for i, aa in enumerate(seq_with_gap):
if aa == "-":
continue
if center_idx >= len(center_seq):
break
if aa.upper() != center_seq[center_idx].upper():
pos_diff.append(str(i))
msa_pos_aa.append(aa.upper())
center_idx += 1
return ",".join(pos_diff)
def compute_aa_diff(x):
seq_with_gap = x["aa_seq_ori"]
center_seq = id_seq_dict[x["target"]]
pos_diff, msa_pos_aa = [], []
center_idx = 0
for i, aa in enumerate(seq_with_gap):
if aa == "-":
continue
if center_idx >= len(center_seq):
break
if aa.upper() != center_seq[center_idx].upper():
pos_diff.append(str(i))
msa_pos_aa.append(aa.upper())
center_idx += 1
return "".join(msa_pos_aa)
def compute_aa_diff(x):
seq_with_gap = x["aa_seq_ori"]
center_seq = id_seq_dict[x["target"]]
pos_diff, msa_pos_aa = [], []
center_idx = 0
for i, aa in enumerate(seq_with_gap):
if aa == "-":
continue
if center_idx >= len(center_seq):
break
if aa.upper() != center_seq[center_idx].upper():
pos_diff.append(str(i))
msa_pos_aa.append(aa.upper())
center_idx += 1
return "".join(msa_pos_aa)
msa_df["position"] = msa_df.apply(compute_position_diff, axis=1)
msa_df["msa_aas"] = msa_df.apply(compute_aa_diff, axis=1)
path_list = []
for i in range(len(msa_df)):
path_list.append(
(
msa_df.iloc[i].get('aa_seq'),
msa_df.iloc[i].get('pdb_path'),
msa_df.iloc[i].get('target'),
msa_df.iloc[i].get('unique_id'),
msa_df.iloc[i].get('position'),
msa_df.iloc[i].get('msa_aas'),
)
)
self.data = pmap_multi(read_data, path_list, n_jobs=8)
self.data = [d for d in self.data if d is not None]
def __len__(self):
return len(self.data)
def __getitem__(
self,
idx
):
return self.data[idx]
if __name__ == "__main__":
msa_data_center = MSADataset(
msa_csv_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/zeroshot/msa/msa_samples_zeroshot_w_pdb.csv",
type="center"
)
print(f"length of msa dataset: {len(msa_data_center)}...")
msa_data_msa = MSADataset(
msa_csv_path = "/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/zeroshot/msa/msa_samples_zeroshot_w_pdb.csv",
type="msa"
)
print(f"length of msa dataset: {len(msa_data_msa)}...")