| |
| """ |
| Convert GeneMamba checkpoint to HuggingFace compatible format. |
| |
| This script converts an existing GeneMamba checkpoint (from the original training) |
| to be compatible with the HuggingFace Transformers library. |
| |
| Usage: |
| python scripts/convert_checkpoint.py \ |
| --input_checkpoint /path/to/original/checkpoint \ |
| --output_dir /path/to/output |
| """ |
|
|
| import os |
| import json |
| import shutil |
| import argparse |
| from pathlib import Path |
|
|
|
|
| def convert_checkpoint(input_checkpoint_path, output_dir): |
| """ |
| Convert a GeneMamba checkpoint to HuggingFace format. |
| |
| Args: |
| input_checkpoint_path: Path to the original checkpoint directory |
| output_dir: Output directory for the converted checkpoint |
| """ |
| input_path = Path(input_checkpoint_path) |
| output_path = Path(output_dir) |
| |
| |
| if not input_path.exists(): |
| raise FileNotFoundError(f"Input checkpoint not found: {input_path}") |
| |
| |
| config_file = input_path / "config.json" |
| model_file = input_path / "model.safetensors" |
| tokenizer_file = input_path / "tokenizer.json" |
| tokenizer_config_file = input_path / "tokenizer_config.json" |
| |
| if not config_file.exists(): |
| raise FileNotFoundError(f"config.json not found in {input_path}") |
| if not model_file.exists(): |
| raise FileNotFoundError(f"model.safetensors not found in {input_path}") |
| |
| print(f"[Step 1] Reading original checkpoint from: {input_path}") |
| |
| |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| |
| with open(config_file, 'r') as f: |
| original_config = json.load(f) |
| |
| print("[Step 2] Converting config.json...") |
| |
| |
| hf_config = { |
| |
| "model_type": "genemamba", |
| |
| |
| "architectures": ["GeneMambaModel"], |
| |
| |
| "vocab_size": original_config.get("vocab_size", 25426), |
| "max_position_embeddings": original_config.get("max_position_embeddings", 2048), |
| |
| |
| "hidden_size": original_config.get("d_model", 512), |
| "num_hidden_layers": original_config.get("mamba_layer", 24), |
| "intermediate_size": 2048, |
| |
| |
| "hidden_dropout_prob": 0.1, |
| "initializer_range": 0.02, |
| |
| |
| "mamba_mode": original_config.get("mamba_mode", "gate"), |
| "embedding_pooling": original_config.get("embedding_pooling", "mean"), |
| |
| |
| "num_labels": 2, |
| "pad_token_id": 1, |
| "eos_token_id": 2, |
| "bos_token_id": 0, |
| "use_cache": True, |
| |
| |
| "torch_dtype": original_config.get("torch_dtype", "float32"), |
| "transformers_version": "4.40.2", |
| } |
| |
| |
| new_config_path = output_path / "config.json" |
| with open(new_config_path, 'w') as f: |
| json.dump(hf_config, f, indent=2) |
| print(f"✓ Saved config.json to {new_config_path}") |
| |
| |
| print("[Step 3] Copying model weights...") |
| output_model_file = output_path / "model.safetensors" |
| shutil.copy2(model_file, output_model_file) |
| print(f"✓ Copied model.safetensors ({os.path.getsize(model_file) / 1e9:.2f} GB)") |
| |
| |
| print("[Step 4] Copying tokenizer files...") |
| if tokenizer_file.exists(): |
| shutil.copy2(tokenizer_file, output_path / "tokenizer.json") |
| print("✓ Copied tokenizer.json") |
| else: |
| print("âš tokenizer.json not found (optional)") |
| |
| if tokenizer_config_file.exists(): |
| shutil.copy2(tokenizer_config_file, output_path / "tokenizer_config.json") |
| print("✓ Copied tokenizer_config.json") |
| else: |
| print("âš tokenizer_config.json not found (will be created)") |
| |
| basic_tokenizer_config = { |
| "add_bos_token": True, |
| "add_eos_token": False, |
| "add_prefix_space": False, |
| "bos_token": "<|begin_of_sequence|>", |
| "eos_token": "<|end_of_sequence|>", |
| "model_max_length": 2048, |
| "pad_token": "<|pad|>", |
| "tokenizer_class": "PreTrainedTokenizerFast", |
| "unk_token": "<|unk|>", |
| } |
| with open(output_path / "tokenizer_config.json", 'w') as f: |
| json.dump(basic_tokenizer_config, f, indent=2) |
| print("✓ Created tokenizer_config.json") |
| |
| |
| special_tokens_map = input_path / "special_tokens_map.json" |
| if special_tokens_map.exists(): |
| shutil.copy2(special_tokens_map, output_path / "special_tokens_map.json") |
| print("✓ Copied special_tokens_map.json") |
| |
| print("\n" + "="*70) |
| print("✓ CONVERSION COMPLETE!") |
| print("="*70) |
| print(f"\nModel info:") |
| print(f" Architecture: GeneMamba") |
| print(f" Model Type: {hf_config['model_type']}") |
| print(f" Hidden Size: {hf_config['hidden_size']}") |
| print(f" Num Layers: {hf_config['num_hidden_layers']}") |
| print(f" Vocab Size: {hf_config['vocab_size']}") |
| print(f"\nConverted checkpoint saved to: {output_path}") |
| print(f"\nNext step - Upload to HuggingFace Hub:") |
| print(f" python scripts/push_to_hub.py \\") |
| print(f" --model_path {output_path} \\") |
| print(f" --repo_name <your_username>/<repo_name>") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Convert GeneMamba checkpoint to HuggingFace format", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| # Convert 24L-512D model |
| python scripts/convert_checkpoint.py \\ |
| --input_checkpoint /project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/10m/checkpoint-31250 \\ |
| --output_dir ./converted_checkpoints/GeneMamba2_24l_512d |
| |
| # Convert 48L-768D model |
| python scripts/convert_checkpoint.py \\ |
| --input_checkpoint /project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_48l_768d/1/4m/checkpoint-31250 \\ |
| --output_dir ./converted_checkpoints/GeneMamba2_48l_768d |
| """) |
| |
| parser.add_argument( |
| "--input_checkpoint", |
| required=True, |
| help="Path to original GeneMamba checkpoint directory" |
| ) |
| parser.add_argument( |
| "--output_dir", |
| required=True, |
| help="Output directory for HuggingFace compatible checkpoint" |
| ) |
| |
| args = parser.parse_args() |
| |
| try: |
| convert_checkpoint(args.input_checkpoint, args.output_dir) |
| except Exception as e: |
| print(f"\n✗ ERROR: {str(e)}") |
| exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|