|
import click |
|
import os |
|
import sys |
|
import importlib |
|
import importlib.util |
|
import json |
|
from click.core import Context, Option |
|
|
|
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ |
|
from ding.framework import Parallel |
|
from ding.entry.cli_parsers import PLATFORM_PARSERS |
|
|
|
|
|
def print_version(ctx: Context, param: Option, value: bool) -> None: |
|
if not value or ctx.resilient_parsing: |
|
return |
|
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__)) |
|
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__)) |
|
ctx.exit() |
|
|
|
|
|
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) |
|
|
|
|
|
@click.command(context_settings=CONTEXT_SETTINGS) |
|
@click.option( |
|
'-v', |
|
'--version', |
|
is_flag=True, |
|
callback=print_version, |
|
expose_value=False, |
|
is_eager=True, |
|
help="Show package's version information." |
|
) |
|
@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.") |
|
@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1") |
|
@click.option( |
|
'--protocol', |
|
type=click.Choice(["tcp", "ipc"]), |
|
default="tcp", |
|
help="Network protocol in parallel mode, default: tcp" |
|
) |
|
@click.option( |
|
"--ports", |
|
type=str, |
|
help="The port addresses that the tasks listen to, e.g. 50515,50516, default: k8s, local: 50515, slurm: 15151" |
|
) |
|
@click.option("--attach-to", type=str, help="The addresses to connect to.") |
|
@click.option("--address", type=str, help="The address to listen to (without port).") |
|
@click.option("--labels", type=str, help="Labels.") |
|
@click.option("--node-ids", type=str, help="Candidate node ids.") |
|
@click.option( |
|
"--topology", |
|
type=click.Choice(["alone", "mesh", "star"]), |
|
default="alone", |
|
help="Network topology, default: alone." |
|
) |
|
@click.option("--platform-spec", type=str, help="Platform specific configure.") |
|
@click.option("--platform", type=str, help="Platform type: slurm, k8s.") |
|
@click.option("--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis.") |
|
@click.option("--redis-host", type=str, help="Redis host.") |
|
@click.option("--redis-port", type=int, help="Redis port.") |
|
@click.option("-m", "--main", type=str, help="Main function of entry module.") |
|
@click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.") |
|
@click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP") |
|
def cli_ditask(*args, **kwargs): |
|
return _cli_ditask(*args, **kwargs) |
|
|
|
|
|
def _parse_platform_args(platform: str, platform_spec: str, all_args: dict): |
|
if platform_spec: |
|
try: |
|
if os.path.splitext(platform_spec) == "json": |
|
with open(platform_spec) as f: |
|
platform_spec = json.load(f) |
|
else: |
|
platform_spec = json.loads(platform_spec) |
|
except: |
|
click.echo("platform_spec is not a valid json!") |
|
exit(1) |
|
if platform not in PLATFORM_PARSERS: |
|
click.echo("platform type is invalid! type: {}".format(platform)) |
|
exit(1) |
|
all_args.pop("platform") |
|
all_args.pop("platform_spec") |
|
try: |
|
parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args) |
|
except Exception as e: |
|
click.echo("error when parse platform spec configure: {}".format(e)) |
|
raise e |
|
|
|
return parsed_args |
|
|
|
|
|
def _cli_ditask( |
|
package: str, |
|
main: str, |
|
parallel_workers: int, |
|
protocol: str, |
|
ports: str, |
|
attach_to: str, |
|
address: str, |
|
labels: str, |
|
node_ids: str, |
|
topology: str, |
|
mq_type: str, |
|
redis_host: str, |
|
redis_port: int, |
|
startup_interval: int, |
|
local_rank: int = 0, |
|
platform: str = None, |
|
platform_spec: str = None, |
|
): |
|
|
|
all_args = locals() |
|
if platform: |
|
parsed_args = _parse_platform_args(platform, platform_spec, all_args) |
|
return _cli_ditask(**parsed_args) |
|
|
|
if not package: |
|
package = os.getcwd() |
|
sys.path.append(package) |
|
if main is None: |
|
mod_name = os.path.basename(package) |
|
mod_name, _ = os.path.splitext(mod_name) |
|
func_name = "main" |
|
else: |
|
mod_name, func_name = main.rsplit(".", 1) |
|
root_mod_name = mod_name.split(".", 1)[0] |
|
sys.path.append(os.path.join(package, root_mod_name)) |
|
mod = importlib.import_module(mod_name) |
|
main_func = getattr(mod, func_name) |
|
|
|
ports = ports or 50515 |
|
if not isinstance(ports, int): |
|
ports = ports.split(",") |
|
ports = list(map(lambda i: int(i), ports)) |
|
ports = ports[0] if len(ports) == 1 else ports |
|
if attach_to: |
|
attach_to = attach_to.split(",") |
|
attach_to = list(map(lambda s: s.strip(), attach_to)) |
|
if labels: |
|
labels = labels.split(",") |
|
labels = set(map(lambda s: s.strip(), labels)) |
|
if node_ids and not isinstance(node_ids, int): |
|
node_ids = node_ids.split(",") |
|
node_ids = list(map(lambda i: int(i), node_ids)) |
|
Parallel.runner( |
|
n_parallel_workers=parallel_workers, |
|
ports=ports, |
|
protocol=protocol, |
|
topology=topology, |
|
attach_to=attach_to, |
|
address=address, |
|
labels=labels, |
|
node_ids=node_ids, |
|
mq_type=mq_type, |
|
redis_host=redis_host, |
|
redis_port=redis_port, |
|
startup_interval=startup_interval |
|
)(main_func) |
|
|