File size: 3,878 Bytes
2493d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import datetime
import importlib
import pickle
import numpy as np
import tensorflow as tf


def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
    state = {
        'model': model.weights,
        'optimizer': optimizer,
        'step': current_step,
        'epoch': epoch,
        'date': datetime.date.today().strftime("%B %d, %Y"),
        'r': r
    }
    state.update(kwargs)
    pickle.dump(state, open(output_path, 'wb'))


def load_checkpoint(model, checkpoint_path):
    checkpoint = pickle.load(open(checkpoint_path, 'rb'))
    chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
    tf_vars = model.weights
    for tf_var in tf_vars:
        layer_name = tf_var.name
        try:
            chkp_var_value = chkp_var_dict[layer_name]
        except KeyError:
            class_name = list(chkp_var_dict.keys())[0].split("/")[0]
            layer_name = f"{class_name}/{layer_name}"
            chkp_var_value = chkp_var_dict[layer_name]

        tf.keras.backend.set_value(tf_var, chkp_var_value)
    if 'r' in checkpoint.keys():
        model.decoder.set_r(checkpoint['r'])
    return model


def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.max()
    batch_size = sequence_length.size(0)
    seq_range = np.empty([0, max_len], dtype=np.int8)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (
        sequence_length.unsqueeze(1).expand_as(seq_range_expand))
    # B x T_max
    return seq_range_expand < seq_length_expand


# @tf.custom_gradient
def check_gradient(x, grad_clip):
    x_normed = tf.clip_by_norm(x, grad_clip)
    grad_norm = tf.norm(grad_clip)
    return x_normed, grad_norm


def count_parameters(model, c):
    try:
        return model.count_params()
    except RuntimeError:
        input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32'))
        input_lengths = np.random.randint(100, 129, (8, ))
        input_lengths[-1] = 128
        input_lengths = tf.convert_to_tensor(input_lengths.astype('int32'))
        mel_spec = np.random.rand(8, 2 * c.r,
                                  c.audio['num_mels']).astype('float32')
        mel_spec = tf.convert_to_tensor(mel_spec)
        speaker_ids = np.random.randint(
            0, 5, (8, )) if c.use_speaker_embedding else None
        _ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids)
        return model.count_params()


def setup_model(num_chars, num_speakers, c, enable_tflite=False):
    print(" > Using model: {}".format(c.model))
    MyModel = importlib.import_module('TTS.tts.tf.models.' + c.model.lower())
    MyModel = getattr(MyModel, c.model)
    if c.model.lower() in "tacotron":
        raise NotImplementedError(' [!] Tacotron model is not ready.')
    # tacotron2
    model = MyModel(num_chars=num_chars,
                    num_speakers=num_speakers,
                    r=c.r,
                    postnet_output_dim=c.audio['num_mels'],
                    decoder_output_dim=c.audio['num_mels'],
                    attn_type=c.attention_type,
                    attn_win=c.windowing,
                    attn_norm=c.attention_norm,
                    prenet_type=c.prenet_type,
                    prenet_dropout=c.prenet_dropout,
                    forward_attn=c.use_forward_attn,
                    trans_agent=c.transition_agent,
                    forward_attn_mask=c.forward_attn_mask,
                    location_attn=c.location_attn,
                    attn_K=c.attention_heads,
                    separate_stopnet=c.separate_stopnet,
                    bidirectional_decoder=c.bidirectional_decoder,
                    enable_tflite=enable_tflite)
    return model