torchnet / BaseTrainer.py
milselarch's picture
push to main
df07554
raw
history blame contribute delete
No virus
1.94 kB
import os
import numpy as np
import shutil
import resource
import options as opt
from helpers import *
from datetime import datetime as Datetime
from tensorboardX import SummaryWriter
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(
resource.RLIMIT_NOFILE, (65536, rlimit[1])
)
class BaseTrainer(object):
def __init__(self, name='M', base_dir=''):
self.name = name
self.base_dir = base_dir
self.date_stamp = self.make_date_stamp()
self.save_name = f'{self.name}-{self.date_stamp}'
self.weights_dir = None
self.log_dir = None
self.writer = None
@staticmethod
def get_dataset_kwargs(
shared_dict=None, base_dir='',
char_map=opt.char_map, **kwargs
):
return kwargify(
video_path=opt.video_path,
shared_dict=shared_dict,
alignments_dir=opt.alignments_dir,
vid_pad=opt.vid_padding,
image_dir=opt.images_dir,
txt_pad=opt.txt_padding,
phonemes_dir=opt.phonemes_dir,
frame_doubling=opt.frame_doubling,
char_map=char_map,
base_dir=base_dir,
**kwargs
)
def init_tensorboard(self):
self.log_dir = f'runs/{self.save_name}'
self.weights_dir = f'weights/{self.save_name}'
if not os.path.exists(self.log_dir):
os.mkdir(self.log_dir)
if not os.path.exists(self.weights_dir):
os.mkdir(self.weights_dir)
self.writer = SummaryWriter(self.log_dir)
# save current state of options file
shutil.copyfile(
'options.py', os.path.join(self.log_dir, 'options.py')
)
@staticmethod
def make_date_stamp():
return Datetime.now().strftime("%y%m%d-%H%M")
def log_scalar(self, name, value, iterations, label):
self.writer.add_scalars(name, {label: value}, iterations)