Revision-Script / re-vision.py
Delta-Vector's picture
Upload re-vision.py with huggingface_hub
8c785f0 verified
# pip install pathlib safetensors tqdm
import json
import os
from pathlib import Path
from safetensors.torch import load_file, save_file, safe_open
from collections import defaultdict
import torch # Needed for tensor manipulation if any dtype/device casting were required (not expected here)
import shutil
from tqdm import tqdm # Optional: for progress bar
# --- Configuration ---
BASE_MODEL_DIR = Path("/home/dgxuser/workspace/Mango/models/Mistral-Small-3.2-24B-Instruct-2506")
TRAINED_MODEL_DIR = Path("/home/dgxuser/workspace/Mango/axolotl/24B-Retrain/merged")
OUTPUT_MODEL_DIR = Path("/home/dgxuser/workspace/docshotgun/models/MS3.2-Venice-SFT-KTO-0.35-beta-re-vision")
# Define the prefix used in the base model for language model layers
BASE_LM_PREFIX = "language_model."
# Define the prefix used in the trained model for language model layers
# (Assuming the trained model has the prefix stripped)
TRAINED_LM_PREFIX = "" # If trained keys are 'model.layers...', this is effectively empty relative to the base
# --- Safety Check ---
if OUTPUT_MODEL_DIR.exists() and any(OUTPUT_MODEL_DIR.iterdir()):
print(f"Warning: Output directory {OUTPUT_MODEL_DIR} already exists and is not empty.")
# Decide if you want to overwrite or stop
# input("Press Enter to continue and potentially overwrite files, or Ctrl+C to abort.")
pass # Or raise an error: raise FileExistsError(f"Output directory {OUTPUT_MODEL_DIR} is not empty.")
# --- Create Output Directory ---
OUTPUT_MODEL_DIR.mkdir(parents=True, exist_ok=True)
# --- Load Index Files ---
try:
base_index_path = next(BASE_MODEL_DIR.glob("*.safetensors.index.json"))
with open(base_index_path, 'r') as f:
base_index = json.load(f)
print(f"Loaded base model index from: {base_index_path}")
except StopIteration:
raise FileNotFoundError(f"Could not find *.safetensors.index.json in {BASE_MODEL_DIR}")
try:
trained_index_path = next(TRAINED_MODEL_DIR.glob("*.safetensors.index.json"))
with open(trained_index_path, 'r') as f:
trained_index = json.load(f)
print(f"Loaded trained model index from: {trained_index_path}")
except StopIteration:
raise FileNotFoundError(f"Could not find *.safetensors.index.json in {TRAINED_MODEL_DIR}")
# --- Prepare Trained Tensor Lookup ---
# Create a map from trained tensor name to the shard file it's in
trained_tensor_to_shard = trained_index.get("weight_map", {})
if not trained_tensor_to_shard:
raise ValueError("Could not find 'weight_map' in trained model index.")
print(f"Built lookup map for {len(trained_tensor_to_shard)} trained tensors.")
# --- Process Shards ---
base_weight_map = base_index.get("weight_map", {})
if not base_weight_map:
raise ValueError("Could not find 'weight_map' in base model index.")
# Group base tensors by the shard they belong to
base_shards_content = defaultdict(list)
for tensor_name, shard_file in base_weight_map.items():
base_shards_content[shard_file].append(tensor_name)
print(f"Processing {len(base_shards_content)} shards from the base model...")
# Use tqdm for progress bar over shards
for shard_file, tensors_in_shard in tqdm(base_shards_content.items(), desc="Merging Shards"):
base_shard_path = BASE_MODEL_DIR / shard_file
output_shard_path = OUTPUT_MODEL_DIR / shard_file
# Load the current base model shard
# print(f" Loading base shard: {shard_file}")
current_shard_tensors = load_file(base_shard_path, device="cpu") # Load to CPU to save GPU memory
# Identify which tensors in this shard need replacement
tensors_to_replace = {} # {base_tensor_name: trained_tensor_name}
for base_tensor_name in tensors_in_shard:
if base_tensor_name.startswith(BASE_LM_PREFIX):
# Derive the corresponding name in the trained model
# e.g., language_model.model.layers.0... -> model.layers.0...
potential_trained_name = base_tensor_name[len(BASE_LM_PREFIX):]
# Check if this derived name exists in the trained model's index
if potential_trained_name in trained_tensor_to_shard:
tensors_to_replace[base_tensor_name] = potential_trained_name
else:
# This might happen for non-layer LM parts if the naming convention differs
# Or if the base model has LM parts not present in the stripped trained model
# print(f" Debug: Base tensor {base_tensor_name} starts with prefix, but derived name {potential_trained_name} not found in trained model map. Skipping replacement.")
pass # Keep the base tensor
# --- Explicit Check for LM Head (Common Case) ---
# Many models have `lm_head.weight` outside the `language_model` block
# Check if the trained model also has `lm_head.weight` (or similar)
elif base_tensor_name == "lm_head.weight": # Adjust if your LM head has a different name
if "lm_head.weight" in trained_tensor_to_shard:
tensors_to_replace[base_tensor_name] = "lm_head.weight"
else:
# print(f" Debug: Base tensor 'lm_head.weight' found, but not present in trained model map. Skipping replacement.")
pass # Keep the base tensor
# Group the needed trained tensors by the shard they are located in
needed_trained_shards = defaultdict(list) # {trained_shard_file: [list of trained_tensor_names]}
for base_name, trained_name in tensors_to_replace.items():
try:
trained_shard_file = trained_tensor_to_shard[trained_name]
needed_trained_shards[trained_shard_file].append(trained_name)
except KeyError:
print(f" Warning: Tensor '{trained_name}' (derived from '{base_name}') listed for replacement but not found in trained model's weight map. Skipping.")
# Remove from replacement list if lookup fails
del tensors_to_replace[base_name]
# Load needed trained shards one by one and perform replacements
loaded_trained_tensors = {}
for trained_shard_file, trained_tensor_names in needed_trained_shards.items():
trained_shard_path = TRAINED_MODEL_DIR / trained_shard_file
# print(f" Loading trained shard: {trained_shard_file} for {len(trained_tensor_names)} tensor(s)")
try:
# Load only the required tensors from the trained shard if possible (optimisation - requires safetensors >= 0.4.0)
# Note: As of mid-2023, load_file loads the whole shard. This is aspirational or requires custom loading.
# For now, we load the whole shard.
shard_data = load_file(trained_shard_path, device="cpu")
for name in trained_tensor_names:
if name in shard_data:
loaded_trained_tensors[name] = shard_data[name]
else:
print(f" Warning: Expected tensor '{name}' not found within loaded trained shard '{trained_shard_file}'.")
del shard_data # Free memory
except FileNotFoundError:
print(f" Error: Could not find required trained shard file: {trained_shard_path}. Cannot perform replacements for tensors in this shard.")
# Remove base tensors that relied on this missing shard from replacement list
base_names_to_remove = [b_name for b_name, t_name in tensors_to_replace.items() if t_name in trained_tensor_names]
for b_name in base_names_to_remove:
del tensors_to_replace[b_name]
print(f" Skipping replacement for base tensor: {b_name}")
# Perform the replacements in the loaded base shard dictionary
replacement_count = 0
for base_name, trained_name in tensors_to_replace.items():
if trained_name in loaded_trained_tensors:
# Sanity check shapes (optional but recommended)
if current_shard_tensors[base_name].shape != loaded_trained_tensors[trained_name].shape:
print(f" Warning: Shape mismatch for {base_name}! Base: {current_shard_tensors[base_name].shape}, Trained: {loaded_trained_tensors[trained_name].shape}. Skipping replacement.")
continue
current_shard_tensors[base_name] = loaded_trained_tensors[trained_name]
replacement_count += 1
# else: # Already handled by warnings above
# print(f" Warning: Trained tensor '{trained_name}' was expected but not loaded. Skipping replacement for '{base_name}'.")
# print(f" Replaced {replacement_count} tensors in shard {shard_file}.")
# Save the modified shard to the output directory
# Ensure the directory for the shard exists if shards are nested (unlikely but possible)
output_shard_path.parent.mkdir(parents=True, exist_ok=True)
# print(f" Saving modified shard to: {output_shard_path}")
# Metadata can be copied if needed, but usually not necessary for simple weight replacement
# Pass existing metadata from base_index if available and relevant per-tensor
save_file(current_shard_tensors, output_shard_path)
# Clean up loaded tensors for this shard
del current_shard_tensors
del loaded_trained_tensors
print("Finished processing shards.")
# --- Copy Non-Tensor Files ---
print("Copying non-tensor files (index, config, tokenizer, etc.)...")
copied_files = []
skipped_files = []
for item in BASE_MODEL_DIR.iterdir():
# Skip the actual shard files and the index we processed
if item.is_file() and (".safetensors" not in item.name) and (".md" not in item.name):
output_path = OUTPUT_MODEL_DIR / item.name
try:
shutil.copy2(item, output_path) # copy2 preserves metadata
copied_files.append(item.name)
except Exception as e:
skipped_files.append(f"{item.name} (Error: {e})")
elif item.is_dir(): # Also copy relevant subdirectories like tokenizer configs
output_path = OUTPUT_MODEL_DIR / item.name
if output_path.exists():
shutil.rmtree(output_path) # Overwrite directory if exists
try:
shutil.copytree(item, output_path)
copied_files.append(f"{item.name}/")
except Exception as e:
skipped_files.append(f"{item.name}/ (Error: {e})")
# Specifically copy the original base index file to the new directory
try:
shutil.copy2(base_index_path, OUTPUT_MODEL_DIR / base_index_path.name)
copied_files.append(base_index_path.name)
except Exception as e:
skipped_files.append(f"{base_index_path.name} (Error: {e})")
print(f"Copied: {', '.join(copied_files)}")
if skipped_files:
print(f"Skipped/Errors: {', '.join(skipped_files)}")
print(f"\nSuccessfully created merged model in: {OUTPUT_MODEL_DIR}")