|
from pathlib import Path |
|
import os |
|
import json |
|
|
|
def get_config(path= None): |
|
if path and Path.exists(Path(path)): |
|
with open(path,"r") as f: |
|
config = json.load(f) |
|
requires =["batch_size","num_epochs","lr","seq_len","d_model","d_ff","N","h","model_folder","model_basename","preload","tokenizer_file","experiment_name"] |
|
not_includes = [] |
|
for r in requires: |
|
if r not in config: |
|
not_includes.append(r) |
|
if len(not_includes) > 0 : |
|
raise ValueError(f"Field(s) missing in config file : {''.join(not_includes)}") |
|
return config |
|
return { |
|
"batch_size":4, |
|
"num_epochs":30, |
|
"lr":3**-4, |
|
"seq_len":360, |
|
"d_model":512, |
|
"N":6, |
|
"h":8, |
|
"d_ff":2048, |
|
"lang_src":"en", |
|
"lang_tgt":"it", |
|
"model_folder":"weights", |
|
'datasource':"opus_books", |
|
"model_basename":"tmodel_", |
|
"preload":25, |
|
"tokenizer_file":"tokenizer_{0}.json", |
|
"experiment_name":"runs/tmodel", |
|
} |
|
|
|
def get_weights_file_path(config,epoch:str): |
|
model_folder = config['model_folder'] |
|
model_basename = config['model_basename'] |
|
model_filename = f"{model_basename}{epoch}.pt" |
|
return str(Path(os.getcwd()) / model_folder / model_filename) |
|
|
|
|