Spaces:
Running
Running
| import os | |
| import argparse | |
| import json | |
| import glob | |
| from safetensors import safe_open | |
| from gguf import GGUFReader | |
| from gguf.constants import Keys | |
| from typing import List, Dict, Any | |
| def create_safetensors_index(shards_dir: str, output_dir: str) -> None: | |
| """Creates the model.safetensors.index.json file by scanning shard files.""" | |
| shard_pattern = os.path.join(shards_dir, '*.safetensors') | |
| shard_files = sorted(glob.glob(shard_pattern)) | |
| if not shard_files: | |
| print(f"Error: No .safetensors files found in directory: {shards_dir}") | |
| return | |
| print(f"Found {len(shard_files)} shard files to index.") | |
| index_data: Dict[str, Any] = {"metadata": {}, "weight_map": {}} | |
| total_size = 0 | |
| for shard_file in shard_files: | |
| shard_basename = os.path.basename(shard_file) | |
| try: | |
| with safe_open(shard_file, framework="pt", device="cpu") as f: | |
| for tensor_name in f.keys(): | |
| index_data["weight_map"][tensor_name] = shard_basename | |
| shard_size = os.path.getsize(shard_file) | |
| total_size += shard_size | |
| except Exception as e: | |
| print(f"Warning: Could not process shard {shard_basename}. Error: {e}") | |
| continue | |
| index_data["metadata"]["total_size"] = total_size | |
| index_filepath = os.path.join(output_dir, "model.safetensors.index.json") | |
| try: | |
| with open(index_filepath, 'w', encoding='utf-8') as f: | |
| json.dump(index_data, f, indent=2) | |
| print(f"Successfully created safetensors index file: {index_filepath}") | |
| except Exception as e: | |
| print(f"Error: Failed to write index file. Error: {e}") | |
| def extract_and_save_gguf_configs(reader: GGUFReader, output_dir: str) -> None: | |
| """Extracts metadata from GGUF and saves config, tokenizer, and generation files.""" | |
| config = {} | |
| # --- config.json --- | |
| try: | |
| arch = reader.get_field(Keys.General.ARCHITECTURE).name.lower() | |
| model_type_map = {"llama": "llama", "mistral": "mistral", "gemma": "gemma"} | |
| model_type = model_type_map.get(arch, arch) | |
| config = { | |
| "architectures": [arch.capitalize()], | |
| "model_type": model_type, | |
| "hidden_size": reader.get_int_value(f"{model_type}.embedding_length"), | |
| "intermediate_size": reader.get_int_value(f"{model_type}.feed_forward_length"), | |
| "num_attention_heads": reader.get_int_value(f"{model_type}.attention.head_count"), | |
| "num_hidden_layers": reader.get_int_value(f"{model_type}.block_count"), | |
| "num_key_value_heads": reader.get_int_value(f"{model_type}.attention.head_count_kv"), | |
| "rms_norm_eps": reader.get_float_value(f"{model_type}.attention.layer_norm_rms_epsilon"), | |
| "vocab_size": len(reader.get_field(Keys.Tokenizer.VOCAB)), | |
| "rope_theta": reader.get_float_value(f"{model_type}.rope.freq_base"), | |
| "max_position_embeddings": reader.get_int_value(f"{model_type}.context_length"), | |
| } | |
| with open(os.path.join(output_dir, "config.json"), 'w', encoding='utf-8') as f: | |
| json.dump(config, f, indent=2) | |
| print("Created config.json") | |
| except Exception as e: | |
| print(f"Warning: Could not create config.json. Some values may be missing. Error: {e}") | |
| # --- tokenizer_config.json --- | |
| try: | |
| tokenizer_config = { | |
| "model_max_length": config.get("max_position_embeddings", 4096), | |
| "padding_side": "left", | |
| "tokenizer_class": "LlamaTokenizer", | |
| } | |
| # Add chat template if it exists | |
| try: | |
| chat_template = reader.get_str_value("tokenizer.chat_template") | |
| tokenizer_config["chat_template"] = chat_template | |
| except (KeyError, ValueError): | |
| pass # Field does not exist | |
| with open(os.path.join(output_dir, "tokenizer_config.json"), 'w', encoding='utf-8') as f: | |
| json.dump(tokenizer_config, f, indent=2) | |
| print("Created tokenizer_config.json") | |
| except Exception as e: | |
| print(f"Warning: Could not create tokenizer_config.json. Error: {e}") | |
| # --- tokenizer.json --- | |
| try: | |
| vocab = [item.piece for item in reader.get_field(Keys.Tokenizer.VOCAB)] | |
| merges = reader.get_field(Keys.Tokenizer.MERGES) | |
| tokenizer_data = { | |
| "version": "1.0", | |
| "model": { | |
| "type": "BPE", | |
| "vocab": {token: i for i, token in enumerate(vocab)}, | |
| "merges": merges, | |
| }, | |
| "added_tokens": [], | |
| } | |
| with open(os.path.join(output_dir, "tokenizer.json"), 'w', encoding='utf-8') as f: | |
| json.dump(tokenizer_data, f, indent=None, separators=(',', ':')) | |
| print("Created tokenizer.json") | |
| except Exception as e: | |
| print(f"Warning: Could not create tokenizer.json. Error: {e}") | |
| # --- special_tokens_map.json --- | |
| try: | |
| special_map = {} | |
| # Use a helper to avoid crashing on missing keys | |
| def add_special_token(key_name, gguf_id_key): | |
| try: | |
| token_id = reader.get_int_value(gguf_id_key) | |
| token_str = vocab[token_id] | |
| special_map[key_name] = token_str | |
| except (KeyError, ValueError, IndexError): | |
| pass | |
| add_special_token("bos_token", "tokenizer.ggml.bos_token_id") | |
| add_special_token("eos_token", "tokenizer.ggml.eos_token_id") | |
| add_special_token("unk_token", "tokenizer.ggml.unknown_token_id") | |
| with open(os.path.join(output_dir, "special_tokens_map.json"), 'w', encoding='utf-8') as f: | |
| json.dump(special_map, f, indent=2) | |
| print("Created special_tokens_map.json") | |
| except Exception as e: | |
| print(f"Warning: Could not create special_tokens_map.json. Error: {e}") | |
| # --- generation_config.json --- | |
| try: | |
| gen_config = {"_from_model_config": True} | |
| try: | |
| gen_config["bos_token_id"] = reader.get_int_value("tokenizer.ggml.bos_token_id") | |
| gen_config["eos_token_id"] = reader.get_int_value("tokenizer.ggml.eos_token_id") | |
| except (KeyError, ValueError): | |
| pass | |
| with open(os.path.join(output_dir, "generation_config.json"), 'w', encoding='utf-8') as f: | |
| json.dump(gen_config, f, indent=2) | |
| print("Created generation_config.json") | |
| except Exception as e: | |
| print(f"Warning: Could not create generation_config.json. Error: {e}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Generate safetensors index and config files for a sharded model directory." | |
| ) | |
| parser.add_argument( | |
| "--gguf-file", | |
| required=True, | |
| help="Path to the original GGUF file to read metadata from." | |
| ) | |
| parser.add_argument( | |
| "--shards-dir", | |
| required=True, | |
| help="Path to the directory containing the sharded .safetensors files." | |
| ) | |
| args = parser.parse_args() | |
| if not os.path.isfile(args.gguf_file): | |
| print(f"Error: GGUF file not found at {args.gguf_file}") | |
| return | |
| if not os.path.isdir(args.shards_dir): | |
| print(f"Error: Shards directory not found at {args.shards_dir}") | |
| return | |
| print(f"Loading GGUF metadata from: {args.gguf_file}") | |
| reader = GGUFReader(args.gguf_file, 'r') | |
| # Generate config files from GGUF header and save them to the shards directory | |
| extract_and_save_gguf_configs(reader, args.shards_dir) | |
| # Generate the safetensors index from the actual shard files | |
| create_safetensors_index(args.shards_dir, args.shards_dir) | |
| print("\nMetadata ripping complete.") | |
| if __name__ == "__main__": | |
| main() |