GeneMamba / scripts /convert_checkpoint.py
mineself2016's picture
Upload GeneMamba model
54cd552 verified
#!/usr/bin/env python3
"""
Convert GeneMamba checkpoint to HuggingFace compatible format.
This script converts an existing GeneMamba checkpoint (from the original training)
to be compatible with the HuggingFace Transformers library.
Usage:
python scripts/convert_checkpoint.py \
--input_checkpoint /path/to/original/checkpoint \
--output_dir /path/to/output
"""
import os
import json
import shutil
import argparse
from pathlib import Path
def convert_checkpoint(input_checkpoint_path, output_dir):
"""
Convert a GeneMamba checkpoint to HuggingFace format.
Args:
input_checkpoint_path: Path to the original checkpoint directory
output_dir: Output directory for the converted checkpoint
"""
input_path = Path(input_checkpoint_path)
output_path = Path(output_dir)
# Verify input checkpoint exists
if not input_path.exists():
raise FileNotFoundError(f"Input checkpoint not found: {input_path}")
# Check for required files
config_file = input_path / "config.json"
model_file = input_path / "model.safetensors"
tokenizer_file = input_path / "tokenizer.json"
tokenizer_config_file = input_path / "tokenizer_config.json"
if not config_file.exists():
raise FileNotFoundError(f"config.json not found in {input_path}")
if not model_file.exists():
raise FileNotFoundError(f"model.safetensors not found in {input_path}")
print(f"[Step 1] Reading original checkpoint from: {input_path}")
# Create output directory
output_path.mkdir(parents=True, exist_ok=True)
# Read original config
with open(config_file, 'r') as f:
original_config = json.load(f)
print("[Step 2] Converting config.json...")
# Create new HuggingFace-compatible config
hf_config = {
# Model type (CRITICAL for HuggingFace to recognize the model)
"model_type": "genemamba",
# Architecture info
"architectures": ["GeneMambaModel"],
# Vocabulary and sequence
"vocab_size": original_config.get("vocab_size", 25426),
"max_position_embeddings": original_config.get("max_position_embeddings", 2048),
# Model dimensions
"hidden_size": original_config.get("d_model", 512),
"num_hidden_layers": original_config.get("mamba_layer", 24),
"intermediate_size": 2048,
# Regularization
"hidden_dropout_prob": 0.1,
"initializer_range": 0.02,
# Mamba-specific
"mamba_mode": original_config.get("mamba_mode", "gate"),
"embedding_pooling": original_config.get("embedding_pooling", "mean"),
# Task-specific
"num_labels": 2,
"pad_token_id": 1,
"eos_token_id": 2,
"bos_token_id": 0,
"use_cache": True,
# Metadata
"torch_dtype": original_config.get("torch_dtype", "float32"),
"transformers_version": "4.40.2",
}
# Save new config
new_config_path = output_path / "config.json"
with open(new_config_path, 'w') as f:
json.dump(hf_config, f, indent=2)
print(f"✓ Saved config.json to {new_config_path}")
# Copy model weights
print("[Step 3] Copying model weights...")
output_model_file = output_path / "model.safetensors"
shutil.copy2(model_file, output_model_file)
print(f"✓ Copied model.safetensors ({os.path.getsize(model_file) / 1e9:.2f} GB)")
# Copy tokenizer files if they exist
print("[Step 4] Copying tokenizer files...")
if tokenizer_file.exists():
shutil.copy2(tokenizer_file, output_path / "tokenizer.json")
print("✓ Copied tokenizer.json")
else:
print("âš  tokenizer.json not found (optional)")
if tokenizer_config_file.exists():
shutil.copy2(tokenizer_config_file, output_path / "tokenizer_config.json")
print("✓ Copied tokenizer_config.json")
else:
print("âš  tokenizer_config.json not found (will be created)")
# Create a basic tokenizer config if it doesn't exist
basic_tokenizer_config = {
"add_bos_token": True,
"add_eos_token": False,
"add_prefix_space": False,
"bos_token": "<|begin_of_sequence|>",
"eos_token": "<|end_of_sequence|>",
"model_max_length": 2048,
"pad_token": "<|pad|>",
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<|unk|>",
}
with open(output_path / "tokenizer_config.json", 'w') as f:
json.dump(basic_tokenizer_config, f, indent=2)
print("✓ Created tokenizer_config.json")
# Copy special tokens map
special_tokens_map = input_path / "special_tokens_map.json"
if special_tokens_map.exists():
shutil.copy2(special_tokens_map, output_path / "special_tokens_map.json")
print("✓ Copied special_tokens_map.json")
print("\n" + "="*70)
print("✓ CONVERSION COMPLETE!")
print("="*70)
print(f"\nModel info:")
print(f" Architecture: GeneMamba")
print(f" Model Type: {hf_config['model_type']}")
print(f" Hidden Size: {hf_config['hidden_size']}")
print(f" Num Layers: {hf_config['num_hidden_layers']}")
print(f" Vocab Size: {hf_config['vocab_size']}")
print(f"\nConverted checkpoint saved to: {output_path}")
print(f"\nNext step - Upload to HuggingFace Hub:")
print(f" python scripts/push_to_hub.py \\")
print(f" --model_path {output_path} \\")
print(f" --repo_name <your_username>/<repo_name>")
def main():
parser = argparse.ArgumentParser(
description="Convert GeneMamba checkpoint to HuggingFace format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Convert 24L-512D model
python scripts/convert_checkpoint.py \\
--input_checkpoint /project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/10m/checkpoint-31250 \\
--output_dir ./converted_checkpoints/GeneMamba2_24l_512d
# Convert 48L-768D model
python scripts/convert_checkpoint.py \\
--input_checkpoint /project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_48l_768d/1/4m/checkpoint-31250 \\
--output_dir ./converted_checkpoints/GeneMamba2_48l_768d
""")
parser.add_argument(
"--input_checkpoint",
required=True,
help="Path to original GeneMamba checkpoint directory"
)
parser.add_argument(
"--output_dir",
required=True,
help="Output directory for HuggingFace compatible checkpoint"
)
args = parser.parse_args()
try:
convert_checkpoint(args.input_checkpoint, args.output_dir)
except Exception as e:
print(f"\n✗ ERROR: {str(e)}")
exit(1)
if __name__ == "__main__":
main()