phazei-SkyReels-V2-fp8-e5m2 / scripts /merge_fp8_shards.py
phazei's picture
Add readme and scripts
95b91f6
import torch
import os
import json
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from collections import OrderedDict
from tqdm import tqdm
import glob # For finding shard files
# --- Configuration ---
# Should match OUTPUT_SHARD_DIR from the previous script
# CONVERTED_SHARDS_DIR = "F:/Models/SkyReels-V2-DF-14B-540P/converted_fp8_shards" # Or T2V path
CONVERTED_SHARDS_DIR = "F:/Models/SkyReels-V2-T2V-14B-540P/converted_fp8_shards" # Or T2V path
# Define the final single output file
FINAL_OUTPUT_MODEL_NAME = "SkyReels-V2-T2V-14B-540P-fp8_e5m2.safetensors" # Example final name
FINAL_OUTPUT_MODEL_PATH = os.path.join(os.path.dirname(CONVERTED_SHARDS_DIR), FINAL_OUTPUT_MODEL_NAME) # Saves in parent of shards dir
# This index is needed to know the *intended order* of tensors if it matters,
# and also to map tensor names to the *new* shard files if your merge logic needs it.
# However, for a simple merge, we can just load all tensors from all new shards.
# For a more robust merge that respects original ordering from an index, we'd need one.
# For now, let's assume we just load everything and save in whatever order they come.
# If specific order is critical, the original index.json from the FP32 model would be needed
# to guide the loading order.
# ORIGINAL_FP32_INDEX_JSON = "F:/Models/SkyReels-V2-DF-14B-540P/model.safetensors.index.json"
print(f"--- SCRIPT START (Merge Converted Shards) ---")
print(f"Converted shards directory: {CONVERTED_SHARDS_DIR}")
print(f"Final output model path: {FINAL_OUTPUT_MODEL_PATH}")
def merge_converted_shards():
if not os.path.exists(CONVERTED_SHARDS_DIR):
print(f"Error: Directory with converted shards not found: {CONVERTED_SHARDS_DIR}")
return
# Find all .safetensors files in the converted_shards_dir
# Ensure they are sorted to process in a consistent order (e.g., 00001, 00002, ...)
shard_files = sorted(glob.glob(os.path.join(CONVERTED_SHARDS_DIR, "fp8_converted_model-*-of-*.safetensors")))
# Or a more generic pattern if your naming was different:
# shard_files = sorted(glob.glob(os.path.join(CONVERTED_SHARDS_DIR, "*.safetensors")))
if not shard_files:
print(f"Error: No converted shard files found in {CONVERTED_SHARDS_DIR}")
return
print(f"Found {len(shard_files)} converted shards to merge.")
merged_state_dict = OrderedDict()
for shard_path in tqdm(shard_files, desc="Merging shards"):
print(f"Loading tensors from: {shard_path}")
try:
# Load all tensors from the current converted shard
# No need for safe_open with individual get_tensor here, load_file is fine
# as these shards are smaller.
current_shard_state_dict = load_file(shard_path, device="cpu")
merged_state_dict.update(current_shard_state_dict)
print(f" Added {len(current_shard_state_dict)} tensors from {os.path.basename(shard_path)}")
except Exception as e:
print(f"Error loading shard {shard_path}: {e}")
# Decide if you want to stop or continue
return # Stop if a shard can't be loaded for the merge
if not merged_state_dict:
print("No tensors were loaded from shards. Final model file will not be created.")
return
print(f"\nMerge complete. Total tensors in merged model: {len(merged_state_dict)}")
print(f"Saving merged model to {FINAL_OUTPUT_MODEL_PATH}...")
try:
os.makedirs(os.path.dirname(FINAL_OUTPUT_MODEL_PATH), exist_ok=True)
save_file(merged_state_dict, FINAL_OUTPUT_MODEL_PATH)
print(f"Successfully saved final merged model to {FINAL_OUTPUT_MODEL_PATH}")
except Exception as e:
print(f"Error saving the final merged model: {e}")
if __name__ == "__main__":
print(f"--- __main__ block start ---")
if not os.path.exists(CONVERTED_SHARDS_DIR):
print(f"Error: Converted shards directory not found: {CONVERTED_SHARDS_DIR}")
else:
merge_converted_shards()
print(f"--- __main__ block end (Merge Converted Shards) ---")