File size: 2,168 Bytes
753e275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import ray
import time

from diffab.tools.relax.openmm_relaxer import run_openmm
from diffab.tools.relax.pyrosetta_relaxer import run_pyrosetta, run_pyrosetta_fixbb
from diffab.tools.relax.base import TaskScanner


@ray.remote(num_gpus=1/8, num_cpus=1)
def run_openmm_remote(task):
    return run_openmm(task)


@ray.remote(num_cpus=1)
def run_pyrosetta_remote(task):
    return run_pyrosetta(task)


@ray.remote(num_cpus=1)
def run_pyrosetta_fixbb_remote(task):
    return run_pyrosetta_fixbb(task)


@ray.remote
def pipeline_openmm_pyrosetta(task):
    funcs = [
        run_openmm_remote,
        run_pyrosetta_remote,
    ]
    for fn in funcs:
        task = fn.remote(task)
    return ray.get(task)


@ray.remote
def pipeline_pyrosetta(task):
    funcs = [
        run_pyrosetta_remote,
    ]
    for fn in funcs:
        task = fn.remote(task)
    return ray.get(task)


@ray.remote
def pipeline_pyrosetta_fixbb(task):
    funcs = [
        run_pyrosetta_fixbb_remote,
    ]
    for fn in funcs:
        task = fn.remote(task)
    return ray.get(task)


pipeline_dict = {
    'openmm_pyrosetta': pipeline_openmm_pyrosetta,
    'pyrosetta': pipeline_pyrosetta,
    'pyrosetta_fixbb': pipeline_pyrosetta_fixbb,
}


def main():
    ray.init()
    parser = argparse.ArgumentParser()
    parser.add_argument('--root', type=str, default='./results')
    parser.add_argument('--pipeline', type=lambda s: pipeline_dict[s], default=pipeline_openmm_pyrosetta)
    args = parser.parse_args()

    final_pfx = 'fixbb' if args.pipeline == pipeline_pyrosetta_fixbb else 'rosetta'
    scanner = TaskScanner(args.root, final_postfix=final_pfx)
    while True:
        tasks = scanner.scan()
        futures = [args.pipeline.remote(t) for t in tasks]
        if len(futures) > 0:
            print(f'Submitted {len(futures)} tasks.')
        while len(futures) > 0:
            done_ids, futures = ray.wait(futures, num_returns=1)
            for done_id in done_ids:
                done_task = ray.get(done_id)
                print(f'Remaining {len(futures)}. Finished {done_task.current_path}')
        time.sleep(1.0)

if __name__ == '__main__':
    main()