PIA / benchmark.py
LeoXing1996
init repo for fg
a001281
raw
history blame contribute delete
No virus
1.58 kB
import os
import argparse
from time import sleep
import subprocess
from concurrent.futures import ThreadPoolExecutor
"""
Examples:
- Test AD-benchmark:
python benchmark.py --script=inference_ad.py --yaml_dir=configs/ad/
- Test Indomain:
python benchmark.py --script=inference.py --yaml_dir=configs/indomain/
- Test:
python benchmark.py --script=inference.py --yaml_dir=configs/indomain/myprompt --spot=True
- Test AnimateBench:
python benchmark.py --script=inference_new.py --yaml_dir=AnimateBench/config/
"""
parser = argparse.ArgumentParser()
parser.add_argument('--yaml_dir', type=str, default='configs/indomain/myprompt_simple/')
parser.add_argument('--node', type=str, default=None)
parser.add_argument('--script', type=str, default='inference.py')
parser.add_argument('--dreambooth', type=list, default=['toon', 'maj', 'real', 'rc', 'ly'])
# parser.add_argument('--dreambooth', type=list, default=['toon'])
parser.add_argument('--spot', type=bool, default=False)
args = parser.parse_args()
def run_srun_command(command):
subprocess.run(command, shell=True)
executor = ThreadPoolExecutor()
for db in args.dreambooth:
if not args.spot:
command = f"srun -p mm_lol --gres=gpu:1 "
else:
command = f"srun -p mm_lol --gres=gpu:1 --quota=spot "
if args.node is not None:
command = command + f'-w {args.node} '
command = command + f"python {args.script} --config={os.path.join(args.yaml_dir, db + '.yaml')}"
executor.submit(run_srun_command, command)
sleep(1)