Spaces:
Runtime error
Runtime error
File size: 457 Bytes
e7d5680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch.distributed as dist
_GLOBAL_PARALLEL_GROUPS = dict()
def set_data_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["data"] = group
def get_data_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("data", None)
def set_sequence_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["sequence"] = group
def get_sequence_parallel_group():
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)
|