File size: 1,939 Bytes
df07554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import numpy as np
import shutil
import resource
import options as opt

from helpers import *
from datetime import datetime as Datetime
from tensorboardX import SummaryWriter

rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(
    resource.RLIMIT_NOFILE, (65536, rlimit[1])
)


class BaseTrainer(object):
    def __init__(self, name='M', base_dir=''):
        self.name = name
        self.base_dir = base_dir

        self.date_stamp = self.make_date_stamp()
        self.save_name = f'{self.name}-{self.date_stamp}'
        self.weights_dir = None
        self.log_dir = None
        self.writer = None

    @staticmethod
    def get_dataset_kwargs(
        shared_dict=None, base_dir='',
        char_map=opt.char_map, **kwargs
    ):
        return kwargify(
            video_path=opt.video_path,
            shared_dict=shared_dict,
            alignments_dir=opt.alignments_dir,
            vid_pad=opt.vid_padding,
            image_dir=opt.images_dir,
            txt_pad=opt.txt_padding,
            phonemes_dir=opt.phonemes_dir,
            frame_doubling=opt.frame_doubling,
            char_map=char_map,
            base_dir=base_dir,
            **kwargs
        )

    def init_tensorboard(self):
        self.log_dir = f'runs/{self.save_name}'
        self.weights_dir = f'weights/{self.save_name}'

        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
        if not os.path.exists(self.weights_dir):
            os.mkdir(self.weights_dir)

        self.writer = SummaryWriter(self.log_dir)
        # save current state of options file
        shutil.copyfile(
            'options.py', os.path.join(self.log_dir, 'options.py')
        )

    @staticmethod
    def make_date_stamp():
        return Datetime.now().strftime("%y%m%d-%H%M")

    def log_scalar(self, name, value, iterations, label):
        self.writer.add_scalars(name, {label: value}, iterations)