from pathlib import Path import os import json def get_config(path= None): if path and Path.exists(Path(path)): with open(path,"r") as f: config = json.load(f) requires =["batch_size","num_epochs","lr","seq_len","d_model","d_ff","N","h","model_folder","model_basename","preload","tokenizer_file","experiment_name"] not_includes = [] for r in requires: if r not in config: not_includes.append(r) if len(not_includes) > 0 : raise ValueError(f"Field(s) missing in config file : {''.join(not_includes)}") return config return { "batch_size":4, "num_epochs":30, "lr":3**-4, "seq_len":360, "d_model":512, "N":6, "h":8, "d_ff":2048, "lang_src":"en", "lang_tgt":"it", "model_folder":"weights", 'datasource':"opus_books", "model_basename":"tmodel_", "preload":25, "tokenizer_file":"tokenizer_{0}.json", "experiment_name":"runs/tmodel", } def get_weights_file_path(config,epoch:str): model_folder = config['model_folder'] model_basename = config['model_basename'] model_filename = f"{model_basename}{epoch}.pt" return str(Path(os.getcwd()) / model_folder / model_filename)