File size: 1,577 Bytes
a001281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import argparse
from time import sleep
import subprocess
from concurrent.futures import ThreadPoolExecutor

"""
Examples:
    - Test AD-benchmark:
        python benchmark.py --script=inference_ad.py --yaml_dir=configs/ad/
    
    - Test Indomain:
        python benchmark.py --script=inference.py --yaml_dir=configs/indomain/

    - Test:
        python benchmark.py --script=inference.py --yaml_dir=configs/indomain/myprompt --spot=True 
    
    - Test AnimateBench:
        python benchmark.py --script=inference_new.py --yaml_dir=AnimateBench/config/

"""

parser = argparse.ArgumentParser()
parser.add_argument('--yaml_dir', type=str, default='configs/indomain/myprompt_simple/')
parser.add_argument('--node', type=str, default=None)
parser.add_argument('--script', type=str, default='inference.py')
parser.add_argument('--dreambooth', type=list, default=['toon', 'maj', 'real', 'rc', 'ly'])
# parser.add_argument('--dreambooth', type=list, default=['toon'])
parser.add_argument('--spot', type=bool, default=False)
args   = parser.parse_args()

def run_srun_command(command):
    subprocess.run(command, shell=True)

executor = ThreadPoolExecutor()

for db in args.dreambooth:
    if not args.spot:
        command = f"srun -p mm_lol --gres=gpu:1 "
    else:
        command = f"srun -p mm_lol --gres=gpu:1 --quota=spot "
    if args.node is not None:
        command = command + f'-w {args.node} '
    command = command + f"python {args.script} --config={os.path.join(args.yaml_dir, db + '.yaml')}"

    executor.submit(run_srun_command, command)
    sleep(1)