| | import os |
| | import json |
| | import re |
| | from pathlib import Path |
| | from safetensors import safe_open |
| | from safetensors.torch import save_file |
| | import torch |
| |
|
| | def main(): |
| | src_dir = Path("../GLM-4.7-Flash") |
| | dst_path = Path("model.safetensors") |
| | num_experts_to_keep = 2 |
| | |
| | |
| | safetensor_files = sorted(src_dir.glob("*.safetensors")) |
| | print(f"Found {len(safetensor_files)} safetensors files") |
| | |
| | |
| | expert_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.experts\.(\d+)\..+") |
| | gate_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.weight") |
| | bias_pattern = re.compile(r"model\.layers\.(\d+)\.mlp\.gate\.e_score_correction_bias") |
| | |
| | new_tensors = {} |
| | |
| | for sf_path in safetensor_files: |
| | print(f"Processing {sf_path.name}...") |
| | |
| | with safe_open(sf_path, framework="pt", device="cpu") as f: |
| | for key in f.keys(): |
| | tensor = f.get_tensor(key) |
| | |
| | |
| | expert_match = expert_pattern.search(key) |
| | if expert_match: |
| | layer_idx = int(expert_match.group(1)) |
| | expert_idx = int(expert_match.group(2)) |
| | |
| | if expert_idx >= num_experts_to_keep: |
| | print(f" Skipping {key} (expert {expert_idx} >= {num_experts_to_keep})") |
| | continue |
| | |
| | new_tensors[key] = tensor |
| | continue |
| | |
| | |
| | gate_match = gate_pattern.search(key) |
| | if gate_match: |
| | layer_idx = int(gate_match.group(1)) |
| | original_shape = tensor.shape |
| | |
| | new_tensor = tensor[:num_experts_to_keep, :] |
| | print(f" Resizing {key}: {original_shape} -> {new_tensor.shape}") |
| | new_tensors[key] = new_tensor |
| | continue |
| | |
| | |
| | bias_match = bias_pattern.search(key) |
| | if bias_match: |
| | layer_idx = int(bias_match.group(1)) |
| | original_shape = tensor.shape |
| | |
| | new_tensor = tensor[:num_experts_to_keep] |
| | print(f" Resizing {key}: {original_shape} -> {new_tensor.shape}") |
| | new_tensors[key] = new_tensor |
| | continue |
| | |
| | |
| | new_tensors[key] = tensor |
| | |
| | print(f"\nTotal tensors to save: {len(new_tensors)}") |
| | print(f"Saving to {dst_path}...") |
| | |
| | save_file(new_tensors, dst_path) |
| | print("Done!") |
| | |
| | |
| | config_src = src_dir / "config.json" |
| | if config_src.exists(): |
| | with open(config_src, "r") as f: |
| | config = json.load(f) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | index_data = { |
| | "metadata": { |
| | "total_size": sum(t.numel() * t.element_size() for t in new_tensors.values()) |
| | }, |
| | "weight_map": {key: str(dst_path) for key in new_tensors.keys()} |
| | } |
| | |
| | index_dst = Path("model.safetensors.index.json") |
| | with open(index_dst, "w") as f: |
| | json.dump(index_data, f, indent=2) |
| | print(f"Saved safetensors index to {index_dst}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|