Delta-Vector commited on
Commit
8c785f0
·
verified ·
1 Parent(s): 921350d

Upload re-vision.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. re-vision.py +209 -0
re-vision.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install pathlib safetensors tqdm
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from safetensors.torch import load_file, save_file, safe_open
7
+ from collections import defaultdict
8
+ import torch # Needed for tensor manipulation if any dtype/device casting were required (not expected here)
9
+ import shutil
10
+ from tqdm import tqdm # Optional: for progress bar
11
+
12
+ # --- Configuration ---
13
+ BASE_MODEL_DIR = Path("/home/dgxuser/workspace/Mango/models/Mistral-Small-3.2-24B-Instruct-2506")
14
+ TRAINED_MODEL_DIR = Path("/home/dgxuser/workspace/Mango/axolotl/24B-Retrain/merged")
15
+ OUTPUT_MODEL_DIR = Path("/home/dgxuser/workspace/docshotgun/models/MS3.2-Venice-SFT-KTO-0.35-beta-re-vision")
16
+
17
+ # Define the prefix used in the base model for language model layers
18
+ BASE_LM_PREFIX = "language_model."
19
+ # Define the prefix used in the trained model for language model layers
20
+ # (Assuming the trained model has the prefix stripped)
21
+ TRAINED_LM_PREFIX = "" # If trained keys are 'model.layers...', this is effectively empty relative to the base
22
+
23
+ # --- Safety Check ---
24
+ if OUTPUT_MODEL_DIR.exists() and any(OUTPUT_MODEL_DIR.iterdir()):
25
+ print(f"Warning: Output directory {OUTPUT_MODEL_DIR} already exists and is not empty.")
26
+ # Decide if you want to overwrite or stop
27
+ # input("Press Enter to continue and potentially overwrite files, or Ctrl+C to abort.")
28
+ pass # Or raise an error: raise FileExistsError(f"Output directory {OUTPUT_MODEL_DIR} is not empty.")
29
+
30
+ # --- Create Output Directory ---
31
+ OUTPUT_MODEL_DIR.mkdir(parents=True, exist_ok=True)
32
+
33
+ # --- Load Index Files ---
34
+ try:
35
+ base_index_path = next(BASE_MODEL_DIR.glob("*.safetensors.index.json"))
36
+ with open(base_index_path, 'r') as f:
37
+ base_index = json.load(f)
38
+ print(f"Loaded base model index from: {base_index_path}")
39
+ except StopIteration:
40
+ raise FileNotFoundError(f"Could not find *.safetensors.index.json in {BASE_MODEL_DIR}")
41
+
42
+ try:
43
+ trained_index_path = next(TRAINED_MODEL_DIR.glob("*.safetensors.index.json"))
44
+ with open(trained_index_path, 'r') as f:
45
+ trained_index = json.load(f)
46
+ print(f"Loaded trained model index from: {trained_index_path}")
47
+ except StopIteration:
48
+ raise FileNotFoundError(f"Could not find *.safetensors.index.json in {TRAINED_MODEL_DIR}")
49
+
50
+
51
+ # --- Prepare Trained Tensor Lookup ---
52
+ # Create a map from trained tensor name to the shard file it's in
53
+ trained_tensor_to_shard = trained_index.get("weight_map", {})
54
+ if not trained_tensor_to_shard:
55
+ raise ValueError("Could not find 'weight_map' in trained model index.")
56
+ print(f"Built lookup map for {len(trained_tensor_to_shard)} trained tensors.")
57
+
58
+ # --- Process Shards ---
59
+ base_weight_map = base_index.get("weight_map", {})
60
+ if not base_weight_map:
61
+ raise ValueError("Could not find 'weight_map' in base model index.")
62
+
63
+ # Group base tensors by the shard they belong to
64
+ base_shards_content = defaultdict(list)
65
+ for tensor_name, shard_file in base_weight_map.items():
66
+ base_shards_content[shard_file].append(tensor_name)
67
+
68
+ print(f"Processing {len(base_shards_content)} shards from the base model...")
69
+
70
+ # Use tqdm for progress bar over shards
71
+ for shard_file, tensors_in_shard in tqdm(base_shards_content.items(), desc="Merging Shards"):
72
+ base_shard_path = BASE_MODEL_DIR / shard_file
73
+ output_shard_path = OUTPUT_MODEL_DIR / shard_file
74
+
75
+ # Load the current base model shard
76
+ # print(f" Loading base shard: {shard_file}")
77
+ current_shard_tensors = load_file(base_shard_path, device="cpu") # Load to CPU to save GPU memory
78
+
79
+ # Identify which tensors in this shard need replacement
80
+ tensors_to_replace = {} # {base_tensor_name: trained_tensor_name}
81
+ for base_tensor_name in tensors_in_shard:
82
+ if base_tensor_name.startswith(BASE_LM_PREFIX):
83
+ # Derive the corresponding name in the trained model
84
+ # e.g., language_model.model.layers.0... -> model.layers.0...
85
+ potential_trained_name = base_tensor_name[len(BASE_LM_PREFIX):]
86
+
87
+ # Check if this derived name exists in the trained model's index
88
+ if potential_trained_name in trained_tensor_to_shard:
89
+ tensors_to_replace[base_tensor_name] = potential_trained_name
90
+ else:
91
+ # This might happen for non-layer LM parts if the naming convention differs
92
+ # Or if the base model has LM parts not present in the stripped trained model
93
+ # print(f" Debug: Base tensor {base_tensor_name} starts with prefix, but derived name {potential_trained_name} not found in trained model map. Skipping replacement.")
94
+ pass # Keep the base tensor
95
+
96
+ # --- Explicit Check for LM Head (Common Case) ---
97
+ # Many models have `lm_head.weight` outside the `language_model` block
98
+ # Check if the trained model also has `lm_head.weight` (or similar)
99
+ elif base_tensor_name == "lm_head.weight": # Adjust if your LM head has a different name
100
+ if "lm_head.weight" in trained_tensor_to_shard:
101
+ tensors_to_replace[base_tensor_name] = "lm_head.weight"
102
+ else:
103
+ # print(f" Debug: Base tensor 'lm_head.weight' found, but not present in trained model map. Skipping replacement.")
104
+ pass # Keep the base tensor
105
+
106
+ # Group the needed trained tensors by the shard they are located in
107
+ needed_trained_shards = defaultdict(list) # {trained_shard_file: [list of trained_tensor_names]}
108
+ for base_name, trained_name in tensors_to_replace.items():
109
+ try:
110
+ trained_shard_file = trained_tensor_to_shard[trained_name]
111
+ needed_trained_shards[trained_shard_file].append(trained_name)
112
+ except KeyError:
113
+ print(f" Warning: Tensor '{trained_name}' (derived from '{base_name}') listed for replacement but not found in trained model's weight map. Skipping.")
114
+ # Remove from replacement list if lookup fails
115
+ del tensors_to_replace[base_name]
116
+
117
+
118
+ # Load needed trained shards one by one and perform replacements
119
+ loaded_trained_tensors = {}
120
+ for trained_shard_file, trained_tensor_names in needed_trained_shards.items():
121
+ trained_shard_path = TRAINED_MODEL_DIR / trained_shard_file
122
+ # print(f" Loading trained shard: {trained_shard_file} for {len(trained_tensor_names)} tensor(s)")
123
+ try:
124
+ # Load only the required tensors from the trained shard if possible (optimisation - requires safetensors >= 0.4.0)
125
+ # Note: As of mid-2023, load_file loads the whole shard. This is aspirational or requires custom loading.
126
+ # For now, we load the whole shard.
127
+ shard_data = load_file(trained_shard_path, device="cpu")
128
+ for name in trained_tensor_names:
129
+ if name in shard_data:
130
+ loaded_trained_tensors[name] = shard_data[name]
131
+ else:
132
+ print(f" Warning: Expected tensor '{name}' not found within loaded trained shard '{trained_shard_file}'.")
133
+ del shard_data # Free memory
134
+ except FileNotFoundError:
135
+ print(f" Error: Could not find required trained shard file: {trained_shard_path}. Cannot perform replacements for tensors in this shard.")
136
+ # Remove base tensors that relied on this missing shard from replacement list
137
+ base_names_to_remove = [b_name for b_name, t_name in tensors_to_replace.items() if t_name in trained_tensor_names]
138
+ for b_name in base_names_to_remove:
139
+ del tensors_to_replace[b_name]
140
+ print(f" Skipping replacement for base tensor: {b_name}")
141
+
142
+
143
+ # Perform the replacements in the loaded base shard dictionary
144
+ replacement_count = 0
145
+ for base_name, trained_name in tensors_to_replace.items():
146
+ if trained_name in loaded_trained_tensors:
147
+ # Sanity check shapes (optional but recommended)
148
+ if current_shard_tensors[base_name].shape != loaded_trained_tensors[trained_name].shape:
149
+ print(f" Warning: Shape mismatch for {base_name}! Base: {current_shard_tensors[base_name].shape}, Trained: {loaded_trained_tensors[trained_name].shape}. Skipping replacement.")
150
+ continue
151
+ current_shard_tensors[base_name] = loaded_trained_tensors[trained_name]
152
+ replacement_count += 1
153
+ # else: # Already handled by warnings above
154
+ # print(f" Warning: Trained tensor '{trained_name}' was expected but not loaded. Skipping replacement for '{base_name}'.")
155
+
156
+ # print(f" Replaced {replacement_count} tensors in shard {shard_file}.")
157
+
158
+ # Save the modified shard to the output directory
159
+ # Ensure the directory for the shard exists if shards are nested (unlikely but possible)
160
+ output_shard_path.parent.mkdir(parents=True, exist_ok=True)
161
+ # print(f" Saving modified shard to: {output_shard_path}")
162
+ # Metadata can be copied if needed, but usually not necessary for simple weight replacement
163
+ # Pass existing metadata from base_index if available and relevant per-tensor
164
+ save_file(current_shard_tensors, output_shard_path)
165
+
166
+ # Clean up loaded tensors for this shard
167
+ del current_shard_tensors
168
+ del loaded_trained_tensors
169
+
170
+ print("Finished processing shards.")
171
+
172
+ # --- Copy Non-Tensor Files ---
173
+ print("Copying non-tensor files (index, config, tokenizer, etc.)...")
174
+ copied_files = []
175
+ skipped_files = []
176
+
177
+ for item in BASE_MODEL_DIR.iterdir():
178
+ # Skip the actual shard files and the index we processed
179
+ if item.is_file() and (".safetensors" not in item.name) and (".md" not in item.name):
180
+ output_path = OUTPUT_MODEL_DIR / item.name
181
+ try:
182
+ shutil.copy2(item, output_path) # copy2 preserves metadata
183
+ copied_files.append(item.name)
184
+ except Exception as e:
185
+ skipped_files.append(f"{item.name} (Error: {e})")
186
+ elif item.is_dir(): # Also copy relevant subdirectories like tokenizer configs
187
+ output_path = OUTPUT_MODEL_DIR / item.name
188
+ if output_path.exists():
189
+ shutil.rmtree(output_path) # Overwrite directory if exists
190
+ try:
191
+ shutil.copytree(item, output_path)
192
+ copied_files.append(f"{item.name}/")
193
+ except Exception as e:
194
+ skipped_files.append(f"{item.name}/ (Error: {e})")
195
+
196
+ # Specifically copy the original base index file to the new directory
197
+ try:
198
+ shutil.copy2(base_index_path, OUTPUT_MODEL_DIR / base_index_path.name)
199
+ copied_files.append(base_index_path.name)
200
+ except Exception as e:
201
+ skipped_files.append(f"{base_index_path.name} (Error: {e})")
202
+
203
+
204
+ print(f"Copied: {', '.join(copied_files)}")
205
+ if skipped_files:
206
+ print(f"Skipped/Errors: {', '.join(skipped_files)}")
207
+
208
+
209
+ print(f"\nSuccessfully created merged model in: {OUTPUT_MODEL_DIR}")