import time | |
out_dir = 'out-tinystory4' | |
eval_interval = 20 | |
eval_iters = 40 | |
wandb_log = True # feel free to turn on | |
wandb_project = 'tinystory-4' | |
wandb_run_name = 'ft-' + str(time.time()) | |
dataset = 'tinystory4' | |
init_from = 'resume' | |
# only save checkpoints if the validation loss improves | |
always_save_checkpoint = False | |
# the number of examples per iter: | |
# 8 batch_size * 16 grad_accum * 256 tokens = 32,768 tokens/iter | |
# Tinystory has 473,992,236 tokens, so 1 epoch ~= 14400 iters | |
batch_size = 8 | |
gradient_accumulation_steps = 16 | |
max_iters = 7200 | |
block_size = 256 | |
# finetune at constant LR | |
learning_rate = 3e-4 | |
decay_lr = False | |
dropout = 0.1 | |