File size: 1,404 Bytes
82f9e44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from pathlib import Path
import os
import json
def get_config(path= None):
if path and Path.exists(Path(path)):
with open(path,"r") as f:
config = json.load(f)
requires =["batch_size","num_epochs","lr","seq_len","d_model","d_ff","N","h","model_folder","model_basename","preload","tokenizer_file","experiment_name"]
not_includes = []
for r in requires:
if r not in config:
not_includes.append(r)
if len(not_includes) > 0 :
raise ValueError(f"Field(s) missing in config file : {''.join(not_includes)}")
return config
return {
"batch_size":4,
"num_epochs":30,
"lr":3**-4,
"seq_len":360,
"d_model":512,
"N":6,
"h":8,
"d_ff":2048,
"lang_src":"en",
"lang_tgt":"it",
"model_folder":"weights",
'datasource':"opus_books",
"model_basename":"tmodel_",
"preload":25,
"tokenizer_file":"tokenizer_{0}.json",
"experiment_name":"runs/tmodel",
}
def get_weights_file_path(config,epoch:str):
model_folder = config['model_folder']
model_basename = config['model_basename']
model_filename = f"{model_basename}{epoch}.pt"
return str(Path(os.getcwd()) / model_folder / model_filename)
|