d22cs051's picture
retriying pushing the code
8273cb9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from omegaconf import OmegaConf
from mmpt.utils import recursive_config, overwrite_dir
from mmpt_cli.localjob import LocalJob
class JobLauncher(object):
JOB_CONFIG = {
"local": LocalJob,
}
def __init__(self, yaml_file):
self.yaml_file = yaml_file
job_key = "local"
if yaml_file.endswith(".yaml"):
config = recursive_config(yaml_file)
if config.task_type is not None:
job_key = config.task_type.split("_")[0]
else:
raise ValueError("unknown extension of job file:", yaml_file)
self.job_key = job_key
def __call__(self, job_type=None, dryrun=False):
if job_type is not None:
self.job_key = job_type.split("_")[0]
print("[JobLauncher] job_key", self.job_key)
job = JobLauncher.JOB_CONFIG[self.job_key](
self.yaml_file, job_type=job_type, dryrun=dryrun)
return job.submit()
class Pipeline(object):
"""a job that loads yaml config."""
def __init__(self, fn):
"""
load a yaml config of a job and save generated configs as yaml for each task.
return: a list of files to run as specified by `run_task`.
"""
if fn.endswith(".py"):
# a python command.
self.backend = "python"
self.run_yamls = [fn]
return
job_config = recursive_config(fn)
if job_config.base_dir is None: # single file job config.
self.run_yamls = [fn]
return
self.project_dir = os.path.join("projects", job_config.project_dir)
self.run_dir = os.path.join("runs", job_config.project_dir)
if job_config.run_task is not None:
run_yamls = []
for stage in job_config.run_task:
# each stage can have multiple tasks running in parallel.
if OmegaConf.is_list(stage):
stage_yamls = []
for task_file in stage:
stage_yamls.append(
os.path.join(self.project_dir, task_file))
run_yamls.append(stage_yamls)
else:
run_yamls.append(os.path.join(self.project_dir, stage))
self.run_yamls = run_yamls
configs_to_save = self._overwrite_task(job_config)
self._save_configs(configs_to_save)
def __getitem__(self, idx):
yaml_files = self.run_yamls[idx]
if isinstance(yaml_files, list):
return [JobLauncher(yaml_file) for yaml_file in yaml_files]
return [JobLauncher(yaml_files)]
def __len__(self):
return len(self.run_yamls)
def _save_configs(self, configs_to_save: dict):
# save
os.makedirs(self.project_dir, exist_ok=True)
for config_file in configs_to_save:
config = configs_to_save[config_file]
print("saving", config_file)
OmegaConf.save(config=config, f=config_file)
def _overwrite_task(self, job_config):
configs_to_save = {}
self.base_project_dir = os.path.join("projects", job_config.base_dir)
self.base_run_dir = os.path.join("runs", job_config.base_dir)
for config_sets in job_config.task_group:
overwrite_config = job_config.task_group[config_sets]
if (
overwrite_config.task_list is None
or len(overwrite_config.task_list) == 0
):
print(
"[warning]",
job_config.task_group,
"has no task_list specified.")
# we don't want this added to a final config.
task_list = overwrite_config.pop("task_list", None)
for config_file in task_list:
config_file_path = os.path.join(
self.base_project_dir, config_file)
config = recursive_config(config_file_path)
# overwrite it.
if overwrite_config:
config = OmegaConf.merge(config, overwrite_config)
overwrite_dir(config, self.run_dir, basedir=self.base_run_dir)
save_file_path = os.path.join(self.project_dir, config_file)
configs_to_save[save_file_path] = config
return configs_to_save
def main(args):
job_type = args.jobtype if args.jobtype else None
# parse multiple pipelines.
pipelines = [Pipeline(fn) for fn in args.yamls.split(",")]
for pipe_id, pipeline in enumerate(pipelines):
if not hasattr(pipeline, "project_dir"):
for job in pipeline[0]:
job(job_type=job_type, dryrun=args.dryrun)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("yamls", type=str)
parser.add_argument(
"--dryrun",
action="store_true",
help="run config and prepare to submit without launch the job.",
)
parser.add_argument(
"--jobtype", type=str, default="",
help="force to run jobs as specified.")
args = parser.parse_args()
main(args)