Supernova25million / supernova /verify_params.py
algorythmtechnologies's picture
Upload folder using huggingface_hub
8174855 verified
raw
history blame contribute delete
921 Bytes
import argparse
import json
from typing import Any
import torch
from .config import ModelConfig
from .model import SupernovaModel
from .tokenizer import load_gpt2_tokenizer
def main(config_path: str):
cfg = ModelConfig.from_json_file(config_path)
tok = load_gpt2_tokenizer()
assert tok.vocab_size == cfg.vocab_size
model = SupernovaModel(cfg)
total_params = sum(p.numel() for p in model.parameters())
print(json.dumps({
"vocab_size": tok.vocab_size,
"n_positions": cfg.n_positions,
"d_model": cfg.d_model,
"n_layers": cfg.n_layers,
"n_heads": cfg.n_heads,
"total_params": total_params,
"exact": total_params == 25_000_000
}, indent=2))
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True)
args = ap.parse_args()
main(args.config)