P2DFlow / experiments /inference_se3_flows.py
Holmes
test
ca7299e
"""DDP inference script."""
import os
import time
import numpy as np
import hydra
import torch
import GPUtil
import sys
from pytorch_lightning import Trainer
from omegaconf import DictConfig, OmegaConf
from experiments import utils as eu
from models.flow_module import FlowModule
import re
from typing import Optional
import subprocess
from biotite.sequence.io import fasta
from data import utils as du
from analysis import metrics
import pandas as pd
import esm
import shutil
import biotite.structure.io as bsio
torch.set_float32_matmul_precision('high')
log = eu.get_pylogger(__name__)
class Sampler:
def __init__(self, cfg: DictConfig):
"""Initialize sampler.
Args:
cfg: inference config.
"""
ckpt_path = cfg.inference.ckpt_path
ckpt_dir = os.path.dirname(ckpt_path)
# ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))
# ckpt_cfg = torch.load(ckpt_path, map_location="cpu")['hyper_parameters']['cfg']
# Set-up config.
OmegaConf.set_struct(cfg, False)
# OmegaConf.set_struct(ckpt_cfg, False)
# cfg = OmegaConf.merge(cfg, ckpt_cfg)
# cfg.experiment.checkpointer.dirpath = './'
self._cfg = cfg
# self._pmpnn_dir = cfg.inference.pmpnn_dir
self._infer_cfg = cfg.inference
self._samples_cfg = self._infer_cfg.samples
# self._rng = np.random.default_rng(self._infer_cfg.seed)
# Set-up directories to write results to
self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
self._output_dir = os.path.join(
self._infer_cfg.output_dir,
self._ckpt_name,
self._infer_cfg.name,
)
os.makedirs(self._output_dir, exist_ok=True)
log.info(f'Saving results to {self._output_dir}')
config_path = os.path.join(self._output_dir, 'config.yaml')
with open(config_path, 'w') as f:
OmegaConf.save(config=self._cfg, f=f)
log.info(f'Saving inference config to {config_path}')
# Read checkpoint and initialize module.
self._flow_module = FlowModule.load_from_checkpoint(
checkpoint_path=ckpt_path,
)
self._flow_module.eval()
self._flow_module._infer_cfg = self._infer_cfg
self._flow_module._samples_cfg = self._samples_cfg
self._flow_module._output_dir = self._output_dir
# devices = GPUtil.getAvailable(
# order='memory', limit = 8)[:4]
# print(GPUtil.getAvailable(order='memory', limit = 8))
devices = [torch.cuda.current_device()]
self._folding_model = esm.pretrained.esmfold_v1().eval()
self._folding_model = self._folding_model.to(devices[-1])
def run_sampling(self):
# devices = GPUtil.getAvailable(
# order='memory', limit = 8)[:self._infer_cfg.num_gpus]
devices = [torch.cuda.current_device()]
log.info(f"Using devices: {devices}")
eval_dataset = eu.LengthDataset(self._samples_cfg)
dataloader = torch.utils.data.DataLoader(
eval_dataset, batch_size=self._samples_cfg.sample_batch, shuffle=False, drop_last=False)
trainer = Trainer(
accelerator="gpu",
strategy="ddp",
devices=devices,
)
trainer.predict(self._flow_module, dataloaders=dataloader)
@hydra.main(version_base=None, config_path="../configs", config_name="inference")
def run(cfg: DictConfig) -> None:
# Read model checkpoint.
log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs')
start_time = time.time()
sampler = Sampler(cfg)
sampler.run_sampling()
#sampler.eval_test()
elapsed_time = time.time() - start_time
log.info(f'Finished in {elapsed_time:.2f}s')
if __name__ == '__main__':
run()