Spaces:
Running
Running
File size: 2,712 Bytes
be5548b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import csv
import os
import torch
import logging
import sys
from pathlib import Path
import utils
def create_folders_if_necessary(path):
dirname = os.path.dirname(path)
if not os.path.isdir(dirname):
os.makedirs(dirname)
def get_storage_dir():
if "RL_STORAGE" in os.environ:
return os.environ["RL_STORAGE"]
return "storage"
def get_model_dir(model_name):
return os.path.join(get_storage_dir(), model_name)
def get_status_path(model_dir, num_frames=None):
if num_frames:
return os.path.join(model_dir, "status_{}.pt".format(num_frames))
return os.path.join(model_dir, "status.pt")
def get_model_path(model_dir, num_frames=None):
if num_frames:
return os.path.join(model_dir, "model_{}.pt".format(num_frames))
return os.path.join(model_dir, "model.pt")
def load_status(status_path):
return torch.load(status_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
def get_status(model_dir, num_frames=None):
path = get_status_path(model_dir, num_frames)
# return torch.load(path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
return load_status(path)
def save_status(status, model_dir, num_frames=None):
path = get_status_path(model_dir, num_frames)
utils.create_folders_if_necessary(path)
torch.save(status, path)
def save_model(model, model_dir, num_frames=None):
path = get_model_path(model_dir, num_frames)
utils.create_folders_if_necessary(path)
torch.save(model, path)
def load_model(model_name, raise_not_found=True):
path = get_model_path(model_name)
try:
if torch.cuda.is_available():
model = torch.load(path)
else:
model = torch.load(path, map_location=torch.device("cpu"))
model.eval()
return model
except FileNotFoundError:
if raise_not_found:
raise FileNotFoundError("No model found at {}".format(path))
def get_vocab(model_dir):
return get_status(model_dir)["vocab"]
def get_model_state(model_dir):
return get_status(model_dir)["model_state"]
def get_txt_logger(model_dir):
path = os.path.join(model_dir, "log.txt")
utils.create_folders_if_necessary(path)
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[
logging.FileHandler(filename=path),
logging.StreamHandler(sys.stdout)
]
)
return logging.getLogger()
def get_csv_logger(model_dir):
csv_path = os.path.join(model_dir, "log.csv")
utils.create_folders_if_necessary(csv_path)
csv_file = open(csv_path, "a")
return csv_file, csv.writer(csv_file)
|