Spaces:
Runtime error
Runtime error
from loguru import logger | |
import torch | |
import torch.backends.cudnn as cudnn | |
from yolox.core import Trainer, launch | |
from yolox.exp import get_exp | |
import argparse | |
import random | |
import warnings | |
def make_parser(): | |
parser = argparse.ArgumentParser("YOLOX train parser") | |
parser.add_argument("-expn", "--experiment-name", type=str, default=None) | |
parser.add_argument("-n", "--name", type=str, default=None, help="model name") | |
# distributed | |
parser.add_argument( | |
"--dist-backend", default="nccl", type=str, help="distributed backend" | |
) | |
parser.add_argument( | |
"--dist-url", | |
default=None, | |
type=str, | |
help="url used to set up distributed training", | |
) | |
parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size") | |
parser.add_argument( | |
"-d", "--devices", default=None, type=int, help="device for training" | |
) | |
parser.add_argument( | |
"--local_rank", default=0, type=int, help="local rank for dist training" | |
) | |
parser.add_argument( | |
"-f", | |
"--exp_file", | |
default=None, | |
type=str, | |
help="plz input your expriment description file", | |
) | |
parser.add_argument( | |
"--resume", default=False, action="store_true", help="resume training" | |
) | |
parser.add_argument("-c", "--ckpt", default=None, type=str, help="checkpoint file") | |
parser.add_argument( | |
"-e", | |
"--start_epoch", | |
default=None, | |
type=int, | |
help="resume training start epoch", | |
) | |
parser.add_argument( | |
"--num_machines", default=1, type=int, help="num of node for training" | |
) | |
parser.add_argument( | |
"--machine_rank", default=0, type=int, help="node rank for multi-node training" | |
) | |
parser.add_argument( | |
"--fp16", | |
dest="fp16", | |
default=True, | |
action="store_true", | |
help="Adopting mix precision training.", | |
) | |
parser.add_argument( | |
"-o", | |
"--occupy", | |
dest="occupy", | |
default=False, | |
action="store_true", | |
help="occupy GPU memory first for training.", | |
) | |
parser.add_argument( | |
"opts", | |
help="Modify config options using the command-line", | |
default=None, | |
nargs=argparse.REMAINDER, | |
) | |
return parser | |
def main(exp, args): | |
if exp.seed is not None: | |
random.seed(exp.seed) | |
torch.manual_seed(exp.seed) | |
cudnn.deterministic = True | |
warnings.warn( | |
"You have chosen to seed training. This will turn on the CUDNN deterministic setting, " | |
"which can slow down your training considerably! You may see unexpected behavior " | |
"when restarting from checkpoints." | |
) | |
# set environment variables for distributed training | |
cudnn.benchmark = True | |
trainer = Trainer(exp, args) | |
trainer.train() | |
if __name__ == "__main__": | |
args = make_parser().parse_args() | |
exp = get_exp(args.exp_file, args.name) | |
exp.merge(args.opts) | |
if not args.experiment_name: | |
args.experiment_name = exp.exp_name | |
num_gpu = torch.cuda.device_count() if args.devices is None else args.devices | |
assert num_gpu <= torch.cuda.device_count() | |
launch( | |
main, | |
num_gpu, | |
args.num_machines, | |
args.machine_rank, | |
backend=args.dist_backend, | |
dist_url=args.dist_url, | |
args=(exp, args), | |
) | |