|
"""DDP inference script.""" |
|
import os |
|
import time |
|
import numpy as np |
|
import hydra |
|
import torch |
|
import GPUtil |
|
import sys |
|
|
|
from pytorch_lightning import Trainer |
|
from omegaconf import DictConfig, OmegaConf |
|
from experiments import utils as eu |
|
from models.flow_module import FlowModule |
|
import re |
|
from typing import Optional |
|
import subprocess |
|
from biotite.sequence.io import fasta |
|
from data import utils as du |
|
from analysis import metrics |
|
import pandas as pd |
|
import esm |
|
import shutil |
|
import biotite.structure.io as bsio |
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
log = eu.get_pylogger(__name__) |
|
|
|
class Sampler: |
|
|
|
def __init__(self, cfg: DictConfig): |
|
"""Initialize sampler. |
|
|
|
Args: |
|
cfg: inference config. |
|
""" |
|
ckpt_path = cfg.inference.ckpt_path |
|
ckpt_dir = os.path.dirname(ckpt_path) |
|
|
|
|
|
|
|
|
|
OmegaConf.set_struct(cfg, False) |
|
|
|
|
|
|
|
|
|
self._cfg = cfg |
|
|
|
self._infer_cfg = cfg.inference |
|
self._samples_cfg = self._infer_cfg.samples |
|
|
|
|
|
|
|
self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:]) |
|
self._output_dir = os.path.join( |
|
self._infer_cfg.output_dir, |
|
self._ckpt_name, |
|
self._infer_cfg.name, |
|
) |
|
os.makedirs(self._output_dir, exist_ok=True) |
|
log.info(f'Saving results to {self._output_dir}') |
|
config_path = os.path.join(self._output_dir, 'config.yaml') |
|
with open(config_path, 'w') as f: |
|
OmegaConf.save(config=self._cfg, f=f) |
|
log.info(f'Saving inference config to {config_path}') |
|
|
|
|
|
self._flow_module = FlowModule.load_from_checkpoint( |
|
checkpoint_path=ckpt_path, |
|
) |
|
self._flow_module.eval() |
|
self._flow_module._infer_cfg = self._infer_cfg |
|
self._flow_module._samples_cfg = self._samples_cfg |
|
self._flow_module._output_dir = self._output_dir |
|
|
|
|
|
|
|
|
|
|
|
devices = [torch.cuda.current_device()] |
|
|
|
self._folding_model = esm.pretrained.esmfold_v1().eval() |
|
self._folding_model = self._folding_model.to(devices[-1]) |
|
|
|
def run_sampling(self): |
|
|
|
|
|
devices = [torch.cuda.current_device()] |
|
|
|
log.info(f"Using devices: {devices}") |
|
|
|
eval_dataset = eu.LengthDataset(self._samples_cfg) |
|
dataloader = torch.utils.data.DataLoader( |
|
eval_dataset, batch_size=self._samples_cfg.sample_batch, shuffle=False, drop_last=False) |
|
|
|
trainer = Trainer( |
|
accelerator="gpu", |
|
strategy="ddp", |
|
devices=devices, |
|
) |
|
trainer.predict(self._flow_module, dataloaders=dataloader) |
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../configs", config_name="inference") |
|
def run(cfg: DictConfig) -> None: |
|
|
|
|
|
log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs') |
|
start_time = time.time() |
|
sampler = Sampler(cfg) |
|
sampler.run_sampling() |
|
|
|
elapsed_time = time.time() - start_time |
|
log.info(f'Finished in {elapsed_time:.2f}s') |
|
|
|
if __name__ == '__main__': |
|
run() |
|
|