tu commited on
Commit
6f6918a
1 Parent(s): 2ac38e0
Files changed (10) hide show
  1. app.py +71 -0
  2. commons.py +163 -0
  3. config.json +62 -0
  4. model.onnx +3 -0
  5. preprocess.py +15 -0
  6. requirements.txt +12 -0
  7. text/__init__.py +72 -0
  8. text/cleaners.py +114 -0
  9. text/symbols.py +16 -0
  10. utils.py +307 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ from gradio import components
4
+ import os
5
+ import torch
6
+ import commons
7
+ import utils
8
+ import numpy as np
9
+ from text import text_to_sequence
10
+ from scipy.io.wavfile import write
11
+ from preprocess import preprocess
12
+ import onnxruntime
13
+
14
+ def get_text(texts, hps):
15
+ text_norm_list = []
16
+ for text in texts.split(","):
17
+ text = preprocess(text)
18
+ chunk_strings = []
19
+ chunk_len = 30
20
+ for i in range(0, len(text.split()), chunk_len):
21
+ chunk = " ".join(text.split()[i:i+chunk_len])
22
+ chunk_strings.append(chunk)
23
+ for chunk_string in chunk_strings:
24
+ text_norm = text_to_sequence(chunk_string, hps.data.text_cleaners)
25
+ if hps.data.add_blank:
26
+ text_norm = commons.intersperse(text_norm, 0)
27
+ text_norm_list.append(torch.LongTensor(text_norm))
28
+ return text_norm_list
29
+
30
+ def tts(text):
31
+ model_path = "model.onnx"
32
+ config_path = "config.json"
33
+ sid = 6
34
+ output_wav_path = "output.wav"
35
+ sess_options = onnxruntime.SessionOptions()
36
+ model = onnxruntime.InferenceSession(str(model_path), sess_options=sess_options, providers=["CPUExecutionProvider"])
37
+
38
+ hps = utils.get_hparams_from_file(config_path)
39
+
40
+ audios = []
41
+
42
+ stn_tst_list = get_text(text, hps)
43
+ for stn_tst in stn_tst_list:
44
+ text = np.expand_dims(np.array(stn_tst, dtype=np.int64), 0)
45
+ text_lengths = np.array([text.shape[1]], dtype=np.int64)
46
+ scales = np.array([0.667, 1.0, 0.8], dtype=np.float32)
47
+ sid = np.array([int(sid)]) if sid is not None else None
48
+
49
+ audio = model.run(
50
+ None,
51
+ {
52
+ "input": text,
53
+ "input_lengths": text_lengths,
54
+ "scales": scales,
55
+ "sid": sid,
56
+ },
57
+ )[0].squeeze((0, 1))
58
+ audios.append(audio)
59
+ audios = np.concatenate(audios, axis=0)
60
+
61
+ write(data=audios, rate=hps.data.sampling_rate, filename=output_wav_path)
62
+ return output_wav_path
63
+
64
+ if __name__ == "__main__":
65
+
66
+ gr.Interface(
67
+ fn=tts,
68
+ inputs=[components.Textbox(label="Text Input")],
69
+ outputs=components.Audio(type='filepath', label="Generated Speech"),
70
+ live=False
71
+ ).launch(server_name="0.0.0.0", server_port=7860)
commons.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += (
34
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
35
+ )
36
+ return kl
37
+
38
+
39
+ def rand_gumbel(shape):
40
+ """Sample from the Gumbel distribution, protect from overflows."""
41
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
42
+ return -torch.log(-torch.log(uniform_samples))
43
+
44
+
45
+ def rand_gumbel_like(x):
46
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
47
+ return g
48
+
49
+
50
+ def slice_segments(x, ids_str, segment_size=4):
51
+ ret = torch.zeros_like(x[:, :, :segment_size])
52
+ for i in range(x.size(0)):
53
+ idx_str = ids_str[i]
54
+ idx_end = idx_str + segment_size
55
+ ret[i] = x[i, :, idx_str:idx_end]
56
+ return ret
57
+
58
+
59
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
60
+ b, d, t = x.size()
61
+ if x_lengths is None:
62
+ x_lengths = t
63
+ ids_str_max = x_lengths - segment_size + 1
64
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
65
+ ret = slice_segments(x, ids_str, segment_size)
66
+ return ret, ids_str
67
+
68
+
69
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
70
+ position = torch.arange(length, dtype=torch.float)
71
+ num_timescales = channels // 2
72
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
73
+ num_timescales - 1
74
+ )
75
+ inv_timescales = min_timescale * torch.exp(
76
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
77
+ )
78
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
79
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
80
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
81
+ signal = signal.view(1, channels, length)
82
+ return signal
83
+
84
+
85
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
86
+ b, channels, length = x.size()
87
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
88
+ return x + signal.to(dtype=x.dtype, device=x.device)
89
+
90
+
91
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
92
+ b, channels, length = x.size()
93
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
94
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
95
+
96
+
97
+ def subsequent_mask(length):
98
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
99
+ return mask
100
+
101
+
102
+ @torch.jit.script
103
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
104
+ n_channels_int = n_channels[0]
105
+ in_act = input_a + input_b
106
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
107
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
108
+ acts = t_act * s_act
109
+ return acts
110
+
111
+
112
+ def convert_pad_shape(pad_shape):
113
+ l = pad_shape[::-1]
114
+ pad_shape = [item for sublist in l for item in sublist]
115
+ return pad_shape
116
+
117
+
118
+ def shift_1d(x):
119
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
120
+ return x
121
+
122
+
123
+ def sequence_mask(length, max_length=None):
124
+ if max_length is None:
125
+ max_length = length.max()
126
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
127
+ return x.unsqueeze(0) < length.unsqueeze(1)
128
+
129
+
130
+ def generate_path(duration, mask):
131
+ """
132
+ duration: [b, 1, t_x]
133
+ mask: [b, 1, t_y, t_x]
134
+ """
135
+ device = duration.device
136
+
137
+ b, _, t_y, t_x = mask.shape
138
+ cum_duration = torch.cumsum(duration, -1)
139
+
140
+ cum_duration_flat = cum_duration.view(b * t_x)
141
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
142
+ path = path.view(b, t_x, t_y)
143
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
144
+ path = path.unsqueeze(1).transpose(2, 3) * mask
145
+ return path
146
+
147
+
148
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
149
+ if isinstance(parameters, torch.Tensor):
150
+ parameters = [parameters]
151
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
152
+ norm_type = float(norm_type)
153
+ if clip_value is not None:
154
+ clip_value = float(clip_value)
155
+
156
+ total_norm = 0
157
+ for p in parameters:
158
+ param_norm = p.grad.data.norm(norm_type)
159
+ total_norm += param_norm.item() ** norm_type
160
+ if clip_value is not None:
161
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
162
+ total_norm = total_norm ** (1.0 / norm_type)
163
+ return total_norm
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 6000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 14,
11
+ "fp16_run": false,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8192,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0
18
+ },
19
+ "data": {
20
+ "use_mel_posterior_encoder": true,
21
+ "training_files":"/home/minhtu/projects/ViSV2TTS/DATA_thu_hue/train.txt",
22
+ "validation_files":"/home/minhtu/projects/ViSV2TTS/DATA_thu_hue/val.txt",
23
+ "text_cleaners":["english_cleaners2"],
24
+ "max_wav_value": 32768.0,
25
+ "sampling_rate": 22050,
26
+ "filter_length": 1024,
27
+ "hop_length": 256,
28
+ "win_length": 1024,
29
+ "n_mel_channels": 80,
30
+ "mel_fmin": 0.0,
31
+ "mel_fmax": null,
32
+ "add_blank": false,
33
+ "n_speakers": 10,
34
+ "cleaned_text": true
35
+ },
36
+ "model": {
37
+ "use_mel_posterior_encoder": true,
38
+ "use_transformer_flows": true,
39
+ "transformer_flow_type": "pre_conv",
40
+ "use_spk_conditioned_encoder": true,
41
+ "use_noise_scaled_mas": true,
42
+ "use_duration_discriminator": true,
43
+ "inter_channels": 192,
44
+ "hidden_channels": 192,
45
+ "filter_channels": 768,
46
+ "n_heads": 2,
47
+ "n_layers": 6,
48
+ "kernel_size": 3,
49
+ "p_dropout": 0.1,
50
+ "resblock": "1",
51
+ "resblock_kernel_sizes": [3,7,11],
52
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
53
+ "upsample_rates": [8,8,2,2],
54
+ "upsample_initial_channel": 512,
55
+ "upsample_kernel_sizes": [16,16,4,4],
56
+ "n_layers_q": 3,
57
+ "use_spectral_norm": false,
58
+ "use_sdp": false,
59
+ "gin_channels": 256
60
+ }
61
+ }
62
+
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2337aaa17027eba030c8d37df85da9bf4bd9accf7efc61d006efc774e73086f3
3
+ size 123389245
preprocess.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import VietnameseTextNormalizer
2
+
3
+ def preprocess(text):
4
+ text = text.lower().split("\n") # text = VietnameseTextNormalizer.Normalize(text.lower()).split("\n")
5
+ lines = []
6
+ for t in text:
7
+ t = t.replace(".", " . ")
8
+ t = t.replace("!", " . ")
9
+ t = t.replace("?", " . ")
10
+ t = t.replace(",", " , ")
11
+
12
+ t = t.rstrip().strip()
13
+ lines.append(t)
14
+ return ' . '.join(lines)
15
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ eng_to_ipa==0.0.2
2
+ gradio==4.26.0
3
+ matplotlib==3.8.4
4
+ numpy==1.26.4
5
+ onnxruntime==1.17.3
6
+ scipy==1.13.0
7
+ torch==2.2.2
8
+ Unidecode==1.3.8
9
+ vinorm==2.0.7
10
+ underthesea==6.8.0
11
+ viphoneme==3.0.0
12
+ vinorm==2.0.7
text/__init__.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+ #from text.symbols import symbols
4
+ from viphoneme import syms, vi2IPA_split
5
+
6
+ symbols = syms
7
+
8
+ # Mappings from symbol to numeric ID and vice versa:
9
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
10
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
11
+
12
+
13
+ def text_to_sequence(text, cleaner_names):
14
+
15
+ sequence = []
16
+ text = text.replace('\s+',' ').lower()
17
+ phon = vi2IPA_split(text,"/")
18
+ phon = phon.split("/")[1:]
19
+
20
+ eol = -1
21
+ for i,p in reversed(list(enumerate(phon))):
22
+ if p not in ["..",""," ","."," "]:
23
+ eol = i
24
+ break
25
+ phones = phon[:i+1]+[" ","."]
26
+ phones_id =[]
27
+ for i in phones:
28
+ if i in _symbol_to_id:
29
+ phones_id.append(_symbol_to_id[i])
30
+ #phones_id = [_symbol_to_id[i] for i in phones]
31
+ sequence.extend(phones_id)
32
+
33
+ return sequence
34
+
35
+
36
+ def cleaned_text_to_sequence(cleaned_text):
37
+
38
+ sequence = []
39
+ phon = cleaned_text.split("/")[1:]
40
+
41
+ eol = -1
42
+ for i,p in reversed(list(enumerate(phon))):
43
+ if p not in ["..",""," ","."," "]:
44
+ eol = i
45
+ break
46
+ phones = phon[:i+1]+[" ","."]
47
+ phones_id =[]
48
+ for i in phones:
49
+ if i in _symbol_to_id:
50
+ phones_id.append(_symbol_to_id[i])
51
+ #phones_id = [_symbol_to_id[i] for i in phones]
52
+ sequence.extend(phones_id)
53
+
54
+ return sequence
55
+
56
+
57
+ def sequence_to_text(sequence):
58
+
59
+ result = ''
60
+ for symbol_id in sequence:
61
+ if symbol_id in _id_to_symbol:
62
+ result += _id_to_symbol[symbol_id]
63
+ return result
64
+
65
+
66
+ def _clean_text(text, cleaner_names):
67
+ for name in cleaner_names:
68
+ cleaner = getattr(cleaners, name)
69
+ if not cleaner:
70
+ raise Exception("Unknown cleaner: %s" % name)
71
+ text = cleaner(text)
72
+ return text
text/cleaners.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ """
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ """
14
+
15
+ import re
16
+ from unidecode import unidecode
17
+ #from phonemizer import phonemize
18
+ #from phonemizer.backend import EspeakBackend
19
+ #backend = EspeakBackend("vi", preserve_punctuation=True, with_stress=True)
20
+
21
+
22
+ # Regular expression matching whitespace:
23
+ _whitespace_re = re.compile(r"\s+")
24
+
25
+ # List of (regular expression, replacement) pairs for abbreviations:
26
+ _abbreviations = [
27
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
28
+ for x in [
29
+ ("1", "một"),
30
+ ("2", "hai"),
31
+ ("3", "ba"),
32
+ ("4", "bốn"),
33
+ ("5", "năm"),
34
+ ("6", "sáu"),
35
+ ("7", "bảy"),
36
+ ("8", "tám"),
37
+ ("9", "chín"),
38
+ ("10", "mười")
39
+ ]
40
+ ]
41
+
42
+
43
+ def expand_abbreviations(text):
44
+ for regex, replacement in _abbreviations:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def expand_numbers(text):
50
+ return normalize_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ return text.lower()
55
+
56
+
57
+ def collapse_whitespace(text):
58
+ return re.sub(_whitespace_re, " ", text)
59
+
60
+
61
+ def convert_to_ascii(text):
62
+ return unidecode(text)
63
+
64
+
65
+ def basic_cleaners(text):
66
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
67
+ text = lowercase(text)
68
+ text = collapse_whitespace(text)
69
+ return text
70
+
71
+
72
+ def transliteration_cleaners(text):
73
+ """Pipeline for non-English text that transliterates to ASCII."""
74
+ text = convert_to_ascii(text)
75
+ text = lowercase(text)
76
+ text = collapse_whitespace(text)
77
+ return text
78
+
79
+
80
+ def english_cleaners(text):
81
+ """Pipeline for English text, including abbreviation expansion."""
82
+ text = convert_to_ascii(text)
83
+ text = lowercase(text)
84
+ text = expand_abbreviations(text)
85
+ phonemes = phonemize(text, language="vi", backend="espeak", strip=True)
86
+ phonemes = collapse_whitespace(phonemes)
87
+ return phonemes
88
+
89
+
90
+ def english_cleaners2(text):
91
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
92
+ text = convert_to_ascii(text)
93
+ text = lowercase(text)
94
+ text = expand_abbreviations(text)
95
+ phonemes = phonemize(
96
+ text,
97
+ language="vi",
98
+ backend="espeak",
99
+ strip=True,
100
+ preserve_punctuation=True,
101
+ with_stress=True,
102
+ )
103
+ phonemes = collapse_whitespace(phonemes)
104
+ return phonemes
105
+
106
+
107
+ def english_cleaners3(text):
108
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
109
+ text = convert_to_ascii(text)
110
+ text = lowercase(text)
111
+ text = expand_abbreviations(text)
112
+ phonemes = backend.phonemize([text], strip=True)[0]
113
+ phonemes = collapse_whitespace(phonemes)
114
+ return phonemes
text/symbols.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ """
4
+ Defines the set of symbols used in text input to the model.
5
+ """
6
+ _pad = "_"
7
+ _punctuation = ';:,.!?¡¿—…"«»“” '
8
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
9
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
10
+
11
+
12
+ # Export all symbols:
13
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
14
+
15
+ # Special symbol ids
16
+ SPACE_ID = symbols.index(" ")
utils.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
21
+ iteration = checkpoint_dict["iteration"]
22
+ learning_rate = checkpoint_dict["learning_rate"]
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
25
+ saved_state_dict = checkpoint_dict["model"]
26
+ if hasattr(model, "module"):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict = {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, "module"):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info(
42
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
43
+ )
44
+ return model, optimizer, learning_rate, iteration
45
+
46
+
47
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
48
+ logger.info(
49
+ "Saving model and optimizer state at iteration {} to {}".format(
50
+ iteration, checkpoint_path
51
+ )
52
+ )
53
+ if hasattr(model, "module"):
54
+ state_dict = model.module.state_dict()
55
+ else:
56
+ state_dict = model.state_dict()
57
+ torch.save(
58
+ {
59
+ "model": state_dict,
60
+ "iteration": iteration,
61
+ "optimizer": optimizer.state_dict(),
62
+ "learning_rate": learning_rate,
63
+ },
64
+ checkpoint_path,
65
+ )
66
+
67
+
68
+ def summarize(
69
+ writer,
70
+ global_step,
71
+ scalars={},
72
+ histograms={},
73
+ images={},
74
+ audios={},
75
+ audio_sampling_rate=22050,
76
+ ):
77
+ for k, v in scalars.items():
78
+ writer.add_scalar(k, v, global_step)
79
+ for k, v in histograms.items():
80
+ writer.add_histogram(k, v, global_step)
81
+ for k, v in images.items():
82
+ writer.add_image(k, v, global_step, dataformats="HWC")
83
+ for k, v in audios.items():
84
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
85
+
86
+
87
+ def scan_checkpoint(dir_path, regex):
88
+ f_list = glob.glob(os.path.join(dir_path, regex))
89
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
90
+ if len(f_list) == 0:
91
+ return None
92
+ return f_list
93
+
94
+
95
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
96
+ f_list = scan_checkpoint(dir_path, regex)
97
+ if not f_list:
98
+ return None
99
+ x = f_list[-1]
100
+ print(x)
101
+ return x
102
+
103
+
104
+ def remove_old_checkpoints(cp_dir, prefixes=['G_*.pth', 'D_*.pth', 'DUR_*.pth']):
105
+ for prefix in prefixes:
106
+ sorted_ckpts = scan_checkpoint(cp_dir, prefix)
107
+ if sorted_ckpts and len(sorted_ckpts) > 3:
108
+ for ckpt_path in sorted_ckpts[:-3]:
109
+ os.remove(ckpt_path)
110
+ print("removed {}".format(ckpt_path))
111
+
112
+
113
+ def plot_spectrogram_to_numpy(spectrogram):
114
+ global MATPLOTLIB_FLAG
115
+ if not MATPLOTLIB_FLAG:
116
+ import matplotlib
117
+
118
+ matplotlib.use("Agg")
119
+ MATPLOTLIB_FLAG = True
120
+ mpl_logger = logging.getLogger("matplotlib")
121
+ mpl_logger.setLevel(logging.WARNING)
122
+ import matplotlib.pylab as plt
123
+ import numpy as np
124
+
125
+ fig, ax = plt.subplots(figsize=(10, 2))
126
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
127
+ plt.colorbar(im, ax=ax)
128
+ plt.xlabel("Frames")
129
+ plt.ylabel("Channels")
130
+ plt.tight_layout()
131
+
132
+ fig.canvas.draw()
133
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
134
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
135
+ plt.close()
136
+ return data
137
+
138
+
139
+ def plot_alignment_to_numpy(alignment, info=None):
140
+ global MATPLOTLIB_FLAG
141
+ if not MATPLOTLIB_FLAG:
142
+ import matplotlib
143
+
144
+ matplotlib.use("Agg")
145
+ MATPLOTLIB_FLAG = True
146
+ mpl_logger = logging.getLogger("matplotlib")
147
+ mpl_logger.setLevel(logging.WARNING)
148
+ import matplotlib.pylab as plt
149
+ import numpy as np
150
+
151
+ fig, ax = plt.subplots(figsize=(6, 4))
152
+ im = ax.imshow(
153
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
154
+ )
155
+ fig.colorbar(im, ax=ax)
156
+ xlabel = "Decoder timestep"
157
+ if info is not None:
158
+ xlabel += "\n\n" + info
159
+ plt.xlabel(xlabel)
160
+ plt.ylabel("Encoder timestep")
161
+ plt.tight_layout()
162
+
163
+ fig.canvas.draw()
164
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
165
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
166
+ plt.close()
167
+ return data
168
+
169
+
170
+ def load_wav_to_torch(full_path):
171
+ sampling_rate, data = read(full_path)
172
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
173
+
174
+
175
+ def load_filepaths_and_text(filename, split="|"):
176
+ filepaths_and_text = []
177
+ with open(filename, encoding="utf-8") as f:
178
+ for line in f:
179
+ line = line.rstrip().strip().split(split)
180
+ filepaths_and_text.append(line)
181
+ return filepaths_and_text
182
+
183
+
184
+ def get_hparams(init=True):
185
+ parser = argparse.ArgumentParser()
186
+ parser.add_argument(
187
+ "-c",
188
+ "--config",
189
+ type=str,
190
+ default="./configs/base.json",
191
+ help="JSON file for configuration",
192
+ )
193
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
194
+
195
+ args = parser.parse_args()
196
+ model_dir = os.path.join("./logs", args.model)
197
+
198
+ if not os.path.exists(model_dir):
199
+ os.makedirs(model_dir)
200
+
201
+ config_path = args.config
202
+ config_save_path = os.path.join(model_dir, "config.json")
203
+ if init:
204
+ with open(config_path, "r") as f:
205
+ data = f.read()
206
+ with open(config_save_path, "w") as f:
207
+ f.write(data)
208
+ else:
209
+ with open(config_save_path, "r") as f:
210
+ data = f.read()
211
+ config = json.loads(data)
212
+
213
+ hparams = HParams(**config)
214
+ hparams.model_dir = model_dir
215
+ return hparams
216
+
217
+
218
+ def get_hparams_from_dir(model_dir):
219
+ config_save_path = os.path.join(model_dir, "config.json")
220
+ with open(config_save_path, "r") as f:
221
+ data = f.read()
222
+ config = json.loads(data)
223
+
224
+ hparams = HParams(**config)
225
+ hparams.model_dir = model_dir
226
+ return hparams
227
+
228
+
229
+ def get_hparams_from_file(config_path):
230
+ with open(config_path, "r") as f:
231
+ data = f.read()
232
+ config = json.loads(data)
233
+
234
+ hparams = HParams(**config)
235
+ return hparams
236
+
237
+
238
+ def check_git_hash(model_dir):
239
+ source_dir = os.path.dirname(os.path.realpath(__file__))
240
+ if not os.path.exists(os.path.join(source_dir, ".git")):
241
+ logger.warn(
242
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
243
+ source_dir
244
+ )
245
+ )
246
+ return
247
+
248
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
249
+
250
+ path = os.path.join(model_dir, "githash")
251
+ if os.path.exists(path):
252
+ saved_hash = open(path).read()
253
+ if saved_hash != cur_hash:
254
+ logger.warn(
255
+ "git hash values are different. {}(saved) != {}(current)".format(
256
+ saved_hash[:8], cur_hash[:8]
257
+ )
258
+ )
259
+ else:
260
+ open(path, "w").write(cur_hash)
261
+
262
+
263
+ def get_logger(model_dir, filename="train.log"):
264
+ global logger
265
+ logger = logging.getLogger(os.path.basename(model_dir))
266
+ logger.setLevel(logging.DEBUG)
267
+
268
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
269
+ if not os.path.exists(model_dir):
270
+ os.makedirs(model_dir)
271
+ h = logging.FileHandler(os.path.join(model_dir, filename))
272
+ h.setLevel(logging.DEBUG)
273
+ h.setFormatter(formatter)
274
+ logger.addHandler(h)
275
+ return logger
276
+
277
+
278
+ class HParams:
279
+ def __init__(self, **kwargs):
280
+ for k, v in kwargs.items():
281
+ if type(v) == dict:
282
+ v = HParams(**v)
283
+ self[k] = v
284
+
285
+ def keys(self):
286
+ return self.__dict__.keys()
287
+
288
+ def items(self):
289
+ return self.__dict__.items()
290
+
291
+ def values(self):
292
+ return self.__dict__.values()
293
+
294
+ def __len__(self):
295
+ return len(self.__dict__)
296
+
297
+ def __getitem__(self, key):
298
+ return getattr(self, key)
299
+
300
+ def __setitem__(self, key, value):
301
+ return setattr(self, key, value)
302
+
303
+ def __contains__(self, key):
304
+ return key in self.__dict__
305
+
306
+ def __repr__(self):
307
+ return self.__dict__.__repr__()