SocialAISchool / utils /storage.py
grg's picture
Cleaned old git history
be5548b
import csv
import os
import torch
import logging
import sys
from pathlib import Path
import utils
def create_folders_if_necessary(path):
dirname = os.path.dirname(path)
if not os.path.isdir(dirname):
os.makedirs(dirname)
def get_storage_dir():
if "RL_STORAGE" in os.environ:
return os.environ["RL_STORAGE"]
return "storage"
def get_model_dir(model_name):
return os.path.join(get_storage_dir(), model_name)
def get_status_path(model_dir, num_frames=None):
if num_frames:
return os.path.join(model_dir, "status_{}.pt".format(num_frames))
return os.path.join(model_dir, "status.pt")
def get_model_path(model_dir, num_frames=None):
if num_frames:
return os.path.join(model_dir, "model_{}.pt".format(num_frames))
return os.path.join(model_dir, "model.pt")
def load_status(status_path):
return torch.load(status_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
def get_status(model_dir, num_frames=None):
path = get_status_path(model_dir, num_frames)
# return torch.load(path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
return load_status(path)
def save_status(status, model_dir, num_frames=None):
path = get_status_path(model_dir, num_frames)
utils.create_folders_if_necessary(path)
torch.save(status, path)
def save_model(model, model_dir, num_frames=None):
path = get_model_path(model_dir, num_frames)
utils.create_folders_if_necessary(path)
torch.save(model, path)
def load_model(model_name, raise_not_found=True):
path = get_model_path(model_name)
try:
if torch.cuda.is_available():
model = torch.load(path)
else:
model = torch.load(path, map_location=torch.device("cpu"))
model.eval()
return model
except FileNotFoundError:
if raise_not_found:
raise FileNotFoundError("No model found at {}".format(path))
def get_vocab(model_dir):
return get_status(model_dir)["vocab"]
def get_model_state(model_dir):
return get_status(model_dir)["model_state"]
def get_txt_logger(model_dir):
path = os.path.join(model_dir, "log.txt")
utils.create_folders_if_necessary(path)
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[
logging.FileHandler(filename=path),
logging.StreamHandler(sys.stdout)
]
)
return logging.getLogger()
def get_csv_logger(model_dir):
csv_path = os.path.join(model_dir, "log.csv")
utils.create_folders_if_necessary(csv_path)
csv_file = open(csv_path, "a")
return csv_file, csv.writer(csv_file)