|
|
import copy |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from diffusers import LTXVideoTransformer3DModel |
|
|
from torch._utils import _get_device_module |
|
|
from torch.distributed.tensor import DTensor, Replicate |
|
|
from torch.distributed.tensor.debug import CommDebugMode |
|
|
from torch.distributed.tensor.device_mesh import DeviceMesh |
|
|
from torch.distributed.tensor.parallel.api import parallelize_module |
|
|
from torch.distributed.tensor.parallel.style import ( |
|
|
ColwiseParallel, |
|
|
RowwiseParallel, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE_TYPE = "cuda" |
|
|
PG_BACKEND = "nccl" |
|
|
DEVICE_COUNT = _get_device_module(DEVICE_TYPE).device_count() |
|
|
|
|
|
|
|
|
def main(world_size: int, rank: int): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats(rank) |
|
|
|
|
|
CHANNELS = 128 |
|
|
CROSS_ATTENTION_DIM = 2048 |
|
|
CAPTION_CHANNELS = 4096 |
|
|
NUM_LAYERS = 28 |
|
|
NUM_ATTENTION_HEADS = 32 |
|
|
ATTENTION_HEAD_DIM = 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = { |
|
|
"in_channels": CHANNELS, |
|
|
"out_channels": CHANNELS, |
|
|
"patch_size": 1, |
|
|
"patch_size_t": 1, |
|
|
"num_attention_heads": NUM_ATTENTION_HEADS, |
|
|
"attention_head_dim": ATTENTION_HEAD_DIM, |
|
|
"cross_attention_dim": CROSS_ATTENTION_DIM, |
|
|
"num_layers": NUM_LAYERS, |
|
|
"activation_fn": "gelu-approximate", |
|
|
"qk_norm": "rms_norm_across_heads", |
|
|
"norm_elementwise_affine": False, |
|
|
"norm_eps": 1e-6, |
|
|
"caption_channels": CAPTION_CHANNELS, |
|
|
"attention_bias": True, |
|
|
"attention_out_bias": True, |
|
|
} |
|
|
|
|
|
|
|
|
torch.manual_seed(0) |
|
|
model = LTXVideoTransformer3DModel(**config).to(DEVICE_TYPE) |
|
|
|
|
|
|
|
|
model_tp = copy.deepcopy(model) |
|
|
device_mesh = DeviceMesh(DEVICE_TYPE, torch.arange(world_size)) |
|
|
print(f"Device mesh: {device_mesh}") |
|
|
|
|
|
transformer_tp_plan = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
for block in model_tp.transformer_blocks: |
|
|
block_tp_plan = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_tp_plan["ff.net.0.proj"] = ColwiseParallel() |
|
|
block_tp_plan["ff.net.2"] = RowwiseParallel() |
|
|
parallelize_module(block, device_mesh, block_tp_plan) |
|
|
|
|
|
parallelize_module(model_tp, device_mesh, transformer_tp_plan) |
|
|
|
|
|
comm_mode = CommDebugMode() |
|
|
|
|
|
batch_size = 2 |
|
|
num_frames, height, width = 49, 512, 512 |
|
|
temporal_compression_ratio, spatial_compression_ratio = 8, 32 |
|
|
latent_num_frames, latent_height, latent_width = ( |
|
|
(num_frames - 1) // temporal_compression_ratio + 1, |
|
|
height // spatial_compression_ratio, |
|
|
width // spatial_compression_ratio, |
|
|
) |
|
|
video_sequence_length = latent_num_frames * latent_height * latent_width |
|
|
caption_sequence_length = 64 |
|
|
|
|
|
hidden_states = torch.randn(batch_size, video_sequence_length, CHANNELS, device=DEVICE_TYPE) |
|
|
encoder_hidden_states = torch.randn(batch_size, caption_sequence_length, CAPTION_CHANNELS, device=DEVICE_TYPE) |
|
|
encoder_attention_mask = None |
|
|
timestep = torch.randint(0, 1000, (batch_size, 1), device=DEVICE_TYPE) |
|
|
inputs = { |
|
|
"hidden_states": hidden_states, |
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
|
"encoder_attention_mask": encoder_attention_mask, |
|
|
"timestep": timestep, |
|
|
"num_frames": latent_num_frames, |
|
|
"height": latent_height, |
|
|
"width": latent_width, |
|
|
"rope_interpolation_scale": [1 / (8 / 25), 8, 8], |
|
|
"return_dict": False, |
|
|
} |
|
|
|
|
|
output = model(**inputs)[0] |
|
|
|
|
|
with comm_mode: |
|
|
output_tp = model_tp(**inputs)[0] |
|
|
|
|
|
output_tp = ( |
|
|
output_tp.redistribute(output_tp.device_mesh, [Replicate()]).to_local() |
|
|
if isinstance(output_tp, DTensor) |
|
|
else output_tp |
|
|
) |
|
|
|
|
|
print("Output shapes:", output.shape, output_tp.shape) |
|
|
print( |
|
|
"Comparing output:", |
|
|
rank, |
|
|
torch.allclose(output, output_tp, atol=1e-5, rtol=1e-5), |
|
|
(output - output_tp).abs().max(), |
|
|
) |
|
|
print(f"Max memory reserved ({rank=}): {torch.cuda.max_memory_reserved(rank) / 1024**3:.2f} GB") |
|
|
|
|
|
if rank == 0: |
|
|
print() |
|
|
print("get_comm_counts:", comm_mode.get_comm_counts()) |
|
|
|
|
|
|
|
|
print() |
|
|
print("Sharding info:\n" + "".join(f"{k} - {v}\n" for k, v in comm_mode.get_sharding_info().items())) |
|
|
print() |
|
|
print("get_total_counts:", comm_mode.get_total_counts()) |
|
|
comm_mode.generate_json_dump("dump_comm_mode_log.json", noise_level=1) |
|
|
comm_mode.log_comm_debug_tracing_table_to_file("dump_comm_mode_tracing_table.txt", noise_level=1) |
|
|
|
|
|
|
|
|
dist.init_process_group(PG_BACKEND) |
|
|
WORLD_SIZE = dist.get_world_size() |
|
|
RANK = dist.get_rank() |
|
|
|
|
|
torch.cuda.set_device(RANK) |
|
|
|
|
|
if RANK == 0: |
|
|
print(f"World size: {WORLD_SIZE}") |
|
|
print(f"Device count: {DEVICE_COUNT}") |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
main(WORLD_SIZE, RANK) |
|
|
finally: |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|