|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from train import main |
|
from submitit.helpers import Checkpointable |
|
|
|
LOCAL = False |
|
try: |
|
import submitit |
|
except: |
|
print( |
|
"No submitit package found! Defaulting to executing the script in the local machine" |
|
) |
|
LOCAL = True |
|
import parser |
|
import json |
|
|
|
|
|
class Trainer(Checkpointable): |
|
def __call__(self, args, slurm=False): |
|
if slurm and not LOCAL: |
|
main( |
|
args, |
|
args.outdir, |
|
master_node=submitit.JobEnvironment().hostnames[0], |
|
port=args.port, |
|
) |
|
else: |
|
main(args, args.outdir, master_node="", dry_run=args.dry_run) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser_ = parser.get_parser() |
|
args = parser_.parse_args() |
|
|
|
if args.json_config != "": |
|
data = json.load(open(args.json_config)) |
|
for key in data.keys(): |
|
setattr(args, key, data[key]) |
|
else: |
|
print("Not using JSON configuration file!") |
|
if args.data_root is not None: |
|
print("Appending data_root to paths") |
|
args.data = os.path.join(args.data_root, args.data) |
|
args.root_feats = os.path.join(args.data_root, args.root_feats) |
|
args.root_nns = os.path.join(args.data_root, args.root_nns) |
|
args.outdir = args.base_root |
|
|
|
trainer = Trainer() |
|
if not args.slurm or LOCAL: |
|
trainer(args) |
|
else: |
|
|
|
executor = submitit.SlurmExecutor(folder=args.slurm_logdir, max_num_timeout=60) |
|
print(args.gpus) |
|
executor.update_parameters( |
|
gpus_per_node=args.gpus, |
|
partition=args.partition, |
|
constraint="volta32gb", |
|
nodes=args.nodes, |
|
ntasks_per_node=args.gpus, |
|
cpus_per_task=10, |
|
mem=256000, |
|
time=args.slurm_time, |
|
job_name=args.exp_name, |
|
exclusive=True if args.gpus == 8 else False, |
|
) |
|
|
|
job = executor.submit(trainer, args, slurm=True) |
|
print(job.job_id) |
|
|
|
import time |
|
|
|
time.sleep(1) |
|
|