| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | import os |
| | import tempfile |
| | import warnings |
| |
|
| | from safetensors import safe_open |
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_available |
| | from transformers.integrations.tensor_parallel import get_packed_weights, get_tensor_shard, repack_weights |
| | from transformers.testing_utils import ( |
| | TestCasePlus, |
| | backend_device_count, |
| | get_torch_dist_unique_port, |
| | require_huggingface_hub_greater_or_equal, |
| | require_torch_multi_accelerator, |
| | torch_device, |
| | ) |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| | import torch.distributed as dist |
| | import torch.multiprocessing as mp |
| |
|
| |
|
| | def global_wrapper(rank, func, tp, port, func_args, func_kwargs): |
| | def setup_dist_env(rank, world_size, port): |
| | os.environ["WORLD_SIZE"] = str(world_size) |
| | os.environ["RANK"] = str(rank) |
| | os.environ["LOCAL_RANK"] = str(rank) |
| | os.environ["MASTER_ADDR"] = "localhost" |
| | os.environ["MASTER_PORT"] = str(port) |
| |
|
| | world_size = tp |
| | setup_dist_env(rank, world_size, port) |
| |
|
| | if torch.cuda.is_available(): |
| | torch.cuda.set_device(rank) |
| | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) |
| | else: |
| | dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) |
| |
|
| | func(rank, *func_args, **func_kwargs) |
| |
|
| | dist.barrier() |
| | dist.destroy_process_group() |
| |
|
| |
|
| | def init_distributed(tp: int): |
| | def _init_distributed(func): |
| | def wrapper(*args, **kwargs): |
| | world_size = tp |
| | port = get_torch_dist_unique_port() |
| | spawn_args = (func, tp, port, args, kwargs) |
| | mp.spawn(global_wrapper, args=spawn_args, nprocs=world_size) |
| |
|
| | return wrapper |
| |
|
| | return _init_distributed |
| |
|
| |
|
| | class TestTensorParallelUtils(TestCasePlus): |
| | def test_packed_unpacked_conversion(self): |
| | WORLD_SIZE = 2 |
| | PACKED_BLOCK_SIZE = 800 |
| | SHARDING_DIM = 2 |
| | NUM_BLOCKS = 2 |
| |
|
| | original_packed_weights = torch.randn(4, 512, 2 * PACKED_BLOCK_SIZE) |
| | original_packed_weights.get_dtype = lambda: "F32" |
| | empty_param = torch.empty(4, 512, 2 * PACKED_BLOCK_SIZE) |
| |
|
| | class MockDeviceMesh: |
| | def size(self): |
| | return WORLD_SIZE |
| |
|
| | mock_mesh = ( |
| | MockDeviceMesh() |
| | ) |
| |
|
| | packed_weights_0 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 0, SHARDING_DIM) |
| | packed_weights_1 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 1, SHARDING_DIM) |
| |
|
| | |
| | packed_weights = torch.cat([packed_weights_0, packed_weights_1], dim=SHARDING_DIM) |
| | unpacked_weights = repack_weights(packed_weights, SHARDING_DIM, WORLD_SIZE, NUM_BLOCKS) |
| |
|
| | assert torch.allclose(unpacked_weights, original_packed_weights) |
| |
|
| |
|
| | class TestTensorParallelProperties(TestCasePlus): |
| | def test_tp_plan_property_setter_getter(self): |
| | """Test that tp_plan property can be set and retrieved correctly.""" |
| | model_id = "JackFram/llama-68m" |
| | model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| |
|
| | |
| | model.tp_plan = {} |
| | self.assertEqual(model.tp_plan, {}) |
| |
|
| | |
| | valid_plan = {"model.layers.*.self_attn.q_proj": "colwise"} |
| | model.tp_plan = valid_plan |
| | self.assertEqual(model.tp_plan, valid_plan) |
| |
|
| | |
| | model.tp_plan.update({"model.layers.*.self_attn.k_proj": "colwise"}) |
| | expected_plan = {"model.layers.*.self_attn.q_proj": "colwise", "model.layers.*.self_attn.k_proj": "colwise"} |
| | self.assertEqual(model.tp_plan, expected_plan) |
| |
|
| | |
| | model.tp_plan.update({"model.layers.*.self_attn.q_proj": "colwise_rep"}) |
| | expected_plan = { |
| | "model.layers.*.self_attn.q_proj": "colwise_rep", |
| | "model.layers.*.self_attn.k_proj": "colwise", |
| | } |
| | self.assertEqual(model.tp_plan, expected_plan) |
| |
|
| | def test_tp_plan_validation_invalid_style(self): |
| | """Test that invalid parallel styles are rejected.""" |
| | model_id = "JackFram/llama-68m" |
| | model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| |
|
| | |
| | with self.assertRaises(ValueError) as context: |
| | model.tp_plan = {"layers.*.self_attn.q_proj": "invalid_style"} |
| |
|
| | self.assertIn("Unsupported tensor parallel style 'invalid_style'", str(context.exception)) |
| | self.assertIn("Supported styles are", str(context.exception)) |
| |
|
| | def test_tp_plan_validation_nonexistent_layer_warning(self): |
| | """Test that warnings are issued for non-existent layer patterns.""" |
| |
|
| | model_id = "JackFram/llama-68m" |
| | model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| |
|
| | |
| | with warnings.catch_warnings(record=True) as w: |
| | warnings.simplefilter("always") |
| | model.tp_plan = {"nonexistent.*.layer": "colwise"} |
| |
|
| | |
| | self.assertTrue(len(w) > 0) |
| | warning_message = str(w[0].message) |
| | self.assertIn("Layer pattern 'nonexistent.*.layer' does not match any parameters", warning_message) |
| |
|
| | def test_tp_plan_valid_layer_patterns(self): |
| | """Test that valid layer patterns are accepted without warnings.""" |
| | model_id = "JackFram/llama-68m" |
| | model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| |
|
| | |
| | valid_plans = [ |
| | {"model.layers.*.self_attn.q_proj": "colwise"}, |
| | {"model.layers.*.self_attn.k_proj": "rowwise"}, |
| | {"model.layers.*.mlp.gate_proj": "colwise_rep"}, |
| | ] |
| |
|
| | for plan in valid_plans: |
| | with warnings.catch_warnings(record=True) as w: |
| | warnings.simplefilter("always") |
| | model.tp_plan = plan |
| |
|
| | |
| | layer_warnings = [ |
| | warning |
| | for warning in w |
| | if "Layer pattern" in str(warning.message) |
| | and "does not match any parameters" in str(warning.message) |
| | ] |
| |
|
| | |
| | self.assertEqual( |
| | len(layer_warnings), |
| | 0, |
| | f"Unexpected warning for valid pattern {plan}: {[str(w.message) for w in layer_warnings]}", |
| | ) |
| |
|
| | |
| | self.assertEqual(model.tp_plan, valid_plans[-1]) |
| |
|
| | def test_tp_plan_none_handling(self): |
| | """Test that None values are handled correctly.""" |
| | model_id = "JackFram/llama-68m" |
| | model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| |
|
| | |
| | model.tp_plan = None |
| | self.assertEqual(model.tp_plan, {}) |
| |
|
| | |
| | model.tp_plan = {"model.layers.*.self_attn.q_proj": "colwise"} |
| | self.assertEqual(model.tp_plan, {"model.layers.*.self_attn.q_proj": "colwise"}) |
| |
|
| |
|
| | |
| | def _test_model_dense_forward_impl(rank, mode): |
| | """Implementation for comparing TP and non-TP model outputs.""" |
| | model_id = "JackFram/llama-68m" |
| |
|
| | |
| | torch.manual_seed(0) |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) |
| | prompt = "Can I help" |
| | inputs = tokenizer(prompt, return_tensors="pt") |
| |
|
| | |
| | model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto") |
| | dist.barrier() |
| | if mode == "eval": |
| | model_tp.eval() |
| | else: |
| | model_tp.train() |
| |
|
| | |
| | device = model_tp.device |
| | model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| | model = model.to(device) |
| |
|
| | if mode == "eval": |
| | model.eval() |
| | else: |
| | model.train() |
| |
|
| | |
| | input_ids = inputs.input_ids.to(device) |
| |
|
| | |
| | with torch.no_grad(): |
| | |
| | outputs = model(input_ids) |
| | logits = outputs.logits |
| |
|
| | |
| | outputs_tp = model_tp(input_ids) |
| | logits_tp = outputs_tp.logits |
| |
|
| | |
| | assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), ( |
| | f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}" |
| | ) |
| |
|
| | dist.barrier() |
| |
|
| |
|
| | def _test_model_dense_backward_pass_impl(rank): |
| | """Implementation for comparing TP and non-TP model backward passes.""" |
| | model_id = "JackFram/llama-68m" |
| |
|
| | torch.manual_seed(0) |
| |
|
| | model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto") |
| | dist.barrier() |
| | model_tp.train() |
| |
|
| | device = model_tp.device |
| | model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32) |
| | model = model.to(device) |
| | model.train() |
| |
|
| | batch_size, seq_length = 2, 10 |
| | torch.manual_seed(42) |
| | input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device) |
| | labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device) |
| |
|
| | outputs = model(input_ids, labels=labels) |
| | loss = outputs.loss |
| | loss.backward() |
| |
|
| | outputs_tp = model_tp(input_ids, labels=labels) |
| | loss_tp = outputs_tp.loss |
| | loss_tp.backward() |
| |
|
| | assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), ( |
| | f"TP and non-TP model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}" |
| | ) |
| |
|
| | |
| | |
| | for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): |
| | if param.grad is not None and param_tp.grad is not None: |
| | grad = param.grad |
| | grad_tp = param_tp.grad |
| |
|
| | if isinstance(param_tp.data, dist.tensor.DTensor): |
| | placement = param_tp.data.placements[0] |
| | if hasattr(placement, "dim") and placement.dim is not None: |
| | grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim) |
| | else: |
| | grad_shard = grad |
| | else: |
| | grad_shard = grad |
| |
|
| | grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp |
| |
|
| | assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), ( |
| | f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}" |
| | ) |
| |
|
| | dist.barrier() |
| |
|
| |
|
| | def _test_model_dense_forward_compile_impl(rank, mode): |
| | """Implementation for comparing TP and non-TP model outputs with torch.compile.""" |
| | model_id = "JackFram/llama-68m" |
| |
|
| | torch.manual_seed(0) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) |
| | prompt = "Can I help" |
| | inputs = tokenizer(prompt, return_tensors="pt") |
| |
|
| | model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto") |
| | dist.barrier() |
| | if mode == "eval": |
| | model_tp.eval() |
| | else: |
| | model_tp.train() |
| |
|
| | device = model_tp.device |
| | model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") |
| | model = model.to(device) |
| |
|
| | if mode == "eval": |
| | model.eval() |
| | else: |
| | model.train() |
| |
|
| | |
| | model.forward = torch.compile(model.forward) |
| | model_tp.forward = torch.compile(model_tp.forward) |
| |
|
| | input_ids = inputs.input_ids.to(device) |
| |
|
| | with torch.no_grad(): |
| | outputs = model(input_ids) |
| | logits = outputs.logits |
| |
|
| | outputs_tp = model_tp(input_ids) |
| | logits_tp = outputs_tp.logits |
| |
|
| | assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), ( |
| | f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}" |
| | ) |
| |
|
| | dist.barrier() |
| |
|
| |
|
| | def _test_model_dense_save_impl(rank, tmp_dir): |
| | """Implementation of test_model_save for distributed execution.""" |
| | model_id = "JackFram/llama-68m" |
| |
|
| | if dist.is_initialized(): |
| | kwargs = {"tp_plan": "auto"} |
| | result_dir = f"{tmp_dir}/tp" |
| | else: |
| | kwargs = {} |
| | result_dir = f"{tmp_dir}/nontp" |
| |
|
| | model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) |
| | model.save_pretrained(result_dir) |
| |
|
| |
|
| | class TestTensorParallelBase(TestCasePlus): |
| | """Base class for tensor parallel tests. Subclasses must set nproc_per_node.""" |
| |
|
| | nproc_per_node = None |
| |
|
| | @require_torch_multi_accelerator |
| | def test_model_dense_forward_eval(self): |
| | """Test that TP and non-TP models produce the same outputs in eval mode.""" |
| | if self.nproc_per_node is None: |
| | self.skipTest("nproc_per_node not set") |
| | if backend_device_count(torch_device) < self.nproc_per_node: |
| | self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| |
|
| | init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("eval") |
| |
|
| | @require_torch_multi_accelerator |
| | def test_model_dense_forward_train(self): |
| | """Test that TP and non-TP models produce the same outputs in train mode.""" |
| | if self.nproc_per_node is None: |
| | self.skipTest("nproc_per_node not set") |
| | if backend_device_count(torch_device) < self.nproc_per_node: |
| | self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| |
|
| | init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("train") |
| |
|
| | @require_torch_multi_accelerator |
| | def test_model_dense_backward_pass(self): |
| | if self.nproc_per_node is None: |
| | self.skipTest("nproc_per_node not set") |
| | if backend_device_count(torch_device) < self.nproc_per_node: |
| | self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| |
|
| | init_distributed(tp=self.nproc_per_node)(_test_model_dense_backward_pass_impl)() |
| |
|
| | @require_torch_multi_accelerator |
| | def test_model_dense_forward_compile_eval(self): |
| | """Test that TP and non-TP models produce the same outputs with torch.compile in eval mode.""" |
| | if self.nproc_per_node is None: |
| | self.skipTest("nproc_per_node not set") |
| | if backend_device_count(torch_device) < self.nproc_per_node: |
| | self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| |
|
| | init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("eval") |
| |
|
| | @require_torch_multi_accelerator |
| | def test_model_dense_forward_compile_train(self): |
| | """Test that TP and non-TP models produce the same outputs with torch.compile in train mode.""" |
| | if self.nproc_per_node is None: |
| | self.skipTest("nproc_per_node not set") |
| | if backend_device_count(torch_device) < self.nproc_per_node: |
| | self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| |
|
| | init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("train") |
| |
|
| | @require_huggingface_hub_greater_or_equal("0.31.4") |
| | @require_torch_multi_accelerator |
| | def test_model_dense_save(self): |
| | if self.nproc_per_node is None: |
| | self.skipTest("nproc_per_node not set") |
| | if backend_device_count(torch_device) < self.nproc_per_node: |
| | self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}") |
| |
|
| | with tempfile.TemporaryDirectory() as tmp_dir: |
| | |
| | init_distributed(tp=self.nproc_per_node)(_test_model_dense_save_impl)(tmp_dir) |
| |
|
| | |
| | _test_model_dense_save_impl(0, tmp_dir) |
| |
|
| | non_tp_model_path = os.path.join(tmp_dir, "nontp") |
| | tp_model_path = os.path.join(tmp_dir, "tp") |
| |
|
| | for filename in os.listdir(non_tp_model_path): |
| | if not filename.endswith(".safetensors"): |
| | continue |
| |
|
| | non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt") |
| | tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt") |
| | for non_tp_key in non_tp_model.keys(): |
| | non_tp_tensor = non_tp_model.get_tensor(non_tp_key) |
| | tp_tensor = tp_model.get_tensor(non_tp_key) |
| | assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match" |
| | del non_tp_tensor, tp_tensor |
| |
|
| |
|
| | class TestTensorParallel2Proc(TestTensorParallelBase): |
| | """Test tensor parallel with 2 processes.""" |
| |
|
| | nproc_per_node = 2 |
| |
|
| |
|
| | class TestTensorParallel4Proc(TestTensorParallelBase): |
| | """Test tensor parallel with 4 processes.""" |
| |
|
| | nproc_per_node = 4 |
| |
|