File size: 3,875 Bytes
ca7299e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""DDP inference script."""
import os
import time
import numpy as np
import hydra
import torch
import GPUtil
import sys

from pytorch_lightning import Trainer
from omegaconf import DictConfig, OmegaConf
from experiments import utils as eu
from models.flow_module import FlowModule
import re
from typing import Optional
import subprocess
from biotite.sequence.io import fasta
from data import utils as du
from analysis import metrics
import pandas as pd
import esm
import shutil
import biotite.structure.io as bsio


torch.set_float32_matmul_precision('high')
log = eu.get_pylogger(__name__)

class Sampler:

    def __init__(self, cfg: DictConfig):
        """Initialize sampler.

        Args:
            cfg: inference config.
        """
        ckpt_path = cfg.inference.ckpt_path
        ckpt_dir = os.path.dirname(ckpt_path)
        # ckpt_cfg = OmegaConf.load(os.path.join(ckpt_dir, 'config.yaml'))
        # ckpt_cfg = torch.load(ckpt_path, map_location="cpu")['hyper_parameters']['cfg']

        # Set-up config.
        OmegaConf.set_struct(cfg, False)
        # OmegaConf.set_struct(ckpt_cfg, False)
        # cfg = OmegaConf.merge(cfg, ckpt_cfg)
        # cfg.experiment.checkpointer.dirpath = './'

        self._cfg = cfg
        # self._pmpnn_dir = cfg.inference.pmpnn_dir
        self._infer_cfg = cfg.inference
        self._samples_cfg = self._infer_cfg.samples
        # self._rng = np.random.default_rng(self._infer_cfg.seed)

        # Set-up directories to write results to
        self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:])
        self._output_dir = os.path.join(
            self._infer_cfg.output_dir,
            self._ckpt_name,
            self._infer_cfg.name,
        )
        os.makedirs(self._output_dir, exist_ok=True)
        log.info(f'Saving results to {self._output_dir}')
        config_path = os.path.join(self._output_dir, 'config.yaml')
        with open(config_path, 'w') as f:
            OmegaConf.save(config=self._cfg, f=f)
        log.info(f'Saving inference config to {config_path}')

        # Read checkpoint and initialize module.
        self._flow_module = FlowModule.load_from_checkpoint(
            checkpoint_path=ckpt_path,
        )
        self._flow_module.eval()
        self._flow_module._infer_cfg = self._infer_cfg
        self._flow_module._samples_cfg = self._samples_cfg
        self._flow_module._output_dir = self._output_dir

        # devices = GPUtil.getAvailable(
        #     order='memory', limit = 8)[:4]
        # print(GPUtil.getAvailable(order='memory', limit = 8))

        devices = [torch.cuda.current_device()]

        self._folding_model = esm.pretrained.esmfold_v1().eval()
        self._folding_model = self._folding_model.to(devices[-1])

    def run_sampling(self):
        # devices = GPUtil.getAvailable(
        #     order='memory', limit = 8)[:self._infer_cfg.num_gpus]
        devices = [torch.cuda.current_device()]

        log.info(f"Using devices: {devices}")

        eval_dataset = eu.LengthDataset(self._samples_cfg)
        dataloader = torch.utils.data.DataLoader(
            eval_dataset, batch_size=self._samples_cfg.sample_batch, shuffle=False, drop_last=False)
        
        trainer = Trainer(
            accelerator="gpu",
            strategy="ddp",
            devices=devices,
        )
        trainer.predict(self._flow_module, dataloaders=dataloader)
        


@hydra.main(version_base=None, config_path="../configs", config_name="inference")
def run(cfg: DictConfig) -> None:

    # Read model checkpoint.
    log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs')
    start_time = time.time()
    sampler = Sampler(cfg)
    sampler.run_sampling()
    #sampler.eval_test()
    elapsed_time = time.time() - start_time
    log.info(f'Finished in {elapsed_time:.2f}s')

if __name__ == '__main__':
    run()