10M-LLM / config.py
abancp's picture
openwebbook2
82f9e44
raw
history blame contribute delete
1.4 kB
from pathlib import Path
import os
import json
def get_config(path= None):
if path and Path.exists(Path(path)):
with open(path,"r") as f:
config = json.load(f)
requires =["batch_size","num_epochs","lr","seq_len","d_model","d_ff","N","h","model_folder","model_basename","preload","tokenizer_file","experiment_name"]
not_includes = []
for r in requires:
if r not in config:
not_includes.append(r)
if len(not_includes) > 0 :
raise ValueError(f"Field(s) missing in config file : {''.join(not_includes)}")
return config
return {
"batch_size":4,
"num_epochs":30,
"lr":3**-4,
"seq_len":360,
"d_model":512,
"N":6,
"h":8,
"d_ff":2048,
"lang_src":"en",
"lang_tgt":"it",
"model_folder":"weights",
'datasource':"opus_books",
"model_basename":"tmodel_",
"preload":25,
"tokenizer_file":"tokenizer_{0}.json",
"experiment_name":"runs/tmodel",
}
def get_weights_file_path(config,epoch:str):
model_folder = config['model_folder']
model_basename = config['model_basename']
model_filename = f"{model_basename}{epoch}.pt"
return str(Path(os.getcwd()) / model_folder / model_filename)