PIA / benchmark.py
LeoXing1996
init repo for fg
a001281
raw history blame
No virus
1.58 kB
import os
import argparse
from time import sleep
import subprocess
from concurrent.futures import ThreadPoolExecutor
"""
Examples:
- Test AD-benchmark:
python benchmark.py --script=inference_ad.py --yaml_dir=configs/ad/
- Test Indomain:
python benchmark.py --script=inference.py --yaml_dir=configs/indomain/
- Test:
python benchmark.py --script=inference.py --yaml_dir=configs/indomain/myprompt --spot=True
- Test AnimateBench:
python benchmark.py --script=inference_new.py --yaml_dir=AnimateBench/config/
"""
parser = argparse.ArgumentParser()
parser.add_argument('--yaml_dir', type=str, default='configs/indomain/myprompt_simple/')
parser.add_argument('--node', type=str, default=None)
parser.add_argument('--script', type=str, default='inference.py')
parser.add_argument('--dreambooth', type=list, default=['toon', 'maj', 'real', 'rc', 'ly'])
# parser.add_argument('--dreambooth', type=list, default=['toon'])
parser.add_argument('--spot', type=bool, default=False)
args = parser.parse_args()
def run_srun_command(command):
subprocess.run(command, shell=True)
executor = ThreadPoolExecutor()
for db in args.dreambooth:
if not args.spot:
command = f"srun -p mm_lol --gres=gpu:1 "
else:
command = f"srun -p mm_lol --gres=gpu:1 --quota=spot "
if args.node is not None:
command = command + f'-w {args.node} '
command = command + f"python {args.script} --config={os.path.join(args.yaml_dir, db + '.yaml')}"
executor.submit(run_srun_command, command)
sleep(1)