model_tools / safetensors_meta_ripper_v1.py
Naphula's picture
Upload 3 files
22115d9 verified
import os
import argparse
import json
import glob
from safetensors import safe_open
from gguf import GGUFReader
from gguf.constants import Keys
from typing import List, Dict, Any
def create_safetensors_index(shards_dir: str, output_dir: str) -> None:
"""Creates the model.safetensors.index.json file by scanning shard files."""
shard_pattern = os.path.join(shards_dir, '*.safetensors')
shard_files = sorted(glob.glob(shard_pattern))
if not shard_files:
print(f"Error: No .safetensors files found in directory: {shards_dir}")
return
print(f"Found {len(shard_files)} shard files to index.")
index_data: Dict[str, Any] = {"metadata": {}, "weight_map": {}}
total_size = 0
for shard_file in shard_files:
shard_basename = os.path.basename(shard_file)
try:
with safe_open(shard_file, framework="pt", device="cpu") as f:
for tensor_name in f.keys():
index_data["weight_map"][tensor_name] = shard_basename
shard_size = os.path.getsize(shard_file)
total_size += shard_size
except Exception as e:
print(f"Warning: Could not process shard {shard_basename}. Error: {e}")
continue
index_data["metadata"]["total_size"] = total_size
index_filepath = os.path.join(output_dir, "model.safetensors.index.json")
try:
with open(index_filepath, 'w', encoding='utf-8') as f:
json.dump(index_data, f, indent=2)
print(f"Successfully created safetensors index file: {index_filepath}")
except Exception as e:
print(f"Error: Failed to write index file. Error: {e}")
def extract_and_save_gguf_configs(reader: GGUFReader, output_dir: str) -> None:
"""Extracts metadata from GGUF and saves config, tokenizer, and generation files."""
config = {}
# --- config.json ---
try:
arch = reader.get_field(Keys.General.ARCHITECTURE).name.lower()
model_type_map = {"llama": "llama", "mistral": "mistral", "gemma": "gemma"}
model_type = model_type_map.get(arch, arch)
config = {
"architectures": [arch.capitalize()],
"model_type": model_type,
"hidden_size": reader.get_int_value(f"{model_type}.embedding_length"),
"intermediate_size": reader.get_int_value(f"{model_type}.feed_forward_length"),
"num_attention_heads": reader.get_int_value(f"{model_type}.attention.head_count"),
"num_hidden_layers": reader.get_int_value(f"{model_type}.block_count"),
"num_key_value_heads": reader.get_int_value(f"{model_type}.attention.head_count_kv"),
"rms_norm_eps": reader.get_float_value(f"{model_type}.attention.layer_norm_rms_epsilon"),
"vocab_size": len(reader.get_field(Keys.Tokenizer.VOCAB)),
"rope_theta": reader.get_float_value(f"{model_type}.rope.freq_base"),
"max_position_embeddings": reader.get_int_value(f"{model_type}.context_length"),
}
with open(os.path.join(output_dir, "config.json"), 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
print("Created config.json")
except Exception as e:
print(f"Warning: Could not create config.json. Some values may be missing. Error: {e}")
# --- tokenizer_config.json ---
try:
tokenizer_config = {
"model_max_length": config.get("max_position_embeddings", 4096),
"padding_side": "left",
"tokenizer_class": "LlamaTokenizer",
}
# Add chat template if it exists
try:
chat_template = reader.get_str_value("tokenizer.chat_template")
tokenizer_config["chat_template"] = chat_template
except (KeyError, ValueError):
pass # Field does not exist
with open(os.path.join(output_dir, "tokenizer_config.json"), 'w', encoding='utf-8') as f:
json.dump(tokenizer_config, f, indent=2)
print("Created tokenizer_config.json")
except Exception as e:
print(f"Warning: Could not create tokenizer_config.json. Error: {e}")
# --- tokenizer.json ---
try:
vocab = [item.piece for item in reader.get_field(Keys.Tokenizer.VOCAB)]
merges = reader.get_field(Keys.Tokenizer.MERGES)
tokenizer_data = {
"version": "1.0",
"model": {
"type": "BPE",
"vocab": {token: i for i, token in enumerate(vocab)},
"merges": merges,
},
"added_tokens": [],
}
with open(os.path.join(output_dir, "tokenizer.json"), 'w', encoding='utf-8') as f:
json.dump(tokenizer_data, f, indent=None, separators=(',', ':'))
print("Created tokenizer.json")
except Exception as e:
print(f"Warning: Could not create tokenizer.json. Error: {e}")
# --- special_tokens_map.json ---
try:
special_map = {}
# Use a helper to avoid crashing on missing keys
def add_special_token(key_name, gguf_id_key):
try:
token_id = reader.get_int_value(gguf_id_key)
token_str = vocab[token_id]
special_map[key_name] = token_str
except (KeyError, ValueError, IndexError):
pass
add_special_token("bos_token", "tokenizer.ggml.bos_token_id")
add_special_token("eos_token", "tokenizer.ggml.eos_token_id")
add_special_token("unk_token", "tokenizer.ggml.unknown_token_id")
with open(os.path.join(output_dir, "special_tokens_map.json"), 'w', encoding='utf-8') as f:
json.dump(special_map, f, indent=2)
print("Created special_tokens_map.json")
except Exception as e:
print(f"Warning: Could not create special_tokens_map.json. Error: {e}")
# --- generation_config.json ---
try:
gen_config = {"_from_model_config": True}
try:
gen_config["bos_token_id"] = reader.get_int_value("tokenizer.ggml.bos_token_id")
gen_config["eos_token_id"] = reader.get_int_value("tokenizer.ggml.eos_token_id")
except (KeyError, ValueError):
pass
with open(os.path.join(output_dir, "generation_config.json"), 'w', encoding='utf-8') as f:
json.dump(gen_config, f, indent=2)
print("Created generation_config.json")
except Exception as e:
print(f"Warning: Could not create generation_config.json. Error: {e}")
def main():
parser = argparse.ArgumentParser(
description="Generate safetensors index and config files for a sharded model directory."
)
parser.add_argument(
"--gguf-file",
required=True,
help="Path to the original GGUF file to read metadata from."
)
parser.add_argument(
"--shards-dir",
required=True,
help="Path to the directory containing the sharded .safetensors files."
)
args = parser.parse_args()
if not os.path.isfile(args.gguf_file):
print(f"Error: GGUF file not found at {args.gguf_file}")
return
if not os.path.isdir(args.shards_dir):
print(f"Error: Shards directory not found at {args.shards_dir}")
return
print(f"Loading GGUF metadata from: {args.gguf_file}")
reader = GGUFReader(args.gguf_file, 'r')
# Generate config files from GGUF header and save them to the shards directory
extract_and_save_gguf_configs(reader, args.shards_dir)
# Generate the safetensors index from the actual shard files
create_safetensors_index(args.shards_dir, args.shards_dir)
print("\nMetadata ripping complete.")
if __name__ == "__main__":
main()