File size: 6,047 Bytes
b63cb16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f13f38b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import glob
from safetensors import safe_open
from safetensors.torch import save_file
import torch
import json

# Model directory
model_dir = "xai-org/grok-2"
output_dir = "huihui-ai/grok-2"
os.makedirs(output_dir, exist_ok=True)

# Collect all safetensors files
print("Collecting safetensors files...", flush=True)
safetensors_files = glob.glob(os.path.join(model_dir, "pytorch_model-*.safetensors"))
if not safetensors_files:
    raise FileNotFoundError(f"No pytorch_model-*.safetensors files found in directory {model_dir}")

# Load all files into cache and build key-to-file mapping
file_cache = {}  # file path -> {key: tensor}
key_to_files = {}  # key -> [file paths]
total_size = 0
print("Loading safetensors files...", flush=True)
for file_path in safetensors_files:
    try:
        with safe_open(file_path, framework="pt", device="cpu") as f:
            file_cache[file_path] = {key: f.get_tensor(key) for key in f.keys()}
            for key, tensor in file_cache[file_path].items():
                if key not in key_to_files:
                    key_to_files[key] = []
                key_to_files[key].append(file_path)
                total_size += tensor.element_size() * tensor.nelement()
    except Exception as e:
        print(f"Warning: Failed to load {file_path}: {e}")
print(f"Found {len(key_to_files)} unique keys, total size {total_size / 1e9:.2f} GB", flush=True)

# Merge TP shards
tp_count = 8  # TP=8
merged_state_dict = {}
print("Merging TP shards...", flush=True)
for key, file_paths in key_to_files.items():
    if len(file_paths) > 1:  # TP shards
        print(f"Merging {key} shards...", flush=True)
        # Sort by TP number
        sorted_paths = sorted(file_paths, key=lambda x: int(x.split("TP-")[1].split(".")[0]) if "TP-" in x else -1)
        tensors = []
        for file_path in sorted_paths[:tp_count]:
            if file_path in file_cache and key in file_cache[file_path]:
                tensors.append(file_cache[file_path][key])
            else:
                print(f"Warning: Key {key} missing in {file_path}")
        if len(tensors) == tp_count:
            try:
                # Determine concatenation dimension
                dim = 0 if "w1.weight" in key or "w3.weight" in key else 1 if "w2.weight" in key else 0
                merged_tensor = torch.cat(tensors, dim=dim)
                # Verify shape
                if "block_sparse_moe.experts" in key:
                    if "w1.weight" in key or "w3.weight" in key:
                        expected_shape = (16384, 8192)  # moe_intermediate_size, hidden_size
                        if merged_tensor.shape != expected_shape:
                            print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
                    elif "w2.weight" in key:
                        expected_shape = (8192, 16384)  # hidden_size, moe_intermediate_size
                        if merged_tensor.shape != expected_shape:
                            print(f"Warning: {key} merged shape {merged_tensor.shape} does not match expected {expected_shape}")
                merged_state_dict[key] = merged_tensor
            except Exception as e:
                print(f"Failed to merge {key}: {e}")
                merged_state_dict[key] = tensors[0] if tensors else None
        else:
            print(f"Warning: Found {len(tensors)} shards for {key}, expected {tp_count}, using first tensor")
            merged_state_dict[key] = tensors[0] if tensors else None
    else:
        print(f"Processing {key} ...", flush=True)
        # Non-TP shard
        file_path = file_paths[0]
        if file_path in file_cache and key in file_cache[file_path]:
            merged_state_dict[key] = file_cache[file_path][key]
        else:
            print(f"Warning: Key {key} missing in {file_path}")
            merged_state_dict[key] = None

# Group by layer
layer_dicts = {}
special_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]
last_layer_idx = None
print("Grouping weights by layer...", flush=True)
for key in list(merged_state_dict.keys()):
    if merged_state_dict[key] is None:
        continue
    if key in special_weights:
        continue
    if "model.layers." in key:
        layer_num = int(key.split(".")[2])
        if layer_num not in layer_dicts:
            layer_dicts[layer_num] = {}
        layer_dicts[layer_num][key] = merged_state_dict.pop(key)
        last_layer_idx = max(last_layer_idx or 0, layer_num)

# Save weights for each layer
print("Saving weight files...", flush=True)
for layer_num in sorted(layer_dicts.keys()):
    output_file = os.path.join(output_dir, f"pytorch_model-{layer_num + 1:05d}.safetensors")
    save_file(layer_dicts[layer_num], output_file)
    print(f"Saved layer {layer_num} to {output_file}")

# Save final layer (including special weights)
last_layer_file = os.path.join(output_dir, f"pytorch_model-{last_layer_idx + 1:05d}.safetensors")
last_layer_dict = layer_dicts.get(last_layer_idx, {})
for key in special_weights:
    if key in merged_state_dict and merged_state_dict[key] is not None:
        last_layer_dict[key] = merged_state_dict[key]
save_file(last_layer_dict, last_layer_file)
print(f"Saved final layer (including lm_head, embed_tokens, norm) to {last_layer_file}", flush=True)

# Generate new index
new_index = {"metadata": {"total_size": total_size}, "weight_map": {}}
for layer_num in sorted(layer_dicts.keys()):
    file_name = f"pytorch_model-{layer_num + 1:05d}.safetensors"
    for key in layer_dicts[layer_num]:
        new_index["weight_map"][key] = file_name
for key in special_weights:
    if key in merged_state_dict and merged_state_dict[key] is not None:
        new_index["weight_map"][key] = f"pytorch_model-{last_layer_idx + 1:05d}.safetensors"

with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w") as f:
    json.dump(new_index, f, indent=2)
print(f"Saved new index file to {os.path.join(output_dir, 'pytorch_model.bin.index.json')}", flush=True)