Spaces:
Running
Running
| # /// script | |
| # dependencies = [ | |
| # "torch", | |
| # "numpy", | |
| # ] | |
| # /// | |
| """Generate and save shared weights for consistent comparison.""" | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| # Model configuration | |
| NUM_EXPERTS = 128 | |
| HIDDEN_SIZE = 1152 | |
| INTERMEDIATE_SIZE = 3072 | |
| TOP_K = 4 | |
| # Input configuration | |
| BATCH_SIZE = 1 | |
| SEQ_LEN = 100 | |
| DTYPE = "float32" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Seeds for reproducibility | |
| WEIGHT_SEED = 999 | |
| EXPERT_SEED = 777 | |
| INPUT_SEED = 123 | |
| GENERAL_SEED = 42 | |
| def set_seed(seed: int): | |
| """Set seeds for reproducibility.""" | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # Generate shared weights for all implementations | |
| print("Generating shared weights...") | |
| # Router weights | |
| set_seed(WEIGHT_SEED) | |
| router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE) | |
| torch.nn.init.kaiming_uniform_(router_weight) | |
| router_bias = torch.zeros(NUM_EXPERTS) | |
| # Expert weights - using proper dimensions for gate/up combined projection | |
| set_seed(EXPERT_SEED) | |
| gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02) | |
| gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE) | |
| down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02) | |
| down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE) | |
| # Save weights | |
| torch.save(router_weight, 'router_weight.pt') | |
| torch.save(router_bias, 'router_bias.pt') | |
| torch.save(gate_up_proj, 'gate_up_proj.pt') | |
| torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt') | |
| torch.save(down_proj, 'down_proj.pt') | |
| torch.save(down_proj_bias, 'down_proj_bias.pt') | |
| print(f"Saved weights:") | |
| print(f" Router: {tuple(router_weight.shape)}") | |
| print(f" Gate/Up proj: {tuple(gate_up_proj.shape)}") | |
| print(f" Down proj: {tuple(down_proj.shape)}") | |
| print(f" Hidden size: {HIDDEN_SIZE}") |