| import tensorflow as tf |
| import numpy as np |
|
|
| class LocationSensitiveAttention(tf.keras.layers.Layer): |
| def __init__(self, attn_dim=128, attn_filters=32, attn_kernel=31, **kw): |
| super().__init__(**kw) |
| self.attn_dim = attn_dim |
| self.attn_filters = attn_filters |
| self.attn_kernel = attn_kernel |
|
|
| def build(self, input_shape): |
| d = self.attn_dim |
| self.W_query = tf.keras.layers.Dense(d, use_bias=False) |
| self.W_memory = tf.keras.layers.Dense(d, use_bias=False) |
| self.loc_conv = tf.keras.layers.Conv1D( |
| self.attn_filters, self.attn_kernel, padding='same', use_bias=False) |
| self.W_loc = tf.keras.layers.Dense(d, use_bias=False) |
| self.v = tf.keras.layers.Dense(1, use_bias=False) |
| self.b = self.add_weight(name='attn_bias', shape=[d], |
| initializer='zeros', trainable=True) |
| super().build(input_shape) |
|
|
| def call(self, query, memory, prev_weights): |
| q = self.W_query(tf.expand_dims(query, 1)) |
| m = self.W_memory(memory) |
| loc = self.loc_conv(tf.expand_dims(prev_weights, -1)) |
| loc = self.W_loc(loc) |
| e = self.v(tf.nn.tanh(q + m + loc + self.b)) |
| e = tf.squeeze(e, -1) |
| w = tf.nn.softmax(e, axis=-1) |
| ctx = tf.reduce_sum(tf.expand_dims(w, -1) * memory, 1) |
| return ctx, w |
|
|
| class Prenet(tf.keras.layers.Layer): |
| def __init__(self, units=256, **kw): |
| super().__init__(**kw) |
| self.fc1 = tf.keras.layers.Dense(units, activation='relu') |
| self.fc2 = tf.keras.layers.Dense(units, activation='relu') |
| self.drop1 = tf.keras.layers.Dropout(0.5) |
| self.drop2 = tf.keras.layers.Dropout(0.5) |
|
|
| def call(self, x, **_): |
| x = self.drop1(self.fc1(x), training=True) |
| x = self.drop2(self.fc2(x), training=True) |
| return x |
|
|
| class ConvBN(tf.keras.layers.Layer): |
| def __init__(self, filters, kernel=5, drop=0.5, **kw): |
| super().__init__(**kw) |
| self.conv = tf.keras.layers.Conv1D(filters, kernel, padding='same') |
| self.bn = tf.keras.layers.BatchNormalization() |
| self.dropout = tf.keras.layers.Dropout(drop) |
|
|
| def call(self, x, training=False): |
| return self.dropout(tf.nn.relu(self.bn(self.conv(x), training=training)), |
| training=training) |
|
|
| class Encoder(tf.keras.layers.Layer): |
| def __init__(self, vocab_size, emb_dim=512, enc_dim=512, |
| n_conv=3, conv_k=5, drop=0.5, **kw): |
| super().__init__(**kw) |
| self.emb = tf.keras.layers.Embedding(vocab_size, emb_dim) |
| self.convs = [ConvBN(emb_dim, conv_k, drop) for _ in range(n_conv)] |
| self.bilstm = tf.keras.layers.Bidirectional( |
| tf.keras.layers.LSTM(enc_dim // 2, return_sequences=True), |
| merge_mode='concat') |
|
|
| def call(self, x, training=False): |
| x = self.emb(x) |
| for c in self.convs: |
| x = c(x, training=training) |
| return self.bilstm(x) |
|
|
| class PostNet(tf.keras.layers.Layer): |
| def __init__(self, n_mels=80, dim=512, n_layers=5, k=5, drop=0.5, **kw): |
| super().__init__(**kw) |
| self.layers_list = [] |
| for i in range(n_layers): |
| out = n_mels if i == n_layers - 1 else dim |
| self.layers_list.append(( |
| tf.keras.layers.Conv1D(out, k, padding='same'), |
| tf.keras.layers.BatchNormalization(), |
| tf.keras.layers.Dropout(drop), |
| i == n_layers - 1)) |
|
|
| def call(self, x, training=False): |
| h = x |
| for conv, bn, drop, last in self.layers_list: |
| h = drop((lambda v: v if last else tf.nn.tanh(v))( |
| bn(conv(h), training=training)), training=training) |
| return x + h |
|
|
| class Decoder(tf.keras.layers.Layer): |
| def __init__(self, n_mels=80, dec_dim=1024, attn_dim=128, |
| prenet_dim=256, max_steps=1000, **kw): |
| super().__init__(**kw) |
| self.n_mels = n_mels |
| self.dec_dim = dec_dim |
| self.max_steps = max_steps |
| self.prenet = Prenet(prenet_dim) |
| self.attention = LocationSensitiveAttention(attn_dim) |
| self.lstm1 = tf.keras.layers.LSTMCell(dec_dim) |
| self.lstm2 = tf.keras.layers.LSTMCell(dec_dim) |
| self.mel_proj = tf.keras.layers.Dense(n_mels) |
| self.stop_proj = tf.keras.layers.Dense(1) |
| self.attn_proj = tf.keras.layers.Dense(dec_dim) |
|
|
| def _init_state(self, enc): |
| B = tf.shape(enc)[0] |
| T = tf.shape(enc)[1] |
| enc_dim = tf.shape(enc)[2] |
| return { |
| 'attn_w': tf.zeros([B, T]), |
| 's1': [tf.zeros([B, self.dec_dim]), tf.zeros([B, self.dec_dim])], |
| 's2': [tf.zeros([B, self.dec_dim]), tf.zeros([B, self.dec_dim])], |
| 'ctx': tf.zeros([B, enc_dim]), |
| 'mel': tf.zeros([B, self.n_mels]), |
| } |
|
|
| def _step(self, inp, enc, st): |
| p = self.prenet(inp) |
| x = tf.concat([p, st['ctx']], -1) |
| h1, s1 = self.lstm1(x, st['s1']) |
| h2, s2 = self.lstm2(h1, st['s2']) |
| ctx, w = self.attention(h2, enc, st['attn_w']) |
| out = self.attn_proj(tf.concat([h2, ctx], -1)) |
| mel = self.mel_proj(out) |
| stop = self.stop_proj(out) |
| return mel, stop, w, {'attn_w': w, 's1': s1, 's2': s2, 'ctx': ctx, 'mel': mel} |
|
|
| def call(self, enc, mel_tgt=None, training=False): |
| st = self._init_state(enc) |
| mels, stops, attns = [], [], [] |
|
|
| |
| use_teacher = (mel_tgt is not None) and training |
|
|
| if use_teacher: |
| num_frames = mel_tgt.shape[1] |
| for t in range(num_frames): |
| inp = st['mel'] if t == 0 else mel_tgt[:, t - 1, :] |
| m, s, w, st = self._step(inp, enc, st) |
| mels.append(m) |
| stops.append(s) |
| attns.append(w) |
| else: |
| |
| for t in range(self.max_steps): |
| m, s, w, st = self._step(st['mel'], enc, st) |
| mels.append(m) |
| stops.append(s) |
| attns.append(w) |
| if t > 10 and tf.reduce_mean(tf.sigmoid(s)) > 0.5: |
| break |
|
|
| return tf.stack(mels, 1), tf.stack(stops, 1), tf.stack(attns, 1) |
|
|
| class TeraVO(tf.keras.Model): |
| def __init__(self, vocab_size, n_mels=80, emb_dim=256, enc_dim=256, |
| dec_dim=512, attn_dim=128, prenet_dim=128, postnet_dim=256, |
| n_conv=3, num_voices=3, voice_emb_dim=64, max_steps=800, **kw): |
| super().__init__(**kw) |
| self.voice_emb = tf.keras.layers.Embedding(num_voices, voice_emb_dim) |
| self.voice_proj = tf.keras.layers.Dense(enc_dim) |
| self.encoder = Encoder(vocab_size, emb_dim, enc_dim, n_conv) |
| self.decoder = Decoder(n_mels, dec_dim, attn_dim, prenet_dim, max_steps) |
| self.postnet = PostNet(n_mels, postnet_dim) |
|
|
| def call(self, inputs, training=False): |
| text = inputs['text'] |
| vid = inputs['voice_id'] |
| mel_tgt = inputs.get('mel_target', None) |
| enc = self.encoder(text, training=training) |
| ve = self.voice_proj(self.voice_emb(vid)) |
| enc = enc + tf.expand_dims(ve, 1) |
| mel, stop, attn = self.decoder(enc, mel_tgt, training=training) |
| mel_post = self.postnet(mel, training=training) |
| return {'mel_outputs': mel, 'mel_outputs_postnet': mel_post, |
| 'stop_tokens': stop, 'attention_weights': attn} |
|
|
| def create_model(vocab_size, n_mels=80, num_voices=3): |
| return TeraVO(vocab_size, n_mels=n_mels, num_voices=num_voices) |
|
|