HarryLee commited on
Commit
3b57708
1 Parent(s): ef8be5c

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +523 -0
train.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Train a new model on one or across multiple GPUs.
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ import math
13
+ import os
14
+ import sys
15
+ from typing import Dict, Optional, Any, List, Tuple, Callable
16
+
17
+ # We need to setup root logger before importing any fairseq libraries.
18
+ logging.basicConfig(
19
+ format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
22
+ stream=sys.stdout,
23
+ )
24
+ logger = logging.getLogger("fairseq_cli.train")
25
+
26
+ import numpy as np
27
+ import torch
28
+ from fairseq import (
29
+ # checkpoint_utils,
30
+ options,
31
+ quantization_utils,
32
+ tasks,
33
+ utils,
34
+ )
35
+ from fairseq.data import iterators
36
+ from fairseq.data.plasma_utils import PlasmaStore
37
+ from fairseq.dataclass.configs import FairseqConfig
38
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
39
+ from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
40
+ from fairseq.file_io import PathManager
41
+ from fairseq.logging import meters, metrics, progress_bar
42
+ from fairseq.model_parallel.megatron_trainer import MegatronTrainer
43
+ # from fairseq.trainer import Trainer
44
+ from omegaconf import DictConfig, OmegaConf
45
+
46
+ from utils import checkpoint_utils
47
+ from trainer import Trainer
48
+
49
+
50
+ def main(cfg: FairseqConfig) -> None:
51
+ if isinstance(cfg, argparse.Namespace):
52
+ cfg = convert_namespace_to_omegaconf(cfg)
53
+
54
+ utils.import_user_module(cfg.common)
55
+
56
+ if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
57
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
58
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
59
+
60
+ assert (
61
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
62
+ ), "Must specify batch size either with --max-tokens or --batch-size"
63
+ metrics.reset()
64
+
65
+ if cfg.common.log_file is not None:
66
+ handler = logging.FileHandler(filename=cfg.common.log_file)
67
+ logger.addHandler(handler)
68
+
69
+ np.random.seed(cfg.common.seed)
70
+ utils.set_torch_seed(cfg.common.seed)
71
+
72
+ if distributed_utils.is_master(cfg.distributed_training):
73
+ checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
74
+
75
+ # Print args
76
+ logger.info(cfg)
77
+
78
+ if cfg.checkpoint.write_checkpoints_asynchronously:
79
+ try:
80
+ import iopath # noqa: F401
81
+ except ImportError:
82
+ logging.exception(
83
+ "Asynchronous checkpoint writing is specified but iopath is "
84
+ "not installed: `pip install iopath`"
85
+ )
86
+ return
87
+
88
+ # Setup task, e.g., translation, language modeling, etc.
89
+ task = tasks.setup_task(cfg.task)
90
+
91
+ assert cfg.criterion, "Please specify criterion to train a model"
92
+
93
+ # Build model and criterion
94
+ if cfg.distributed_training.ddp_backend == "fully_sharded":
95
+ with fsdp_enable_wrap(cfg.distributed_training):
96
+ model = fsdp_wrap(task.build_model(cfg.model))
97
+ else:
98
+ model = task.build_model(cfg.model)
99
+ criterion = task.build_criterion(cfg.criterion)
100
+ logger.info(model)
101
+ logger.info("task: {}".format(task.__class__.__name__))
102
+ logger.info("model: {}".format(model.__class__.__name__))
103
+ logger.info("criterion: {}".format(criterion.__class__.__name__))
104
+ logger.info(
105
+ "num. shared model params: {:,} (num. trained: {:,})".format(
106
+ sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)),
107
+ sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad)
108
+ )
109
+ )
110
+
111
+ logger.info(
112
+ "num. expert model params: {} (num. trained: {})".format(
113
+ sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
114
+ sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad),
115
+ )
116
+ )
117
+
118
+ # Load valid dataset (we load training data below, based on the latest checkpoint)
119
+ # We load the valid dataset AFTER building the model
120
+ # data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
121
+ if cfg.dataset.combine_valid_subsets:
122
+ task.load_dataset("valid", combine=True, epoch=1)
123
+ else:
124
+ for valid_sub_split in cfg.dataset.valid_subset.split(","):
125
+ task.load_dataset(valid_sub_split, combine=False, epoch=1)
126
+
127
+ # (optionally) Configure quantization
128
+ if cfg.common.quantization_config_path is not None:
129
+ quantizer = quantization_utils.Quantizer(
130
+ config_path=cfg.common.quantization_config_path,
131
+ max_epoch=cfg.optimization.max_epoch,
132
+ max_update=cfg.optimization.max_update,
133
+ )
134
+ else:
135
+ quantizer = None
136
+
137
+ # Build trainer
138
+ if cfg.common.model_parallel_size == 1:
139
+ trainer = Trainer(cfg, task, model, criterion, quantizer)
140
+ else:
141
+ trainer = MegatronTrainer(cfg, task, model, criterion)
142
+ logger.info(
143
+ "training on {} devices (GPUs/TPUs)".format(
144
+ cfg.distributed_training.distributed_world_size
145
+ )
146
+ )
147
+ logger.info(
148
+ "max tokens per device = {} and max sentences per device = {}".format(
149
+ cfg.dataset.max_tokens,
150
+ cfg.dataset.batch_size,
151
+ )
152
+ )
153
+
154
+ # Load the latest checkpoint if one is available and restore the
155
+ # corresponding train iterator
156
+ extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
157
+ cfg.checkpoint,
158
+ trainer,
159
+ # don't cache epoch iterators for sharded datasets
160
+ disable_iterator_cache=task.has_sharded_data("train"),
161
+ )
162
+ if cfg.common.tpu:
163
+ import torch_xla.core.xla_model as xm
164
+ xm.rendezvous("load_checkpoint") # wait for all workers
165
+
166
+ max_epoch = cfg.optimization.max_epoch or math.inf
167
+ if max_epoch > 0:
168
+ num_iter_per_epoch = (len(epoch_itr) + cfg.distributed_training.distributed_world_size - 1) \
169
+ // cfg.distributed_training.distributed_world_size
170
+ trainer.lr_reinit(num_iter_per_epoch * max_epoch, trainer.get_num_updates())
171
+ lr = trainer.get_lr()
172
+
173
+ train_meter = meters.StopwatchMeter()
174
+ train_meter.start()
175
+ while epoch_itr.next_epoch_idx <= max_epoch:
176
+ if lr <= cfg.optimization.stop_min_lr:
177
+ logger.info(
178
+ f"stopping training because current learning rate ({lr}) is smaller "
179
+ "than or equal to minimum learning rate "
180
+ f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
181
+ )
182
+ break
183
+
184
+ # train for one epoch
185
+ valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
186
+ if should_stop:
187
+ break
188
+
189
+ # only use first validation loss to update the learning rate
190
+ lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
191
+
192
+ epoch_itr = trainer.get_train_iterator(
193
+ epoch_itr.next_epoch_idx,
194
+ # sharded data: get train iterator for next epoch
195
+ load_dataset=True,
196
+ # don't cache epoch iterators for sharded datasets
197
+ disable_iterator_cache=task.has_sharded_data("train"),
198
+ )
199
+ train_meter.stop()
200
+ logger.info("done training in {:.1f} seconds".format(train_meter.sum))
201
+
202
+ # ioPath implementation to wait for all asynchronous file writes to complete.
203
+ if cfg.checkpoint.write_checkpoints_asynchronously:
204
+ logger.info(
205
+ "ioPath PathManager waiting for all asynchronous checkpoint "
206
+ "writes to finish."
207
+ )
208
+ PathManager.async_close()
209
+ logger.info("ioPath PathManager finished waiting.")
210
+
211
+
212
+ def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
213
+ # skip check if no validation was done in the current epoch
214
+ if valid_loss is None:
215
+ return False
216
+ if cfg.checkpoint.patience <= 0:
217
+ return False
218
+
219
+ def is_better(a, b):
220
+ return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
221
+
222
+ prev_best = getattr(should_stop_early, "best", None)
223
+ if prev_best is None or is_better(valid_loss, prev_best):
224
+ should_stop_early.best = valid_loss
225
+ should_stop_early.num_runs = 0
226
+ return False
227
+ else:
228
+ should_stop_early.num_runs += 1
229
+ if should_stop_early.num_runs >= cfg.checkpoint.patience:
230
+ logger.info(
231
+ "early stop since valid performance hasn't improved for last {} runs".format(
232
+ cfg.checkpoint.patience
233
+ )
234
+ )
235
+ return True
236
+ else:
237
+ return False
238
+
239
+
240
+ @metrics.aggregate("train")
241
+ def train(
242
+ cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
243
+ ) -> Tuple[List[Optional[float]], bool]:
244
+ """Train the model for one epoch and return validation losses."""
245
+ # Initialize data iterator
246
+ itr = epoch_itr.next_epoch_itr(
247
+ fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
248
+ shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
249
+ )
250
+ update_freq = (
251
+ cfg.optimization.update_freq[epoch_itr.epoch - 1]
252
+ if epoch_itr.epoch <= len(cfg.optimization.update_freq)
253
+ else cfg.optimization.update_freq[-1]
254
+ )
255
+ itr = iterators.GroupedIterator(itr, update_freq)
256
+ if cfg.common.tpu:
257
+ itr = utils.tpu_data_loader(itr)
258
+ progress = progress_bar.progress_bar(
259
+ itr,
260
+ log_format=cfg.common.log_format,
261
+ log_file=cfg.common.log_file,
262
+ log_interval=cfg.common.log_interval,
263
+ epoch=epoch_itr.epoch,
264
+ tensorboard_logdir=(
265
+ cfg.common.tensorboard_logdir
266
+ if distributed_utils.is_master(cfg.distributed_training)
267
+ else None
268
+ ),
269
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
270
+ wandb_project=(
271
+ cfg.common.wandb_project
272
+ if distributed_utils.is_master(cfg.distributed_training)
273
+ else None
274
+ ),
275
+ wandb_run_name=os.environ.get(
276
+ "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
277
+ ),
278
+ azureml_logging=(
279
+ cfg.common.azureml_logging
280
+ if distributed_utils.is_master(cfg.distributed_training)
281
+ else False
282
+ ),
283
+ )
284
+ progress.update_config(_flatten_config(cfg))
285
+
286
+ trainer.begin_epoch(epoch_itr.epoch)
287
+
288
+ valid_subsets = cfg.dataset.valid_subset.split(",")
289
+ should_stop = False
290
+ num_updates = trainer.get_num_updates()
291
+ logger.info("Start iterating over samples")
292
+ for i, samples in enumerate(progress):
293
+ with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
294
+ "train_step-%d" % i
295
+ ):
296
+ log_output = trainer.train_step(samples)
297
+
298
+ if log_output is not None: # not OOM, overflow, ...
299
+ # log mid-epoch stats
300
+ num_updates = trainer.get_num_updates()
301
+ if num_updates % cfg.common.log_interval == 0:
302
+ stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
303
+ progress.log(stats, tag="train_inner", step=num_updates)
304
+
305
+ # reset mid-epoch stats after each log interval
306
+ # the end-of-epoch stats will still be preserved
307
+ metrics.reset_meters("train_inner")
308
+
309
+ end_of_epoch = not itr.has_next()
310
+ valid_losses, should_stop = validate_and_save(
311
+ cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
312
+ )
313
+
314
+ if should_stop:
315
+ break
316
+
317
+ # log end-of-epoch stats
318
+ logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
319
+ stats = get_training_stats(metrics.get_smoothed_values("train"))
320
+ progress.print(stats, tag="train", step=num_updates)
321
+
322
+ # reset epoch-level meters
323
+ metrics.reset_meters("train")
324
+ return valid_losses, should_stop
325
+
326
+
327
+ def _flatten_config(cfg: DictConfig):
328
+ config = OmegaConf.to_container(cfg)
329
+ # remove any legacy Namespaces and replace with a single "args"
330
+ namespace = None
331
+ for k, v in list(config.items()):
332
+ if isinstance(v, argparse.Namespace):
333
+ namespace = v
334
+ del config[k]
335
+ if namespace is not None:
336
+ config["args"] = vars(namespace)
337
+ return config
338
+
339
+
340
+ def validate_and_save(
341
+ cfg: DictConfig,
342
+ trainer: Trainer,
343
+ task: tasks.FairseqTask,
344
+ epoch_itr,
345
+ valid_subsets: List[str],
346
+ end_of_epoch: bool,
347
+ ) -> Tuple[List[Optional[float]], bool]:
348
+ num_updates = trainer.get_num_updates()
349
+ max_update = cfg.optimization.max_update or math.inf
350
+
351
+ # Stopping conditions (and an additional one based on validation loss later
352
+ # on)
353
+ should_stop = False
354
+ if num_updates >= max_update:
355
+ should_stop = True
356
+ logger.info(
357
+ f"Stopping training due to "
358
+ f"num_updates: {num_updates} >= max_update: {max_update}"
359
+ )
360
+
361
+ training_time_hours = trainer.cumulative_training_time() / (60 * 60)
362
+ if (
363
+ cfg.optimization.stop_time_hours > 0
364
+ and training_time_hours > cfg.optimization.stop_time_hours
365
+ ):
366
+ should_stop = True
367
+ logger.info(
368
+ f"Stopping training due to "
369
+ f"cumulative_training_time: {training_time_hours} > "
370
+ f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
371
+ )
372
+
373
+ do_save = (
374
+ (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
375
+ or should_stop
376
+ or (
377
+ cfg.checkpoint.save_interval_updates > 0
378
+ and num_updates > 0
379
+ and num_updates % cfg.checkpoint.save_interval_updates == 0
380
+ and num_updates >= cfg.dataset.validate_after_updates
381
+ )
382
+ )
383
+ do_validate = (
384
+ (not end_of_epoch and do_save) # validate during mid-epoch saves
385
+ or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
386
+ or should_stop
387
+ or (
388
+ cfg.dataset.validate_interval_updates > 0
389
+ and num_updates > 0
390
+ and num_updates % cfg.dataset.validate_interval_updates == 0
391
+ )
392
+ ) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates
393
+
394
+ # Validate
395
+ valid_losses = [None]
396
+ if do_validate:
397
+ valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
398
+
399
+ should_stop |= should_stop_early(cfg, valid_losses[0])
400
+
401
+ # Save checkpoint
402
+ if do_save or should_stop:
403
+ checkpoint_utils.save_checkpoint(
404
+ cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
405
+ )
406
+
407
+ return valid_losses, should_stop
408
+
409
+
410
+ def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
411
+ stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
412
+ return stats
413
+
414
+
415
+ def validate(
416
+ cfg: DictConfig,
417
+ trainer: Trainer,
418
+ task: tasks.FairseqTask,
419
+ epoch_itr,
420
+ subsets: List[str],
421
+ ) -> List[Optional[float]]:
422
+ """Evaluate the model on the validation set(s) and return the losses."""
423
+
424
+ if cfg.dataset.fixed_validation_seed is not None:
425
+ # set fixed seed for every validation
426
+ utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
427
+
428
+ trainer.begin_valid_epoch(epoch_itr.epoch)
429
+ valid_losses = []
430
+ for subset in subsets:
431
+ logger.info('begin validation on "{}" subset'.format(subset))
432
+
433
+ # Initialize data iterator
434
+ itr = trainer.get_valid_iterator(subset).next_epoch_itr(
435
+ shuffle=False, set_dataset_epoch=False # use a fixed valid set
436
+ )
437
+ if cfg.common.tpu:
438
+ itr = utils.tpu_data_loader(itr)
439
+ progress = progress_bar.progress_bar(
440
+ itr,
441
+ log_format=cfg.common.log_format,
442
+ log_interval=cfg.common.log_interval,
443
+ epoch=epoch_itr.epoch,
444
+ prefix=f"valid on '{subset}' subset",
445
+ tensorboard_logdir=(
446
+ cfg.common.tensorboard_logdir
447
+ if distributed_utils.is_master(cfg.distributed_training)
448
+ else None
449
+ ),
450
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
451
+ wandb_project=(
452
+ cfg.common.wandb_project
453
+ if distributed_utils.is_master(cfg.distributed_training)
454
+ else None
455
+ ),
456
+ wandb_run_name=os.environ.get(
457
+ "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
458
+ ),
459
+ )
460
+
461
+ # create a new root metrics aggregator so validation metrics
462
+ # don't pollute other aggregators (e.g., train meters)
463
+ with metrics.aggregate(new_root=True) as agg:
464
+ for i, sample in enumerate(progress):
465
+ if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
466
+ break
467
+ trainer.valid_step(sample)
468
+
469
+ # log validation stats
470
+ if hasattr(task, 'get_valid_stats'):
471
+ stats = task.get_valid_stats(cfg, trainer, agg.get_smoothed_values())
472
+ else:
473
+ stats = agg.get_smoothed_values()
474
+ stats = get_valid_stats(cfg, trainer, stats)
475
+
476
+ if hasattr(task, "post_validate"):
477
+ task.post_validate(trainer.get_model(), stats, agg)
478
+
479
+ progress.print(stats, tag=subset, step=trainer.get_num_updates())
480
+
481
+ valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
482
+ return valid_losses
483
+
484
+
485
+ def get_valid_stats(
486
+ cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]
487
+ ) -> Dict[str, Any]:
488
+ stats["num_updates"] = trainer.get_num_updates()
489
+ if hasattr(checkpoint_utils.save_checkpoint, "best"):
490
+ key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
491
+ best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
492
+ stats[key] = best_function(
493
+ checkpoint_utils.save_checkpoint.best,
494
+ stats[cfg.checkpoint.best_checkpoint_metric],
495
+ )
496
+ return stats
497
+
498
+
499
+ def cli_main(
500
+ modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
501
+ ) -> None:
502
+ parser = options.get_training_parser()
503
+ args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
504
+
505
+ cfg = convert_namespace_to_omegaconf(args)
506
+
507
+ if cfg.common.use_plasma_view:
508
+ server = PlasmaStore(path=cfg.common.plasma_path)
509
+ logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}")
510
+
511
+ if args.profile:
512
+ with torch.cuda.profiler.profile():
513
+ with torch.autograd.profiler.emit_nvtx():
514
+ distributed_utils.call_main(cfg, main)
515
+ else:
516
+ distributed_utils.call_main(cfg, main)
517
+
518
+ # if cfg.common.use_plasma_view:
519
+ # server.server.kill()
520
+
521
+
522
+ if __name__ == "__main__":
523
+ cli_main()