SimpleGPT / README.md
whut-zhangwx's picture
fix model card
1c3127a
metadata
language:
  - en
thumbnail: https://huggingface.co/whut-zhangwx/SimpleGPT
tags:
  - gpt
license: mit
datasets:
  - tinyshakespeare
base_model: whut-zhangwx/SimpleGPT

Intruction

This is a pre-trained weight for SimpleGPT.

It was trained on tinyshakespeare and used hyper-parameter as follows

n_layer: 12,
n_head: 12,
embed_dim: 768,
time_step: 256,
bias: False,
vocab_size: 65,
dropout: 0.0
iter_num: 50000

File Content

ckpt_iter_50000.pt contains 6 items

checkpoint = {
  'state_dict': raw_model.state_dict(),
  'optimizer': optimizer.state_dict(),
  'model_args': model_args,
  'iter_num': iter_num,
  'best_val_loss': best_val_loss,
  'config': config,
}

Use this little script to display them

import os
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = "path/to/ckpt_iter_50000.pt"
assert os.path.exists(ckpt_path), f"{ckpt_path} doesn't exit."
checkpoint = torch.load(ckpt_path, map_location=device)

model_args = checkpoint['model_args']
print(model_args)

state_dict = checkpoint['state_dict']
for layer_name, weight_matrix in state_dict.items():
  print(f"{layer_name}\t{weight_matrix.shape}")

Usage

git clone my repository SimpleGPT | whut-zhangwx. Follow the script generate.py to load checkpoint into GPT model to do generation.