|
import torch |
|
import os |
|
import json |
|
from safetensors.torch import load_file, save_file |
|
from safetensors import safe_open |
|
from collections import OrderedDict |
|
from tqdm import tqdm |
|
import glob |
|
|
|
|
|
|
|
|
|
CONVERTED_SHARDS_DIR = "F:/Models/SkyReels-V2-T2V-14B-540P/converted_fp8_shards" |
|
|
|
FINAL_OUTPUT_MODEL_NAME = "SkyReels-V2-T2V-14B-540P-fp8_e5m2.safetensors" |
|
FINAL_OUTPUT_MODEL_PATH = os.path.join(os.path.dirname(CONVERTED_SHARDS_DIR), FINAL_OUTPUT_MODEL_NAME) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"--- SCRIPT START (Merge Converted Shards) ---") |
|
print(f"Converted shards directory: {CONVERTED_SHARDS_DIR}") |
|
print(f"Final output model path: {FINAL_OUTPUT_MODEL_PATH}") |
|
|
|
def merge_converted_shards(): |
|
if not os.path.exists(CONVERTED_SHARDS_DIR): |
|
print(f"Error: Directory with converted shards not found: {CONVERTED_SHARDS_DIR}") |
|
return |
|
|
|
|
|
|
|
shard_files = sorted(glob.glob(os.path.join(CONVERTED_SHARDS_DIR, "fp8_converted_model-*-of-*.safetensors"))) |
|
|
|
|
|
|
|
|
|
if not shard_files: |
|
print(f"Error: No converted shard files found in {CONVERTED_SHARDS_DIR}") |
|
return |
|
|
|
print(f"Found {len(shard_files)} converted shards to merge.") |
|
|
|
merged_state_dict = OrderedDict() |
|
|
|
for shard_path in tqdm(shard_files, desc="Merging shards"): |
|
print(f"Loading tensors from: {shard_path}") |
|
try: |
|
|
|
|
|
|
|
current_shard_state_dict = load_file(shard_path, device="cpu") |
|
merged_state_dict.update(current_shard_state_dict) |
|
print(f" Added {len(current_shard_state_dict)} tensors from {os.path.basename(shard_path)}") |
|
except Exception as e: |
|
print(f"Error loading shard {shard_path}: {e}") |
|
|
|
return |
|
|
|
if not merged_state_dict: |
|
print("No tensors were loaded from shards. Final model file will not be created.") |
|
return |
|
|
|
print(f"\nMerge complete. Total tensors in merged model: {len(merged_state_dict)}") |
|
print(f"Saving merged model to {FINAL_OUTPUT_MODEL_PATH}...") |
|
try: |
|
os.makedirs(os.path.dirname(FINAL_OUTPUT_MODEL_PATH), exist_ok=True) |
|
save_file(merged_state_dict, FINAL_OUTPUT_MODEL_PATH) |
|
print(f"Successfully saved final merged model to {FINAL_OUTPUT_MODEL_PATH}") |
|
except Exception as e: |
|
print(f"Error saving the final merged model: {e}") |
|
|
|
if __name__ == "__main__": |
|
print(f"--- __main__ block start ---") |
|
if not os.path.exists(CONVERTED_SHARDS_DIR): |
|
print(f"Error: Converted shards directory not found: {CONVERTED_SHARDS_DIR}") |
|
else: |
|
merge_converted_shards() |
|
print(f"--- __main__ block end (Merge Converted Shards) ---") |