Spaces:
Runtime error
Runtime error
''' | |
author: wayn391@mastertones | |
''' | |
import datetime | |
import os | |
import time | |
import matplotlib.pyplot as plt | |
import torch | |
import yaml | |
from torch.utils.tensorboard import SummaryWriter | |
class Saver(object): | |
def __init__( | |
self, | |
args, | |
initial_global_step=-1): | |
self.expdir = args.env.expdir | |
self.sample_rate = args.data.sampling_rate | |
# cold start | |
self.global_step = initial_global_step | |
self.init_time = time.time() | |
self.last_time = time.time() | |
# makedirs | |
os.makedirs(self.expdir, exist_ok=True) | |
# path | |
self.path_log_info = os.path.join(self.expdir, 'log_info.txt') | |
# ckpt | |
os.makedirs(self.expdir, exist_ok=True) | |
# writer | |
self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) | |
# save config | |
path_config = os.path.join(self.expdir, 'config.yaml') | |
with open(path_config, "w") as out_config: | |
yaml.dump(dict(args), out_config) | |
def log_info(self, msg): | |
'''log method''' | |
if isinstance(msg, dict): | |
msg_list = [] | |
for k, v in msg.items(): | |
tmp_str = '' | |
if isinstance(v, int): | |
tmp_str = '{}: {:,}'.format(k, v) | |
else: | |
tmp_str = '{}: {}'.format(k, v) | |
msg_list.append(tmp_str) | |
msg_str = '\n'.join(msg_list) | |
else: | |
msg_str = msg | |
# dsplay | |
print(msg_str) | |
# save | |
with open(self.path_log_info, 'a') as fp: | |
fp.write(msg_str+'\n') | |
def log_value(self, dict): | |
for k, v in dict.items(): | |
self.writer.add_scalar(k, v, self.global_step) | |
def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): | |
spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) | |
spec = spec_cat[0] | |
if isinstance(spec, torch.Tensor): | |
spec = spec.cpu().numpy() | |
fig = plt.figure(figsize=(12, 9)) | |
plt.pcolor(spec.T, vmin=vmin, vmax=vmax) | |
plt.tight_layout() | |
self.writer.add_figure(name, fig, self.global_step) | |
def log_audio(self, dict): | |
for k, v in dict.items(): | |
self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) | |
def get_interval_time(self, update=True): | |
cur_time = time.time() | |
time_interval = cur_time - self.last_time | |
if update: | |
self.last_time = cur_time | |
return time_interval | |
def get_total_time(self, to_str=True): | |
total_time = time.time() - self.init_time | |
if to_str: | |
total_time = str(datetime.timedelta( | |
seconds=total_time))[:-5] | |
return total_time | |
def save_model( | |
self, | |
model, | |
optimizer, | |
name='model', | |
postfix='', | |
to_json=False): | |
# path | |
if postfix: | |
postfix = '_' + postfix | |
path_pt = os.path.join( | |
self.expdir , name+postfix+'.pt') | |
# check | |
print(' [*] model checkpoint saved: {}'.format(path_pt)) | |
# save | |
if optimizer is not None: | |
torch.save({ | |
'global_step': self.global_step, | |
'model': model.state_dict(), | |
'optimizer': optimizer.state_dict()}, path_pt) | |
else: | |
torch.save({ | |
'global_step': self.global_step, | |
'model': model.state_dict()}, path_pt) | |
def delete_model(self, name='model', postfix=''): | |
# path | |
if postfix: | |
postfix = '_' + postfix | |
path_pt = os.path.join( | |
self.expdir , name+postfix+'.pt') | |
# delete | |
if os.path.exists(path_pt): | |
os.remove(path_pt) | |
print(' [*] model checkpoint deleted: {}'.format(path_pt)) | |
def global_step_increment(self): | |
self.global_step += 1 | |