| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| import os |
| import random |
| import numpy as np |
|
|
| import paddle |
| from paddle.distributed import fleet |
|
|
| __all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env'] |
|
|
|
|
| def init_fleet_env(find_unused_parameters=False): |
| strategy = fleet.DistributedStrategy() |
| strategy.find_unused_parameters = find_unused_parameters |
| fleet.init(is_collective=True, strategy=strategy) |
|
|
|
|
| def init_parallel_env(): |
| env = os.environ |
| dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env |
| if dist: |
| trainer_id = int(env['PADDLE_TRAINER_ID']) |
| local_seed = (99 + trainer_id) |
| random.seed(local_seed) |
| np.random.seed(local_seed) |
|
|
| paddle.distributed.init_parallel_env() |
|
|
|
|
| def set_random_seed(seed): |
| paddle.seed(seed) |
| random.seed(seed) |
| np.random.seed(seed) |
|
|