|
|
| """Convert a Prisma training checkpoint to HuggingFace format.
|
|
|
| Usage:
|
| python Prisma/convert_checkpoint.py \
|
| --checkpoint circuits/checkpoints/mirrored_300M_mk4_cont/epoch_02.pt \
|
| --output-dir Prisma/ \
|
| --tokenizer facebook/MobileLLM-125M
|
|
|
| This will create:
|
| Prisma/model.safetensors — model weights
|
| Prisma/config.json — model configuration
|
| Prisma/tokenizer.json — tokenizer files
|
| Prisma/tokenizer_config.json
|
| Prisma/special_tokens_map.json
|
| """
|
|
|
| import argparse
|
| import sys
|
| from pathlib import Path
|
|
|
|
|
| _repo_root = Path(__file__).resolve().parent.parent
|
| if str(_repo_root) not in sys.path:
|
| sys.path.insert(0, str(_repo_root))
|
|
|
| import torch
|
| from safetensors.torch import save_file
|
| from transformers import AutoTokenizer
|
|
|
|
|
|
|
| SKIP_SUFFIXES = (
|
| ".inv_freq",
|
| ".cos_cached",
|
| ".sin_cached",
|
| ".causal_mask",
|
| ".word_inv_freq",
|
| )
|
|
|
|
|
| def convert_checkpoint(
|
| checkpoint_path: str,
|
| output_dir: str,
|
| tokenizer_name: str = "facebook/MobileLLM-125M",
|
| dtype: str = "float16",
|
| ):
|
| output_path = Path(output_dir)
|
| output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| print(f"Loading checkpoint: {checkpoint_path}")
|
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
|
| config_dict = ckpt["config"]
|
| model_type = ckpt.get("model_type", "mirrored")
|
| raw_state = ckpt["model"]
|
|
|
| print(f" Model type: {model_type}")
|
| print(f" Config: {config_dict}")
|
| print(f" State dict keys: {len(raw_state)}")
|
|
|
|
|
| cleaned = {}
|
| skipped_buffers = 0
|
| skipped_tied = 0
|
|
|
| for key, tensor in raw_state.items():
|
|
|
| clean_key = key.replace("_orig_mod.", "")
|
|
|
|
|
| if any(clean_key.endswith(s) for s in SKIP_SUFFIXES):
|
| skipped_buffers += 1
|
| continue
|
|
|
|
|
| hf_key = f"transformer.{clean_key}"
|
| cleaned[hf_key] = tensor
|
|
|
| print(f" Skipped {skipped_buffers} deterministic buffers")
|
|
|
|
|
| embed_key = "transformer.embed.weight"
|
| lm_head_key = "transformer.lm_head.weight"
|
|
|
| embed_dim = config_dict.get("embed_dim", 0) or config_dict["hidden_size"]
|
| head_dim = config_dict.get("head_dim", 0) or config_dict["hidden_size"]
|
| tie_embeddings = embed_dim == head_dim
|
|
|
| if tie_embeddings and embed_key in cleaned and lm_head_key in cleaned:
|
|
|
| if torch.equal(cleaned[embed_key], cleaned[lm_head_key]):
|
| del cleaned[lm_head_key]
|
| skipped_tied = 1
|
| print(f" Removed tied lm_head.weight (same as embed.weight)")
|
| else:
|
| tie_embeddings = False
|
| print(f" WARNING: embed and lm_head differ despite matching dims — keeping both")
|
|
|
|
|
| word_rope_dims = config_dict.get("word_rope_dims", 0)
|
| if word_rope_dims > 0:
|
| print(f" Building word_start_table from tokenizer: {tokenizer_name}")
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
|
| vocab_size = config_dict["vocab_size"]
|
| table = torch.zeros(vocab_size, dtype=torch.bool)
|
| tokens = tokenizer.convert_ids_to_tokens(list(range(vocab_size)))
|
| for idx, tok in enumerate(tokens):
|
| if tok is None:
|
| continue
|
| if tok.startswith('Ġ') or tok.startswith('▁') or tok.startswith('<'):
|
| table[idx] = True
|
| elif len(tok) > 0 and tok[0] in '\n\r\t':
|
| table[idx] = True
|
| table[0] = True
|
| cleaned["word_start_table"] = table
|
| print(f" Word start table: {table.sum().item()}/{len(table)} tokens marked as word starters")
|
|
|
|
|
| target_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
|
| for key in cleaned:
|
| if cleaned[key].dtype == torch.float32 and cleaned[key].dtype != target_dtype:
|
|
|
| if cleaned[key].dtype != torch.bool:
|
| cleaned[key] = cleaned[key].to(target_dtype)
|
|
|
| total_params = sum(t.numel() for t in cleaned.values() if t.dtype != torch.bool)
|
| total_bytes = sum(t.numel() * t.element_size() for t in cleaned.values())
|
| print(f" Total parameters: {total_params:,}")
|
| print(f" File size: {total_bytes / 1e9:.2f} GB ({dtype})")
|
|
|
|
|
| safetensors_path = output_path / "model.safetensors"
|
| print(f"\nSaving weights: {safetensors_path}")
|
| save_file(cleaned, str(safetensors_path))
|
|
|
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent))
|
| from configuration_prisma import PrismaConfig
|
|
|
| hf_config = PrismaConfig(
|
| vocab_size=config_dict["vocab_size"],
|
| hidden_size=config_dict["hidden_size"],
|
| num_heads=config_dict["num_heads"],
|
| num_kv_heads=config_dict.get("num_kv_heads"),
|
| num_layers=config_dict["num_layers"],
|
| n_middle=config_dict.get("n_middle", 1),
|
| max_seq_len=config_dict.get("max_seq_len", 1024),
|
| dropout=config_dict.get("dropout", 0.0),
|
| aux_skip_k=config_dict.get("aux_skip_k", 0),
|
| aux_skip_weight=config_dict.get("aux_skip_weight", 0.1),
|
| use_g2lu=config_dict.get("use_g2lu", True),
|
| word_rope_dims=config_dict.get("word_rope_dims", 0),
|
| word_rope_base=config_dict.get("word_rope_base", 10.0),
|
| embed_dim=config_dict.get("embed_dim", 0),
|
| head_dim=config_dict.get("head_dim", 0),
|
| tie_word_embeddings=tie_embeddings,
|
| auto_map={
|
| "AutoConfig": "configuration_prisma.PrismaConfig",
|
| "AutoModelForCausalLM": "modeling_prisma.PrismaForCausalLM",
|
| },
|
| )
|
| hf_config.save_pretrained(str(output_path))
|
| print(f"Saved config: {output_path / 'config.json'}")
|
|
|
|
|
| print(f"\nSaving tokenizer from: {tokenizer_name}")
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
|
| tokenizer.save_pretrained(str(output_path))
|
| print(f"Saved tokenizer files to: {output_path}")
|
|
|
|
|
| print(f"\n{'='*60}")
|
| print(f"Conversion complete!")
|
| print(f" Output directory: {output_path}")
|
| print(f" Model size: {total_bytes / 1e9:.2f} GB ({dtype})")
|
| print(f" Parameters: {total_params:,}")
|
| print(f" Tied embeddings: {tie_embeddings}")
|
| print(f" Word RoPE dims: {word_rope_dims}")
|
| print(f"{'='*60}")
|
| print(f"\nUsage:")
|
| print(f' from transformers import AutoModelForCausalLM, AutoTokenizer')
|
| print(f' model = AutoModelForCausalLM.from_pretrained("{output_path}", trust_remote_code=True)')
|
| print(f' tokenizer = AutoTokenizer.from_pretrained("{output_path}")')
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser(description="Convert Prisma checkpoint to HuggingFace format")
|
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to .pt checkpoint")
|
| parser.add_argument("--output-dir", type=str, default="Prisma/", help="Output directory")
|
| parser.add_argument("--tokenizer", type=str, default="facebook/MobileLLM-125M", help="Tokenizer name")
|
| parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32"])
|
| args = parser.parse_args()
|
|
|
| convert_checkpoint(args.checkpoint, args.output_dir, args.tokenizer, args.dtype)
|
|
|