Spaces:
Running
on
L4
Running
on
L4
import json | |
import os | |
import shutil | |
import random | |
import sys | |
import time | |
from typing import List, Tuple, Optional | |
import Bio.PDB | |
import Bio.SeqUtils | |
import pandas as pd | |
import numpy as np | |
import requests | |
from rdkit import Chem | |
from rdkit.Chem import AllChem | |
BASE_FOLDER = "/tmp/" | |
OUTPUT_FOLDER = f"{BASE_FOLDER}/processed" | |
# https://storage.googleapis.com/plinder/2024-06/v2/index/annotation_table.parquet | |
PLINDER_ANNOTATIONS = f'{BASE_FOLDER}/raw_data/2024-06_v2_index_annotation_table.parquet' | |
# https://storage.googleapis.com/plinder/2024-06/v2/splits/split.parquet | |
PLINDER_SPLITS = f'{BASE_FOLDER}/raw_data/2024-06_v2_splits_split.parquet' | |
# https://console.cloud.google.com/storage/browser/_details/plinder/2024-06/v2/links/kind%3Dapo/links.parquet | |
PLINDER_LINKED_APO_MAP = f"{BASE_FOLDER}/raw_data/2024-06_v2_links_kind=apo_links.parquet" | |
# https://console.cloud.google.com/storage/browser/_details/plinder/2024-06/v2/links/kind%3Dpred/links.parquet | |
PLINDER_LINKED_PRED_MAP = f"{BASE_FOLDER}/raw_data/2024-06_v2_links_kind=pred_links.parquet" | |
# https://storage.googleapis.com/plinder/2024-06/v2/linked_structures/apo.zip | |
PLINDER_LINKED_APO_STRUCTURES = f"{BASE_FOLDER}/raw_data/2024-06_v2_linked_structures_apo" | |
# https://storage.googleapis.com/plinder/2024-06/v2/linked_structures/pred.zip | |
PLINDER_LINKED_PRED_STRUCTURES = f"{BASE_FOLDER}/raw_data/2024-06_v2_linked_structures_pred" | |
GSUTIL_PATH = f"{BASE_FOLDER}/google-cloud-sdk/bin/gsutil" | |
def get_cached_systems_to_train(recompute=False): | |
output_path = os.path.join(OUTPUT_FOLDER, "to_train.pickle") | |
if os.path.exists(output_path) and not recompute: | |
return pd.read_pickle(output_path) | |
""" | |
full: | |
loaded 1357906 409726 163816 433865 | |
loaded 990260 409726 125818 106411 | |
joined splits 409726 | |
Has splits 311008 | |
unique systems 311008 | |
split | |
train 309140 | |
test 1036 | |
val 832 | |
Name: count, dtype: int64 | |
Has affinity 36856 | |
Has affinity by splits split | |
train 36598 | |
test 142 | |
val 116 | |
Name: count, dtype: int64 | |
Total systems before pred 311008 | |
Total systems after pred 311008 | |
Has pred 83487 | |
Has apo 75127 | |
Has both 51506 | |
Has either 107108 | |
columns Index(['system_id', 'entry_pdb_id', 'ligand_binding_affinity', | |
'entry_release_date', 'system_pocket_UniProt', | |
'system_num_protein_chains', 'system_num_ligand_chains', 'uniqueness', | |
'split', 'cluster', 'cluster_for_val_split', | |
'system_pass_validation_criteria', 'system_pass_statistics_criteria', | |
'system_proper_num_ligand_chains', 'system_proper_pocket_num_residues', | |
'system_proper_num_interactions', | |
'system_proper_ligand_max_molecular_weight', | |
'system_has_binding_affinity', 'system_has_apo_or_pred', '_bucket_id', | |
'linked_pred_id', 'linked_apo_id'], | |
dtype='object') | |
total systems 311008 | |
""" | |
systems = pd.read_parquet(PLINDER_ANNOTATIONS, | |
columns=['system_id', 'entry_pdb_id', 'ligand_binding_affinity', | |
'entry_release_date', 'system_pocket_UniProt', 'entry_resolution', | |
'system_num_protein_chains', 'system_num_ligand_chains']) | |
splits = pd.read_parquet(PLINDER_SPLITS) | |
linked_pred = pd.read_parquet(PLINDER_LINKED_PRED_MAP) | |
linked_apo = pd.read_parquet(PLINDER_LINKED_APO_MAP) | |
print("loaded", len(systems), len(splits), len(linked_pred), len(linked_apo)) | |
# remove duplicated | |
systems = systems.drop_duplicates(subset=['system_id']) | |
splits = splits.drop_duplicates(subset=['system_id']) | |
linked_pred = linked_pred.drop_duplicates(subset=['reference_system_id']) | |
linked_apo = linked_apo.drop_duplicates(subset=['reference_system_id']) | |
print("loaded", len(systems), len(splits), len(linked_pred), len(linked_apo)) | |
# join splits | |
systems = pd.merge(systems, splits, on='system_id', how='inner') | |
print("joined splits", len(systems)) | |
systems['_bucket_id'] = systems['entry_pdb_id'].str[1:3] | |
# leave only with train/val/test splits | |
systems = systems[systems['split'].isin(['train', 'val', 'test'])] | |
print("Has splits", len(systems)) | |
print("unique systems", systems['system_id'].nunique()) | |
print(systems["split"].value_counts()) | |
print("Has affinity", len(systems[systems['ligand_binding_affinity'].notna()])) | |
# print has affinity by splits | |
print("Has affinity by splits", systems[systems['ligand_binding_affinity'].notna()]['split'].value_counts()) | |
print("Total systems before pred", len(systems)) | |
# join linked structures - allow to not have linked structures | |
systems = pd.merge(systems, linked_pred[['reference_system_id', 'id']], | |
left_on='system_id', right_on='reference_system_id', | |
how='left') | |
print("Total systems after pred", len(systems)) | |
# Rename the 'id' column from linked_pred to 'linked_pred_id' | |
systems.rename(columns={'id': 'linked_pred_id'}, inplace=True) | |
# Merge the result with linked_apo on the same condition | |
systems = pd.merge(systems, linked_apo[['reference_system_id', 'id']], | |
left_on='system_id', right_on='reference_system_id', | |
how='left') | |
# Rename the 'id' column from linked_apo to 'linked_apo_id' | |
systems.rename(columns={'id': 'linked_apo_id'}, inplace=True) | |
# Drop the reference_system_id columns that were added during the merge | |
systems.drop(columns=['reference_system_id_x', 'reference_system_id_y'], inplace=True) | |
cluster_sizes = systems["cluster"].value_counts() | |
systems["cluster_size"] = systems["cluster"].map(cluster_sizes) | |
# print(systems[['system_id', 'cluster', 'cluster_size']]) | |
print("Has pred", systems['linked_pred_id'].notna().sum()) | |
print("Has apo", systems['linked_apo_id'].notna().sum()) | |
print("Has both", (systems['linked_pred_id'].notna() & systems['linked_apo_id'].notna()).sum()) | |
print("Has either", (systems['linked_pred_id'].notna() | systems['linked_apo_id'].notna()).sum()) | |
print("columns", systems.columns) | |
systems.to_pickle(output_path) | |
return systems | |
def create_conformers(smiles, output_path, num_conformers=100, multiplier_samples=1): | |
target_mol = Chem.MolFromSmiles(smiles) | |
target_mol = Chem.AddHs(target_mol) | |
params = AllChem.ETKDGv3() | |
params.numThreads = 0 # Use all available threads | |
params.pruneRmsThresh = 0.1 # Pruning threshold for RMSD | |
conformer_ids = AllChem.EmbedMultipleConfs(target_mol, numConfs=num_conformers * multiplier_samples, params=params) | |
# Optional: Optimize each conformer using MMFF94 force field | |
# for conf_id in conformer_ids: | |
# AllChem.UFFOptimizeMolecule(target_mol, confId=conf_id) | |
# remove hydrogen atoms | |
target_mol = Chem.RemoveHs(target_mol) | |
# Save aligned conformers to a file (optional) | |
w = Chem.SDWriter(output_path) | |
for i, conf_id in enumerate(conformer_ids): | |
if i >= num_conformers: | |
break | |
w.write(target_mol, confId=conf_id) | |
w.close() | |
def do_robust_chain_object_renumber(chain: Bio.PDB.Chain.Chain, new_chain_id: str) -> Optional[Bio.PDB.Chain.Chain]: | |
all_residues = [res for res in chain.get_residues() | |
if "CA" in res and Bio.SeqUtils.seq1(res.get_resname()) not in ("X", "", " ")] | |
if not all_residues: | |
return None | |
res_and_res_id = [(res, res.get_id()[1]) for res in all_residues] | |
min_res_id = min([i[1] for i in res_and_res_id]) | |
if min_res_id < 1: | |
print("Negative res id", chain, min_res_id) | |
factor = -1 * min_res_id + 1 | |
res_and_res_id = [(res, res_id + factor) for res, res_id in res_and_res_id] | |
res_and_res_id_no_collisions = [] | |
for res, res_id in res_and_res_id[::-1]: | |
if res_and_res_id_no_collisions and res_and_res_id_no_collisions[-1][1] == res_id: | |
# there is a collision, usually an insertion residue | |
res_and_res_id_no_collisions = [(i, j + 1) for i, j in res_and_res_id_no_collisions] | |
res_and_res_id_no_collisions.append((res, res_id)) | |
first_res_id = min([i[1] for i in res_and_res_id_no_collisions]) | |
factor = 1 - first_res_id # start from 1 | |
new_chain = Bio.PDB.Chain.Chain(new_chain_id) | |
res_and_res_id_no_collisions.sort(key=lambda x: x[1]) | |
for res, res_id in res_and_res_id_no_collisions: | |
chain.detach_child(res.id) | |
res.id = (" ", res_id + factor, " ") | |
new_chain.add(res) | |
return new_chain | |
def robust_renumber_protein(pdb_path: str, output_path: str): | |
if pdb_path.endswith(".pdb"): | |
pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
pdb_struct = pdb_parser.get_structure("original_pdb", pdb_path) | |
elif pdb_path.endswith(".cif"): | |
pdb_struct = Bio.PDB.MMCIFParser().get_structure("original_pdb", pdb_path) | |
else: | |
raise ValueError("Unknown file type", pdb_path) | |
assert len(list(pdb_struct)) == 1, "can't extract if more than one model" | |
model = next(iter(pdb_struct)) | |
chains = list(model.get_chains()) | |
new_model = Bio.PDB.Model.Model(0) | |
chain_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" | |
for chain, chain_id in zip(chains, chain_ids): | |
new_chain = do_robust_chain_object_renumber(chain, chain_id) | |
if new_chain is None: | |
continue | |
new_model.add(new_chain) | |
new_struct = Bio.PDB.Structure.Structure("renumbered_pdb") | |
new_struct.add(new_model) | |
io = Bio.PDB.PDBIO() | |
io.set_structure(new_struct) | |
io.save(output_path) | |
def _get_extra(extra_to_save: int, res_before: List[int], res_after: List[int]) -> set: | |
take_from_before = random.randint(0, extra_to_save) | |
take_from_after = extra_to_save - take_from_before | |
if take_from_before > len(res_before): | |
take_from_after = extra_to_save - len(res_before) | |
take_from_before = len(res_before) | |
if take_from_after > len(res_after): | |
take_from_before = extra_to_save - len(res_after) | |
take_from_after = len(res_after) | |
extra_to_add = set() | |
if take_from_before > 0: | |
extra_to_add.update(res_before[-take_from_before:]) | |
extra_to_add.update(res_after[:take_from_after]) | |
return extra_to_add | |
def crop_protein_cont(gt_pdb_path: str, ligand_pos: np.ndarray, output_path: str, max_length: int, | |
distance_threshold: float): | |
protein = Chem.MolFromPDBFile(gt_pdb_path, sanitize=False) | |
ligand_size = ligand_pos.shape[0] | |
pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
gt_model = next(iter(pdb_parser.get_structure("gt_pdb", gt_pdb_path))) | |
all_res_ids_by_chain = {chain.id: sorted([res.id[1] for res in chain.get_residues() if "CA" in res]) | |
for chain in gt_model.get_chains()} | |
protein_conf = protein.GetConformer() | |
protein_pos = protein_conf.GetPositions() | |
protein_atoms = list(protein.GetAtoms()) | |
assert len(protein_pos) == len(protein_atoms), f"Positions and atoms mismatch in {gt_pdb_path}" | |
inter_dists = ligand_pos[:, np.newaxis, :] - protein_pos[np.newaxis, :, :] | |
inter_dists = np.sqrt((inter_dists ** 2).sum(-1)) | |
min_inter_dist_per_protein_atom = inter_dists.min(axis=0) | |
res_to_save_count = max_length - ligand_size | |
used_protein_idx = np.where(min_inter_dist_per_protein_atom < distance_threshold)[0] | |
pocket_residues_by_chain = {} | |
for idx in used_protein_idx: | |
res = protein_atoms[idx].GetPDBResidueInfo() | |
if res.GetIsHeteroAtom(): | |
continue | |
if res.GetChainId() not in pocket_residues_by_chain: | |
pocket_residues_by_chain[res.GetChainId()] = set() | |
# get residue chain | |
pocket_residues_by_chain[res.GetChainId()].add(res.GetResidueNumber()) | |
if not pocket_residues_by_chain: | |
print("No pocket residues found") | |
return -1 | |
# print("pocket_residues_by_chain", pocket_residues_by_chain) | |
complete_pocket = [] | |
extended_pocket_per_chain = {} | |
for chain_id, pocket_residues in pocket_residues_by_chain.items(): | |
max_pocket_res = max(pocket_residues) | |
min_pocket_res = min(pocket_residues) | |
extended_pocket_per_chain[chain_id] = {res_id for res_id in all_res_ids_by_chain[chain_id] | |
if min_pocket_res <= res_id <= max_pocket_res} | |
for res_id in extended_pocket_per_chain[chain_id]: | |
complete_pocket.append((chain_id, res_id)) | |
# print("extended_pocket_per_chain", pocket_residues_by_chain) | |
if len(complete_pocket) > res_to_save_count: | |
total_res_ids = sum([len(res_ids) for res_ids in all_res_ids_by_chain.values()]) | |
total_pocket_res = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()]) | |
print(f"Too many residues all: {total_res_ids} pocket:{total_pocket_res} {len(complete_pocket)} " | |
f"(ligand size: {ligand_size})") | |
return -1 | |
extra_to_save = res_to_save_count - len(complete_pocket) | |
# divide extra_to_save between chains | |
for chain_id, pocket_residues in extended_pocket_per_chain.items(): | |
extra_to_save_per_chain = extra_to_save // len(extended_pocket_per_chain) | |
res_before = [res_id for res_id in all_res_ids_by_chain[chain_id] if res_id < min(pocket_residues)] | |
res_after = [res_id for res_id in all_res_ids_by_chain[chain_id] if res_id > max(pocket_residues)] | |
extra_to_add = _get_extra(extra_to_save_per_chain, res_before, res_after) | |
for res_id in extra_to_add: | |
complete_pocket.append((chain_id, res_id)) | |
total_res_ids = sum([len(res_ids) for res_ids in all_res_ids_by_chain.values()]) | |
total_pocket_res = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()]) | |
total_extended_res = sum([len(res_ids) for res_ids in extended_pocket_per_chain.values()]) | |
print(f"Found valid pocket all: {total_res_ids} pocket:{total_pocket_res} {total_extended_res} " | |
f"{len(complete_pocket)} (ligand size: {ligand_size}) extra: {extra_to_save}") | |
# print("all_res_ids_by_chain", all_res_ids_by_chain) | |
# print("complete_pocket", sorted(complete_pocket)) | |
res_to_remove = [] | |
for res in gt_model.get_residues(): | |
if (res.parent.id, res.id[1]) not in complete_pocket or res.id[0].strip() != "" or res.id[2].strip() != "": | |
res_to_remove.append(res) | |
for res in res_to_remove: | |
gt_model[res.parent.id].detach_child(res.id) | |
io = Bio.PDB.PDBIO() | |
io.set_structure(gt_model) | |
io.save(output_path) | |
return len(complete_pocket) | |
def crop_protein_simple(gt_pdb_path: str, ligand_pos: np.ndarray, output_path: str, max_length: int): | |
protein = Chem.MolFromPDBFile(gt_pdb_path, sanitize=False) | |
ligand_size = ligand_pos.shape[0] | |
res_to_save_count = max_length - ligand_size | |
pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
gt_model = next(iter(pdb_parser.get_structure("gt_pdb", gt_pdb_path))) | |
protein_conf = protein.GetConformer() | |
protein_pos = protein_conf.GetPositions() | |
protein_atoms = list(protein.GetAtoms()) | |
assert len(protein_pos) == len(protein_atoms), f"Positions and atoms mismatch in {gt_pdb_path}" | |
inter_dists = ligand_pos[:, np.newaxis, :] - protein_pos[np.newaxis, :, :] | |
inter_dists = np.sqrt((inter_dists ** 2).sum(-1)) | |
min_inter_dist_per_protein_atom = inter_dists.min(axis=0) | |
protein_idx_by_dist = np.argsort(min_inter_dist_per_protein_atom) | |
pocket_residues_by_chain = {} | |
total_found = 0 | |
for idx in protein_idx_by_dist: | |
res = protein_atoms[idx].GetPDBResidueInfo() | |
if res.GetIsHeteroAtom(): | |
continue | |
if res.GetChainId() not in pocket_residues_by_chain: | |
pocket_residues_by_chain[res.GetChainId()] = set() | |
# get residue chain | |
pocket_residues_by_chain[res.GetChainId()].add(res.GetResidueNumber()) | |
total_found = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()]) | |
if total_found >= res_to_save_count: | |
break | |
print("saved with simple", total_found) | |
if not pocket_residues_by_chain: | |
print("No pocket residues found") | |
return -1 | |
res_to_remove = [] | |
for res in gt_model.get_residues(): | |
if res.id[1] not in pocket_residues_by_chain.get(res.parent.id, set()) \ | |
or res.id[0].strip() != "" or res.id[2].strip() != "": | |
res_to_remove.append(res) | |
for res in res_to_remove: | |
gt_model[res.parent.id].detach_child(res.id) | |
io = Bio.PDB.PDBIO() | |
io.set_structure(gt_model) | |
io.save(output_path) | |
return total_found | |
def cif_to_pdb(cif_path: str, pdb_path: str): | |
protein = Bio.PDB.MMCIFParser().get_structure("s_cif", cif_path) | |
io = Bio.PDB.PDBIO() | |
io.set_structure(protein) | |
io.save(pdb_path) | |
def get_chain_object_to_seq(chain: Bio.PDB.Chain.Chain) -> str: | |
res_id_to_res = {res.get_id()[1]: res for res in chain.get_residues() if "CA" in res} | |
if len(res_id_to_res) == 0: | |
print("skipping empty chain", chain.get_id()) | |
return "" | |
seq = "" | |
for i in range(1, max(res_id_to_res) + 1): | |
if i in res_id_to_res: | |
seq += Bio.SeqUtils.seq1(res_id_to_res[i].get_resname()) | |
else: | |
seq += "X" | |
return seq | |
def get_sequence_from_pdb(pdb_path: str) -> Tuple[str, List[int]]: | |
pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
pdb_struct = pdb_parser.get_structure("original_pdb", pdb_path) | |
# chain_to_seq = {chain.id: get_chain_object_to_seq(chain) for chain in pdb_struct.get_chains()} | |
all_chain_seqs = [ get_chain_object_to_seq(chain) for chain in pdb_struct.get_chains()] | |
chain_lengths = [len(seq) for seq in all_chain_seqs] | |
return ("X" * 20).join(all_chain_seqs), chain_lengths | |
from Bio import PDB | |
from Bio import pairwise2 | |
def extract_sequence(chain): | |
seq = '' | |
residues = [] | |
for res in chain.get_residues(): | |
seq_res = Bio.SeqUtils.seq1(res.get_resname()) | |
if seq_res in ('X', "", " "): | |
continue | |
seq += seq_res | |
residues.append(res) | |
return seq, residues | |
def map_residues(alignment, residues_gt, residues_pred): | |
idx_gt = 0 | |
idx_pred = 0 | |
mapping = [] | |
for i in range(len(alignment.seqA)): | |
aa_gt = alignment.seqA[i] | |
aa_pred = alignment.seqB[i] | |
res_gt = None | |
res_pred = None | |
if aa_gt != '-': | |
res_gt = residues_gt[idx_gt] | |
idx_gt += 1 | |
if aa_pred != '-': | |
res_pred = residues_pred[idx_pred] | |
idx_pred +=1 | |
if res_gt and res_pred: | |
mapping.append((res_gt, res_pred)) | |
return mapping | |
class ResidueSelect(PDB.Select): | |
def __init__(self, residues_to_select): | |
self.residues_to_select = set(residues_to_select) | |
def accept_residue(self, residue): | |
return residue in self.residues_to_select | |
def align_gt_and_input(gt_pdb_path, input_pdb_path, output_gt_path, output_input_path): | |
parser = PDB.PDBParser(QUIET=True) | |
gt_structure = parser.get_structure('gt', gt_pdb_path) | |
pred_structure = parser.get_structure('pred', input_pdb_path) | |
matched_residues_gt = [] | |
matched_residues_pred = [] | |
used_chain_pred = [] | |
total_mapping_size = 0 | |
for chain_gt in gt_structure.get_chains(): | |
seq_gt, residues_gt = extract_sequence(chain_gt) | |
best_alignment = None | |
best_chain_pred = None | |
best_score = -1 | |
best_residues_pred = None | |
# Find the best matching chain in pred | |
for chain_pred in pred_structure.get_chains(): | |
print("checking", chain_pred.get_id(), chain_gt.get_id()) | |
if chain_pred in used_chain_pred: | |
continue | |
seq_pred, residues_pred = extract_sequence(chain_pred) | |
print(seq_gt) | |
print(seq_pred) | |
alignments = pairwise2.align.globalxx(seq_gt, seq_pred, one_alignment_only=True) | |
if not alignments: | |
continue | |
print("checking2", chain_pred.get_id(), chain_gt.get_id()) | |
alignment = alignments[0] | |
score = alignment.score | |
if score > best_score: | |
best_score = score | |
best_alignment = alignment | |
best_chain_pred = chain_pred | |
best_residues_pred = residues_pred | |
if best_alignment: | |
mapping = map_residues(best_alignment, residues_gt, best_residues_pred) | |
total_mapping_size += len(mapping) | |
used_chain_pred.append(best_chain_pred) | |
for res_gt, res_pred in mapping: | |
matched_residues_gt.append(res_gt) | |
matched_residues_pred.append(res_pred) | |
else: | |
print(f"No matching chain found for chain {chain_gt.get_id()}") | |
print(f"Total mapping size: {total_mapping_size}") | |
# Write new PDB files with only matched residues | |
io = PDB.PDBIO() | |
io.set_structure(gt_structure) | |
io.save(output_gt_path, ResidueSelect(matched_residues_gt)) | |
io.set_structure(pred_structure) | |
io.save(output_input_path, ResidueSelect(matched_residues_pred)) | |
def validate_matching_input_gt(gt_pdb_path, input_pdb_path): | |
gt_residues = [res for res in PDB.PDBParser().get_structure('gt', gt_pdb_path).get_residues()] | |
input_residues = [res for res in PDB.PDBParser().get_structure('input', input_pdb_path).get_residues()] | |
if len(gt_residues) != len(input_residues): | |
print(f"Residue count mismatch: {len(gt_residues)} vs {len(input_residues)}") | |
return -1 | |
for res_gt, res_input in zip(gt_residues, input_residues): | |
if res_gt.get_resname() != res_input.get_resname(): | |
print(f"Residue name mismatch: {res_gt.get_resname()} vs {res_input.get_resname()}") | |
return -1 | |
return len(input_residues) | |
def prepare_system(row, system_folder, output_models_folder, output_jsons_folder, should_overwrite=False): | |
output_json_path = os.path.join(output_jsons_folder, f"{row['system_id']}.json") | |
if os.path.exists(output_json_path) and not should_overwrite: | |
return "Already exists" | |
plinder_gt_pdb_path = os.path.join(system_folder, f"receptor.pdb") | |
plinder_gt_ligand_paths = [] | |
plinder_gt_ligands_folder = os.path.join(system_folder, "ligand_files") | |
gt_output_path = os.path.join(output_models_folder, f"{row['system_id']}_gt.pdb") | |
gt_output_relative_path = "plinder_models/" + f"{row['system_id']}_gt.pdb" | |
tmp_input_path = os.path.join(output_models_folder, f"tmp_{row['system_id']}_input.pdb") | |
protein_input_path = os.path.join(output_models_folder, f"{row['system_id']}_input.pdb") | |
protein_input_relative_path = "plinder_models/" + f"{row['system_id']}_input.pdb" | |
print("Copying ground truth files") | |
if not os.path.exists(plinder_gt_pdb_path): | |
print("no receptor", plinder_gt_pdb_path) | |
return "No receptor" | |
tmp_gt_pdb_path = os.path.join(output_models_folder, f"tmp_{row['system_id']}_gt.pdb") | |
robust_renumber_protein(plinder_gt_pdb_path, tmp_gt_pdb_path) | |
ligand_pos_list = [] | |
for ligand_file in os.listdir(plinder_gt_ligands_folder): | |
if not ligand_file.endswith(".sdf"): | |
continue | |
plinder_gt_ligand_paths.append(os.path.join(plinder_gt_ligands_folder, ligand_file)) | |
loaded_ligand = Chem.MolFromMolFile(os.path.join(plinder_gt_ligands_folder, ligand_file)) | |
ligand_pos_list.append(loaded_ligand.GetConformer().GetPositions()) | |
if loaded_ligand is None: | |
print("failed to load", plinder_gt_ligand_paths[-1]) | |
return "Failed to load ligand" | |
# Crop ground truth protein, also removes insertion codes | |
ligand_pos = np.concatenate(ligand_pos_list, axis=0) | |
res_count_in_protein = crop_protein_cont(tmp_gt_pdb_path, ligand_pos, gt_output_path, max_length=350, | |
distance_threshold=5) | |
if res_count_in_protein == -1: | |
print("Failed to crop protein continously, using simple crop") | |
crop_protein_simple(tmp_gt_pdb_path, ligand_pos, gt_output_path, max_length=350) | |
os.remove(tmp_gt_pdb_path) | |
# Generate input protein structure | |
input_protein_source = None | |
if pd.notna(row["linked_apo_id"]): | |
apo_pdb_path = os.path.join(PLINDER_LINKED_APO_STRUCTURES, f"{row['linked_apo_id']}.cif") | |
try: | |
robust_renumber_protein(apo_pdb_path, tmp_input_path) | |
input_protein_source = "apo" | |
print("Using input apo", row['linked_apo_id']) | |
except Exception as e: | |
print("Problem with apo", e, row["linked_apo_id"], apo_pdb_path) | |
if not os.path.exists(tmp_input_path) and pd.notna(row["linked_pred_id"]): | |
pred_pdb_path = os.path.join(PLINDER_LINKED_PRED_STRUCTURES, f"{row['linked_pred_id']}.cif") | |
try: | |
# cif_to_pdb(pred_pdb_path, tmp_input_path) | |
robust_renumber_protein(pred_pdb_path, tmp_input_path) | |
input_protein_source = "pred" | |
print("Using input pred", row['linked_pred_id']) | |
except: | |
print("Problem with pred") | |
if not os.path.exists(tmp_input_path): | |
print("No linked structure found, running ESM") | |
url = "https://api.esmatlas.com/foldSequence/v1/pdb/" | |
sequence, chain_lengths = get_sequence_from_pdb(gt_output_path) | |
if len(sequence) <= 400: | |
try: | |
response = requests.post(url, data=sequence) | |
response.raise_for_status() | |
pdb_text = response.text | |
with open(tmp_input_path, "w") as f: | |
f.write(pdb_text) | |
# divide to chains | |
if len(chain_lengths) > 1: | |
pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
pdb_struct = pdb_parser.get_structure("original_pdb", tmp_input_path) | |
pdb_model = next(iter(pdb_struct)) | |
chain_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[:len(chain_lengths)] | |
start_ind = 1 | |
esm_chain = next(pdb_model.get_chains()) | |
new_model = Bio.PDB.Model.Model(0) | |
for chain_length, chain_id in zip(chain_lengths, chain_ids): | |
end_ind = start_ind + chain_length | |
new_chain = Bio.PDB.Chain.Chain(chain_id) | |
for res in esm_chain.get_residues(): | |
if start_ind <= res.id[1] <= end_ind: | |
new_chain.add(res) | |
new_model.add(new_chain) | |
start_ind = end_ind + 20 # 20 is the gap in esm | |
io = Bio.PDB.PDBIO() | |
io.set_structure(new_model) | |
io.save(tmp_input_path) | |
input_protein_source = "esm" | |
print("Using input ESM") | |
except requests.exceptions.RequestException as e: | |
print(f"An error occurred in ESM: {e}") | |
# return "No linked structure found" | |
else: | |
print("Sequence too long for ESM") | |
if not os.path.exists(tmp_input_path): | |
print("Using input GT") | |
shutil.copyfile(gt_output_path, tmp_input_path) | |
input_protein_source = "gt" | |
align_gt_and_input(gt_output_path, tmp_input_path, gt_output_path, protein_input_path) | |
protein_size = validate_matching_input_gt(gt_output_path, protein_input_path) | |
assert protein_size > -1, "Failed to validate matching input and gt" | |
os.remove(tmp_input_path) | |
rel_gt_lig_paths = [] | |
rel_ref_lig_paths = [] | |
input_smiles = [] | |
for i, ligand_path in enumerate(sorted(plinder_gt_ligand_paths)): | |
gt_ligand_output_path = os.path.join(output_models_folder, f"{row['system_id']}_ligand_gt_{i}.sdf") | |
# rel_gt_lig_paths.append(f"plinder_models/{row['system_id']}_ref_ligand_{i}.sdf") | |
rel_gt_lig_paths.append(f"plinder_models/{row['system_id']}_ligand_gt_{i}.sdf") | |
shutil.copyfile(ligand_path, gt_ligand_output_path) | |
loaded_ligand = Chem.MolFromMolFile(gt_ligand_output_path) | |
input_smiles.append(Chem.MolToSmiles(loaded_ligand)) | |
ref_ligand_output_path = os.path.join(output_models_folder, f"{row['system_id']}_ligand_ref_{i}.sdf") | |
rel_ref_lig_paths.append(f"plinder_models/{row['system_id']}_ligand_ref_{i}.sdf") | |
create_conformers(input_smiles[-1], ref_ligand_output_path, num_conformers=1) | |
# check if file is empty | |
if os.path.getsize(ref_ligand_output_path) == 0: | |
print("Empty ref ligand, copying from gt", ref_ligand_output_path) | |
shutil.copyfile(gt_ligand_output_path, ref_ligand_output_path) | |
affinity = row["ligand_binding_affinity"] | |
if not pd.notna(affinity): | |
affinity = None | |
json_data = { | |
"input_structure": protein_input_relative_path, | |
"gt_structure": gt_output_relative_path, | |
"gt_sdf_list": rel_gt_lig_paths, | |
"input_smiles_list": input_smiles, | |
"resolution": row.fillna(99)["entry_resolution"], | |
"release_year": row["entry_release_date"], | |
"affinity": affinity, | |
"protein_seq_len": protein_size, | |
"uniprot": row["system_pocket_UniProt"], | |
"ligand_num_atoms": ligand_pos.shape[0], | |
"cluster": row["cluster"], | |
"cluster_size": row["cluster_size"], | |
"input_protein_source": input_protein_source, | |
"ref_sdf_list": rel_ref_lig_paths, | |
"pdb_id": row["system_id"], | |
} | |
open(output_json_path, "w").write(json.dumps(json_data, indent=4)) | |
return "success" | |
# use linked structures | |
# input_structure_to_use = None | |
# apo_linked_structure = os.path.join(linked_structures_folder, "apo", system_id) | |
# pred_linked_structure = os.path.join(linked_structures_folder, "pred", system_id) | |
# if os.path.exists(apo_linked_structure): | |
# for folder in os.listdir(apo_linked_structure): | |
# if not os.path.isdir(os.path.join(pred_linked_structure, folder)): | |
# continue | |
# for filename in os.listdir(os.path.join(apo_linked_structure, folder)): | |
# if filename.endswith(".cif"): | |
# input_structure_to_use = os.path.join(apo_linked_structure, folder, filename) | |
# break | |
# if input_structure_to_use: | |
# break | |
# print(system_id, "found apo", input_structure_to_use) | |
# elif os.path.exists(pred_linked_structure): | |
# for folder in os.listdir(pred_linked_structure): | |
# if not os.path.isdir(os.path.join(pred_linked_structure, folder)): | |
# continue | |
# for filename in os.listdir(os.path.join(pred_linked_structure, folder)): | |
# if filename.endswith(".cif"): | |
# input_structure_to_use = os.path.join(pred_linked_structure, folder, filename) | |
# break | |
# if input_structure_to_use: | |
# break | |
# print(system_id, "found pred", input_structure_to_use) | |
# else: | |
# print(system_id, "no linked structure found") | |
# return "No linked structure found" | |
def main(prefix_bucket_id: str = "*"): | |
os.makedirs(OUTPUT_FOLDER, exist_ok=True) | |
systems = get_cached_systems_to_train() | |
print("total systems", len(systems)) | |
print("clusters", systems["cluster"].value_counts()) | |
# systems = systems[systems["system_num_protein_chains"] > 1] | |
# return | |
print("splits", systems["split"].value_counts()) | |
val_or_test = systems[(systems["split"] == "val") | (systems["split"] == "test")] | |
print("validation or test", len(val_or_test)) | |
output_models_folder = os.path.join(OUTPUT_FOLDER, "plinder_models") | |
output_train_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_train") | |
output_val_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_val") | |
output_test_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_test") | |
output_info = os.path.join(OUTPUT_FOLDER, "plinder_generation_info.csv") | |
if prefix_bucket_id != "*": | |
output_info = os.path.join(OUTPUT_FOLDER, f"plinder_generation_info_{prefix_bucket_id}.csv") | |
os.makedirs(output_models_folder, exist_ok=True) | |
os.makedirs(output_train_jsons_folder, exist_ok=True) | |
os.makedirs(output_val_jsons_folder, exist_ok=True) | |
os.makedirs(output_test_jsons_folder, exist_ok=True) | |
split_to_folder = { | |
"train": output_train_jsons_folder, | |
"val": output_val_jsons_folder, | |
"test": output_test_jsons_folder | |
} | |
output_info_file = open(output_info, "a+") | |
for bucket_id, bucket_systems in systems.groupby('_bucket_id', sort=True): | |
if prefix_bucket_id != "*" and not str(bucket_id).startswith(prefix_bucket_id): | |
continue | |
# if bucket_id != "z2": | |
# continue | |
# systems_folder = "{BASE_FOLDER}/processed/tmp_z2/systems" | |
print("Starting bucket", bucket_id, len(bucket_systems)) | |
print(len(bucket_systems), bucket_systems["system_num_ligand_chains"].value_counts()) | |
tmp_output_models_folder = os.path.join(OUTPUT_FOLDER, f"tmp_{bucket_id}") | |
os.makedirs(tmp_output_models_folder, exist_ok=True) | |
os.system(f'{GSUTIL_PATH} -m cp -r "gs://plinder/2024-06/v2/systems/{bucket_id}.zip" {tmp_output_models_folder}') | |
systems_folder = os.path.join(tmp_output_models_folder, "systems") | |
os.system(f'unzip -o {os.path.join(tmp_output_models_folder, f"{bucket_id}.zip")} -d {systems_folder}') | |
for i, row in bucket_systems.iterrows(): | |
# if not str(row['system_id']).startswith("4z22__1__1.A__1.C"): | |
# continue | |
print("doing", row['system_id'], row["system_num_protein_chains"], row["system_num_ligand_chains"]) | |
system_folder = os.path.join(systems_folder, row['system_id']) | |
try: | |
success = prepare_system(row, system_folder, output_models_folder, split_to_folder[row["split"]]) | |
print("done", row['system_id'], success) | |
output_info_file.write(f"{bucket_id},{row['system_id']},{success}\n") | |
except Exception as e: | |
print("Failed", row['system_id'], e) | |
output_info_file.write(f"{bucket_id},{row['system_id']},Failed\n") | |
output_info_file.flush() | |
shutil.rmtree(tmp_output_models_folder) | |
if __name__ == '__main__': | |
prefix_bucket_id = "*" | |
if len(sys.argv) > 1: | |
prefix_bucket_id = sys.argv[1] | |
main(prefix_bucket_id) |