|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
import os |
|
import pytest |
|
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights |
|
|
|
|
|
def run_distributed_shared_expert_test(rank, world_size): |
|
os.environ["MASTER_ADDR"] = "localhost" |
|
os.environ["MASTER_PORT"] = "12356" |
|
os.environ["RANK"] = str(rank) |
|
os.environ["WORLD_SIZE"] = str(world_size) |
|
|
|
dist.init_process_group( |
|
backend="gloo", |
|
rank=rank, |
|
world_size=world_size, |
|
) |
|
|
|
model = MegaBlocksMoeMLPWithSharedExpert() |
|
|
|
hidden_size = 128 |
|
shared_expert_hidden_size = 192 |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def simple_init(tensor): |
|
torch.nn.init.xavier_uniform_(tensor) |
|
|
|
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights( |
|
hidden_size=hidden_size, |
|
shared_expert_hidden_size=shared_expert_hidden_size, |
|
device=torch.device(device), |
|
dtype=torch.float32, |
|
init_method=simple_init |
|
) |
|
|
|
model.set_shared_expert_weights( |
|
up_proj_weight=shared_up_proj_weight, |
|
down_proj_weight=shared_down_proj_weight, |
|
up_proj_bias=shared_up_proj_bias, |
|
down_proj_bias=shared_down_proj_bias, |
|
weighted_sum=True, |
|
activation_fn=torch.nn.functional.gelu |
|
) |
|
|
|
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}" |
|
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}" |
|
assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}" |
|
|
|
print(f"Rank {rank}: Shared expert setup test passed!") |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
def run_distributed_shared_expert_weighted_sum_test(rank, world_size): |
|
os.environ["MASTER_ADDR"] = "localhost" |
|
os.environ["MASTER_PORT"] = "12357" |
|
os.environ["RANK"] = str(rank) |
|
os.environ["WORLD_SIZE"] = str(world_size) |
|
|
|
dist.init_process_group( |
|
backend="gloo", |
|
rank=rank, |
|
world_size=world_size, |
|
) |
|
|
|
model = MegaBlocksMoeMLPWithSharedExpert() |
|
|
|
hidden_size = 64 |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def simple_init(tensor): |
|
torch.nn.init.xavier_uniform_(tensor) |
|
|
|
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights( |
|
hidden_size=hidden_size, |
|
shared_expert_hidden_size=96, |
|
device=torch.device(device), |
|
dtype=torch.float32, |
|
init_method=simple_init |
|
) |
|
|
|
model.set_shared_expert_weights( |
|
up_proj_weight=shared_up_proj_weight, |
|
down_proj_weight=shared_down_proj_weight, |
|
weighted_sum=False, |
|
activation_fn=torch.nn.functional.relu |
|
) |
|
|
|
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}" |
|
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}" |
|
assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}" |
|
assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}" |
|
|
|
print(f"Rank {rank}: Weighted sum setup test passed!") |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
@pytest.mark.parametrize("world_size", [1, 2, 4, 8]) |
|
def test_shared_expert_distributed_functionality(world_size): |
|
if world_size == 1: |
|
|
|
model = MegaBlocksMoeMLPWithSharedExpert() |
|
|
|
hidden_size = 128 |
|
shared_expert_hidden_size = 192 |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def simple_init(tensor): |
|
torch.nn.init.xavier_uniform_(tensor) |
|
|
|
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights( |
|
hidden_size=hidden_size, |
|
shared_expert_hidden_size=shared_expert_hidden_size, |
|
device=torch.device(device), |
|
dtype=torch.float32, |
|
init_method=simple_init |
|
) |
|
|
|
model.set_shared_expert_weights( |
|
up_proj_weight=shared_up_proj_weight, |
|
down_proj_weight=shared_down_proj_weight, |
|
up_proj_bias=shared_up_proj_bias, |
|
down_proj_bias=shared_down_proj_bias, |
|
weighted_sum=True, |
|
activation_fn=torch.nn.functional.gelu |
|
) |
|
|
|
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set" |
|
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set" |
|
assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly" |
|
|
|
print("Single process shared expert setup test passed!") |
|
else: |
|
|
|
mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True) |
|
print("Multi-process shared expert test completed successfully!") |
|
|
|
|
|
@pytest.mark.parametrize("world_size", [1, 2, 4, 8]) |
|
def test_shared_expert_distributed_weighted_sum(world_size): |
|
if world_size == 1: |
|
|
|
model = MegaBlocksMoeMLPWithSharedExpert() |
|
|
|
hidden_size = 64 |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def simple_init(tensor): |
|
torch.nn.init.xavier_uniform_(tensor) |
|
|
|
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights( |
|
hidden_size=hidden_size, |
|
shared_expert_hidden_size=96, |
|
device=torch.device(device), |
|
dtype=torch.float32, |
|
init_method=simple_init |
|
) |
|
|
|
model.set_shared_expert_weights( |
|
up_proj_weight=shared_up_proj_weight, |
|
down_proj_weight=shared_down_proj_weight, |
|
weighted_sum=False, |
|
activation_fn=torch.nn.functional.relu |
|
) |
|
|
|
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set" |
|
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set" |
|
assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly" |
|
assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly" |
|
|
|
print("Single process weighted sum setup test passed!") |
|
else: |
|
|
|
mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True) |
|
print("Multi-process shared expert weighted sum test completed successfully!") |
|
|
|
|
|
def test_shared_expert_single_process(): |
|
model = MegaBlocksMoeMLPWithSharedExpert() |
|
|
|
assert model.shared_up_proj_weight is None |
|
assert model.shared_down_proj_weight is None |
|
assert hasattr(model, 'set_shared_expert_weights') |
|
|
|
print("Single process shared expert basic test passed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
test_shared_expert_single_process() |
|
print("Single process test passed!") |
|
|
|
os.environ['WORLD_SIZE'] = '2' |
|
test_shared_expert_distributed_functionality() |
|
print("Distributed functionality test passed!") |
|
|
|
test_shared_expert_distributed_weighted_sum() |
|
print("Distributed weighted sum test passed!") |