|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from checkpoint import load_weights_from_hf |
|
|
from model import DeepseekForCausalLM |
|
|
from model_config import deepseek_config_registry |
|
|
|
|
|
from torch.distributed.device_mesh import DeviceMesh |
|
|
from torch.distributed.fsdp import fully_shard |
|
|
from torch.distributed.pipelining import PipelineStage, Schedule1F1B |
|
|
|
|
|
|
|
|
|
|
|
model_id = "deepseek-ai/DeepSeek-V2-Lite" |
|
|
|
|
|
|
|
|
|
|
|
def run_full_model( |
|
|
mesh: DeviceMesh, |
|
|
): |
|
|
rank = dist.get_rank() |
|
|
device_count = torch.cuda.device_count() |
|
|
device = torch.device("cuda", rank % device_count) |
|
|
|
|
|
pp_mesh = mesh["pp"] |
|
|
ep_mesh = mesh["ep"] |
|
|
pp_rank = pp_mesh.get_local_rank() |
|
|
ep_rank = ep_mesh.get_local_rank() |
|
|
pp_size = pp_mesh.size() |
|
|
ep_size = ep_mesh.size() |
|
|
|
|
|
|
|
|
model_args = deepseek_config_registry[model_id] |
|
|
|
|
|
|
|
|
model_args.num_hidden_layers = 16 |
|
|
|
|
|
|
|
|
model_args.ep_size = ep_size |
|
|
model_args.num_stages = pp_size |
|
|
model_args.stage_idx = pp_rank |
|
|
print(model_args) |
|
|
|
|
|
|
|
|
with device, mesh: |
|
|
model = DeepseekForCausalLM(model_args) |
|
|
|
|
|
|
|
|
load_weights_from_hf(model, model_id, device) |
|
|
model.train() |
|
|
|
|
|
|
|
|
fsdp_mesh = mesh["fsdp"] |
|
|
hsdp_mesh = mesh["ep", "fsdp"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for layer in model.model.layers.values(): |
|
|
|
|
|
if hasattr(layer.mlp, "experts"): |
|
|
for expert in layer.mlp.experts.values(): |
|
|
fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False) |
|
|
|
|
|
|
|
|
fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False) |
|
|
|
|
|
|
|
|
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False) |
|
|
|
|
|
|
|
|
microbatches = pp_size * 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(ep_rank) |
|
|
bs = 4 |
|
|
seqlen = 128 |
|
|
x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device) |
|
|
label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device) |
|
|
|
|
|
|
|
|
loss_fn = torch.nn.functional.cross_entropy |
|
|
|
|
|
|
|
|
steps = 2 |
|
|
for _ in range(steps): |
|
|
if pp_size > 1: |
|
|
|
|
|
stage = PipelineStage( |
|
|
model, |
|
|
pp_rank, |
|
|
pp_size, |
|
|
device, |
|
|
group=pp_mesh.get_group(), |
|
|
) |
|
|
|
|
|
|
|
|
losses = [] |
|
|
pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn) |
|
|
|
|
|
if pp_rank == 0: |
|
|
y = pp_schedule.step(x) |
|
|
elif pp_rank == pp_size - 1: |
|
|
y = pp_schedule.step(target=label, losses=losses) |
|
|
loss = torch.mean(torch.stack(losses)) |
|
|
else: |
|
|
pp_schedule.step() |
|
|
else: |
|
|
y = model(x) |
|
|
loss = loss_fn(y, label) |
|
|
loss.backward() |
|
|
|
|
|
if pp_rank == pp_size - 1: |
|
|
print(f"logits: {y.shape}") |
|
|
print(f"{loss=}") |
|
|
|
|
|
if pp_rank == 0: |
|
|
param = model.get_parameter("model.layers.0.self_attn.q_proj.weight") |
|
|
print(f"{torch.linalg.norm(param.grad)=}") |
|
|
|
|
|
model.zero_grad() |
|
|
|
|
|
print("Backward done") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp")) |
|
|
|
|
|
run_full_model(mesh) |
|
|
|
|
|
dist.destroy_process_group() |
|
|
|