import json import logging import os import re import time from typing import List, Tuple import numpy import torch from rdkit import Chem from dockformerpp.model.model import AlphaFold from dockformerpp.utils import residue_constants, protein from dockformerpp.utils.consts import POSSIBLE_ATOM_TYPES, POSSIBLE_BOND_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES logging.basicConfig() logger = logging.getLogger(__file__) logger.setLevel(level=logging.INFO) def count_models_to_evaluate(model_checkpoint_path): model_count = 0 if model_checkpoint_path: model_count += len(model_checkpoint_path.split(",")) return model_count def get_model_basename(model_path): return os.path.splitext( os.path.basename( os.path.normpath(model_path) ) )[0] def make_output_directory(output_dir, model_name, multiple_model_mode): if multiple_model_mode: prediction_dir = os.path.join(output_dir, "predictions", model_name) else: prediction_dir = os.path.join(output_dir, "predictions") os.makedirs(prediction_dir, exist_ok=True) return prediction_dir # Function to get the latest checkpoint def get_latest_checkpoint(checkpoint_dir): if not os.path.exists(checkpoint_dir): return None checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')] if not checkpoints: return None latest_checkpoint = max(checkpoints, key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x))) return os.path.join(checkpoint_dir, latest_checkpoint) def load_models_from_command_line(config, model_device, model_checkpoint_path, output_dir): # Create the output directory multiple_model_mode = count_models_to_evaluate(model_checkpoint_path) > 1 if multiple_model_mode: logger.info(f"evaluating multiple models") if model_checkpoint_path: for path in model_checkpoint_path.split(","): model = AlphaFold(config) model = model.eval() checkpoint_basename = get_model_basename(path) assert os.path.isfile(path), f"Model checkpoint not found at {path}" ckpt_path = path d = torch.load(ckpt_path) if "ema" in d: # The public weights have had this done to them already d = d["ema"]["params"] model.load_state_dict(d) model = model.to(model_device) logger.info( f"Loaded Model parameters at {path}..." ) output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode) yield model, output_directory if not model_checkpoint_path: raise ValueError("model_checkpoint_path must be specified.") def parse_fasta(data): data = re.sub('>$', '', data, flags=re.M) lines = [ l.replace('\n', '') for prot in data.split('>') for l in prot.strip().split('\n', 1) ][1:] tags, seqs = lines[::2], lines[1::2] tags = [re.split('\W| \|', t)[0] for t in tags] return tags, seqs def update_timings(timing_dict, output_file=os.path.join(os.getcwd(), "timings.json")): """ Write dictionary of one or more run step times to a file """ if os.path.exists(output_file): with open(output_file, "r") as f: try: timings = json.load(f) except json.JSONDecodeError: logger.info(f"Overwriting non-standard JSON in {output_file}.") timings = {} else: timings = {} timings.update(timing_dict) with open(output_file, "w") as f: json.dump(timings, f) return output_file def run_model(model, batch, tag, output_dir): with torch.no_grad(): logger.info(f"Running inference for {tag}...") t = time.perf_counter() out = model(batch) inference_time = time.perf_counter() - t logger.info(f"Inference time: {inference_time}") update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json")) return out def get_molecule_from_output(atoms_atype: List[int], atom_chiralities: List[int], atom_charges: List[int], bonds: List[Tuple[int, int, int]], atom_positions: List[Tuple[float, float, float]]): mol = Chem.RWMol() assert len(atoms_atype) == len(atom_chiralities) == len(atom_charges) == len(atom_positions) for atype_idx, chirality_idx, charge_idx in zip(atoms_atype, atom_chiralities, atom_charges): new_atom = Chem.Atom(POSSIBLE_ATOM_TYPES[atype_idx]) new_atom.SetChiralTag(POSSIBLE_CHIRALITIES[chirality_idx]) new_atom.SetFormalCharge(POSSIBLE_CHARGES[charge_idx]) mol.AddAtom(new_atom) # Add bonds for bond in bonds: atom1, atom2, bond_type_idx = bond bond_type = POSSIBLE_BOND_TYPES[bond_type_idx] mol.AddBond(int(atom1), int(atom2), bond_type) # Set atom positions conf = Chem.Conformer(len(atoms_atype)) for i, pos in enumerate(atom_positions.astype(float)): conf.SetAtomPosition(i, pos) mol.AddConformer(conf) return mol def save_output_structure(aatype, residue_index, chain_index, plddt, final_atom_protein_positions, final_atom_mask, output_path): plddt_b_factors = numpy.repeat( plddt[..., None], residue_constants.atom_type_num, axis=-1 ) unrelaxed_protein = protein.from_prediction( aatype=aatype, residue_index=residue_index, chain_index=chain_index, atom_mask=final_atom_mask, atom_positions=final_atom_protein_positions, b_factors=plddt_b_factors, remove_leading_feature_dimension=False, ) with open(output_path, 'w') as fp: fp.write(protein.to_pdb(unrelaxed_protein)) print("Output written to", output_path)