| | import importlib.util |
| |
|
| | import torch |
| | import torch.distributed as dist |
| |
|
| | try: |
| | |
| | if importlib.util.find_spec("paifuser") is not None: |
| | import paifuser |
| | from paifuser.xfuser.core.distributed import ( |
| | get_sequence_parallel_rank, get_sequence_parallel_world_size, |
| | get_sp_group, get_world_group, init_distributed_environment, |
| | initialize_model_parallel) |
| | from paifuser.xfuser.core.long_ctx_attention import \ |
| | xFuserLongContextAttention |
| | print("Import PAI DiT Turbo") |
| | else: |
| | import xfuser |
| | from xfuser.core.distributed import (get_sequence_parallel_rank, |
| | get_sequence_parallel_world_size, |
| | get_sp_group, get_world_group, |
| | init_distributed_environment, |
| | initialize_model_parallel) |
| | from xfuser.core.long_ctx_attention import xFuserLongContextAttention |
| | print("Xfuser import sucessful") |
| | except Exception as ex: |
| | get_sequence_parallel_world_size = None |
| | get_sequence_parallel_rank = None |
| | xFuserLongContextAttention = None |
| | get_sp_group = None |
| | get_world_group = None |
| | init_distributed_environment = None |
| | initialize_model_parallel = None |
| |
|
| | def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1): |
| | if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1: |
| | if get_sp_group is None: |
| | raise RuntimeError("xfuser is not installed.") |
| | dist.init_process_group("nccl") |
| | print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % ( |
| | ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(), |
| | dist.get_world_size())) |
| | assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \ |
| | "number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size() |
| | init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) |
| | initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree, |
| | classifier_free_guidance_degree=classifier_free_guidance_degree, |
| | ring_degree=ring_degree, |
| | ulysses_degree=ulysses_degree) |
| | |
| | device = torch.device(f"cuda:{get_world_group().local_rank}") |
| | print('rank=%d device=%s' % (get_world_group().rank, str(device))) |
| | else: |
| | device = "cuda" |
| | return device |