Spaces:
Running
Running
import copy | |
import torch | |
import torch.distributed as dist | |
from diffusers import LTXVideoTransformer3DModel | |
from torch._utils import _get_device_module | |
from torch.distributed.tensor import DTensor, Replicate | |
from torch.distributed.tensor.debug import CommDebugMode | |
from torch.distributed.tensor.device_mesh import DeviceMesh | |
from torch.distributed.tensor.parallel.api import parallelize_module | |
from torch.distributed.tensor.parallel.style import ( | |
ColwiseParallel, | |
RowwiseParallel, | |
) | |
# from torch.utils._python_dispatch import TorchDispatchMode | |
DEVICE_TYPE = "cuda" | |
PG_BACKEND = "nccl" | |
DEVICE_COUNT = _get_device_module(DEVICE_TYPE).device_count() | |
def main(world_size: int, rank: int): | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats(rank) | |
CHANNELS = 128 | |
CROSS_ATTENTION_DIM = 2048 | |
CAPTION_CHANNELS = 4096 | |
NUM_LAYERS = 28 | |
NUM_ATTENTION_HEADS = 32 | |
ATTENTION_HEAD_DIM = 64 | |
# CHANNELS = 4 | |
# CROSS_ATTENTION_DIM = 32 | |
# CAPTION_CHANNELS = 64 | |
# NUM_LAYERS = 1 | |
# NUM_ATTENTION_HEADS = 4 | |
# ATTENTION_HEAD_DIM = 8 | |
config = { | |
"in_channels": CHANNELS, | |
"out_channels": CHANNELS, | |
"patch_size": 1, | |
"patch_size_t": 1, | |
"num_attention_heads": NUM_ATTENTION_HEADS, | |
"attention_head_dim": ATTENTION_HEAD_DIM, | |
"cross_attention_dim": CROSS_ATTENTION_DIM, | |
"num_layers": NUM_LAYERS, | |
"activation_fn": "gelu-approximate", | |
"qk_norm": "rms_norm_across_heads", | |
"norm_elementwise_affine": False, | |
"norm_eps": 1e-6, | |
"caption_channels": CAPTION_CHANNELS, | |
"attention_bias": True, | |
"attention_out_bias": True, | |
} | |
# Normal model | |
torch.manual_seed(0) | |
model = LTXVideoTransformer3DModel(**config).to(DEVICE_TYPE) | |
# TP model | |
model_tp = copy.deepcopy(model) | |
device_mesh = DeviceMesh(DEVICE_TYPE, torch.arange(world_size)) | |
print(f"Device mesh: {device_mesh}") | |
transformer_tp_plan = { | |
# ===== Condition embeddings ===== | |
# "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(), | |
# "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)), | |
# "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()), | |
# "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())), | |
# "caption_projection.linear_1": ColwiseParallel(), | |
# "caption_projection.linear_2": RowwiseParallel(), | |
# "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), | |
# ===== ===== | |
} | |
for block in model_tp.transformer_blocks: | |
block_tp_plan = {} | |
# ===== Attention ===== | |
# 8 all-to-all, 3 all-reduce | |
# block_tp_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False) | |
# block_tp_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False) | |
# block_tp_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False) | |
# block_tp_plan["attn1.norm_q"] = SequenceParallel() | |
# block_tp_plan["attn1.norm_k"] = SequenceParallel() | |
# block_tp_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) | |
# block_tp_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False) | |
# block_tp_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False) | |
# block_tp_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False) | |
# block_tp_plan["attn2.norm_q"] = SequenceParallel() | |
# block_tp_plan["attn2.norm_k"] = SequenceParallel() | |
# block_tp_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) | |
# ===== ===== | |
block_tp_plan["ff.net.0.proj"] = ColwiseParallel() | |
block_tp_plan["ff.net.2"] = RowwiseParallel() | |
parallelize_module(block, device_mesh, block_tp_plan) | |
parallelize_module(model_tp, device_mesh, transformer_tp_plan) | |
comm_mode = CommDebugMode() | |
batch_size = 2 | |
num_frames, height, width = 49, 512, 512 | |
temporal_compression_ratio, spatial_compression_ratio = 8, 32 | |
latent_num_frames, latent_height, latent_width = ( | |
(num_frames - 1) // temporal_compression_ratio + 1, | |
height // spatial_compression_ratio, | |
width // spatial_compression_ratio, | |
) | |
video_sequence_length = latent_num_frames * latent_height * latent_width | |
caption_sequence_length = 64 | |
hidden_states = torch.randn(batch_size, video_sequence_length, CHANNELS, device=DEVICE_TYPE) | |
encoder_hidden_states = torch.randn(batch_size, caption_sequence_length, CAPTION_CHANNELS, device=DEVICE_TYPE) | |
encoder_attention_mask = None | |
timestep = torch.randint(0, 1000, (batch_size, 1), device=DEVICE_TYPE) | |
inputs = { | |
"hidden_states": hidden_states, | |
"encoder_hidden_states": encoder_hidden_states, | |
"encoder_attention_mask": encoder_attention_mask, | |
"timestep": timestep, | |
"num_frames": latent_num_frames, | |
"height": latent_height, | |
"width": latent_width, | |
"rope_interpolation_scale": [1 / (8 / 25), 8, 8], | |
"return_dict": False, | |
} | |
output = model(**inputs)[0] | |
with comm_mode: | |
output_tp = model_tp(**inputs)[0] | |
output_tp = ( | |
output_tp.redistribute(output_tp.device_mesh, [Replicate()]).to_local() | |
if isinstance(output_tp, DTensor) | |
else output_tp | |
) | |
print("Output shapes:", output.shape, output_tp.shape) | |
print( | |
"Comparing output:", | |
rank, | |
torch.allclose(output, output_tp, atol=1e-5, rtol=1e-5), | |
(output - output_tp).abs().max(), | |
) | |
print(f"Max memory reserved ({rank=}): {torch.cuda.max_memory_reserved(rank) / 1024**3:.2f} GB") | |
if rank == 0: | |
print() | |
print("get_comm_counts:", comm_mode.get_comm_counts()) | |
# print() | |
# print("get_parameter_info:", comm_mode.get_parameter_info()) # Too much noise | |
print() | |
print("Sharding info:\n" + "".join(f"{k} - {v}\n" for k, v in comm_mode.get_sharding_info().items())) | |
print() | |
print("get_total_counts:", comm_mode.get_total_counts()) | |
comm_mode.generate_json_dump("dump_comm_mode_log.json", noise_level=1) | |
comm_mode.log_comm_debug_tracing_table_to_file("dump_comm_mode_tracing_table.txt", noise_level=1) | |
dist.init_process_group(PG_BACKEND) | |
WORLD_SIZE = dist.get_world_size() | |
RANK = dist.get_rank() | |
torch.cuda.set_device(RANK) | |
if RANK == 0: | |
print(f"World size: {WORLD_SIZE}") | |
print(f"Device count: {DEVICE_COUNT}") | |
try: | |
with torch.no_grad(): | |
main(WORLD_SIZE, RANK) | |
finally: | |
dist.destroy_process_group() | |
# LTXVideoTransformer3DModel( | |
# (proj_in): Linear(in_features=128, out_features=2048, bias=True) | |
# (time_embed): AdaLayerNormSingle( | |
# (emb): PixArtAlphaCombinedTimestepSizeEmbeddings( | |
# (time_proj): Timesteps() | |
# (timestep_embedder): TimestepEmbedding( | |
# (linear_1): Linear(in_features=256, out_features=2048, bias=True) | |
# (act): SiLU() | |
# (linear_2): Linear(in_features=2048, out_features=2048, bias=True) | |
# ) | |
# ) | |
# (silu): SiLU() | |
# (linear): Linear(in_features=2048, out_features=12288, bias=True) | |
# ) | |
# (caption_projection): PixArtAlphaTextProjection( | |
# (linear_1): Linear(in_features=4096, out_features=2048, bias=True) | |
# (act_1): GELU(approximate='tanh') | |
# (linear_2): Linear(in_features=2048, out_features=2048, bias=True) | |
# ) | |
# (rope): LTXVideoRotaryPosEmbed() | |
# (transformer_blocks): ModuleList( | |
# (0-27): 28 x LTXVideoTransformerBlock( | |
# (norm1): RMSNorm() | |
# (attn1): Attention( | |
# (norm_q): RMSNorm() | |
# (norm_k): RMSNorm() | |
# (to_q): Linear(in_features=2048, out_features=2048, bias=True) | |
# (to_k): Linear(in_features=2048, out_features=2048, bias=True) | |
# (to_v): Linear(in_features=2048, out_features=2048, bias=True) | |
# (to_out): ModuleList( | |
# (0): Linear(in_features=2048, out_features=2048, bias=True) | |
# (1): Dropout(p=0.0, inplace=False) | |
# ) | |
# ) | |
# (norm2): RMSNorm() | |
# (attn2): Attention( | |
# (norm_q): RMSNorm() | |
# (norm_k): RMSNorm() | |
# (to_q): Linear(in_features=2048, out_features=2048, bias=True) | |
# (to_k): Linear(in_features=2048, out_features=2048, bias=True) | |
# (to_v): Linear(in_features=2048, out_features=2048, bias=True) | |
# (to_out): ModuleList( | |
# (0): Linear(in_features=2048, out_features=2048, bias=True) | |
# (1): Dropout(p=0.0, inplace=False) | |
# ) | |
# ) | |
# (ff): FeedForward( | |
# (net): ModuleList( | |
# (0): GELU( | |
# (proj): Linear(in_features=2048, out_features=8192, bias=True) | |
# ) | |
# (1): Dropout(p=0.0, inplace=False) | |
# (2): Linear(in_features=8192, out_features=2048, bias=True) | |
# ) | |
# ) | |
# ) | |
# ) | |
# (norm_out): LayerNorm((2048,), eps=1e-06, elementwise_affine=False) | |
# (proj_out): Linear(in_features=2048, out_features=128, bias=True) | |
# ) | |