|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.python.ops import math_ops |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Linear(keras.layers.Layer): |
|
def __init__(self, units, use_bias, **kwargs): |
|
super(Linear, self).__init__(**kwargs) |
|
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') |
|
self.activation = keras.layers.ReLU() |
|
|
|
def call(self, x): |
|
""" |
|
shapes: |
|
x: B x T x C |
|
""" |
|
return self.activation(self.linear_layer(x)) |
|
|
|
|
|
class LinearBN(keras.layers.Layer): |
|
def __init__(self, units, use_bias, **kwargs): |
|
super(LinearBN, self).__init__(**kwargs) |
|
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') |
|
self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization') |
|
self.activation = keras.layers.ReLU() |
|
|
|
def call(self, x, training=None): |
|
""" |
|
shapes: |
|
x: B x T x C |
|
""" |
|
out = self.linear_layer(x) |
|
out = self.batch_normalization(out, training=training) |
|
return self.activation(out) |
|
|
|
|
|
class Prenet(keras.layers.Layer): |
|
def __init__(self, |
|
prenet_type, |
|
prenet_dropout, |
|
units, |
|
bias, |
|
**kwargs): |
|
super(Prenet, self).__init__(**kwargs) |
|
self.prenet_type = prenet_type |
|
self.prenet_dropout = prenet_dropout |
|
self.linear_layers = [] |
|
if prenet_type == "bn": |
|
self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] |
|
elif prenet_type == "original": |
|
self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] |
|
else: |
|
raise RuntimeError(' [!] Unknown prenet type.') |
|
if prenet_dropout: |
|
self.dropout = keras.layers.Dropout(rate=0.5) |
|
|
|
def call(self, x, training=None): |
|
""" |
|
shapes: |
|
x: B x T x C |
|
""" |
|
for linear in self.linear_layers: |
|
if self.prenet_dropout: |
|
x = self.dropout(linear(x), training=training) |
|
else: |
|
x = linear(x) |
|
return x |
|
|
|
|
|
def _sigmoid_norm(score): |
|
attn_weights = tf.nn.sigmoid(score) |
|
attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True) |
|
return attn_weights |
|
|
|
|
|
class Attention(keras.layers.Layer): |
|
"""TODO: implement forward_attention |
|
TODO: location sensitive attention |
|
TODO: implement attention windowing """ |
|
def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, |
|
loc_attn_kernel_size, use_windowing, norm, use_forward_attn, |
|
use_trans_agent, use_forward_attn_mask, **kwargs): |
|
super(Attention, self).__init__(**kwargs) |
|
self.use_loc_attn = use_loc_attn |
|
self.loc_attn_n_filters = loc_attn_n_filters |
|
self.loc_attn_kernel_size = loc_attn_kernel_size |
|
self.use_windowing = use_windowing |
|
self.norm = norm |
|
self.use_forward_attn = use_forward_attn |
|
self.use_trans_agent = use_trans_agent |
|
self.use_forward_attn_mask = use_forward_attn_mask |
|
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer') |
|
self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer') |
|
self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer') |
|
if use_loc_attn: |
|
self.location_conv1d = keras.layers.Conv1D( |
|
filters=loc_attn_n_filters, |
|
kernel_size=loc_attn_kernel_size, |
|
padding='same', |
|
use_bias=False, |
|
name='location_layer/location_conv1d') |
|
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense') |
|
if norm == 'softmax': |
|
self.norm_func = tf.nn.softmax |
|
elif norm == 'sigmoid': |
|
self.norm_func = _sigmoid_norm |
|
else: |
|
raise ValueError("Unknown value for attention norm type") |
|
|
|
def init_states(self, batch_size, value_length): |
|
states = [] |
|
if self.use_loc_attn: |
|
attention_cum = tf.zeros([batch_size, value_length]) |
|
attention_old = tf.zeros([batch_size, value_length]) |
|
states = [attention_cum, attention_old] |
|
if self.use_forward_attn: |
|
alpha = tf.concat([ |
|
tf.ones([batch_size, 1]), |
|
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7 |
|
], 1) |
|
states.append(alpha) |
|
return tuple(states) |
|
|
|
def process_values(self, values): |
|
""" cache values for decoder iterations """ |
|
|
|
self.processed_values = self.inputs_layer(values) |
|
self.values = values |
|
|
|
def get_loc_attn(self, query, states): |
|
""" compute location attention, query layer and |
|
unnorm. attention weights""" |
|
attention_cum, attention_old = states[:2] |
|
attn_cat = tf.stack([attention_old, attention_cum], axis=2) |
|
|
|
processed_query = self.query_layer(tf.expand_dims(query, 1)) |
|
processed_attn = self.location_dense(self.location_conv1d(attn_cat)) |
|
score = self.v( |
|
tf.nn.tanh(self.processed_values + processed_query + |
|
processed_attn)) |
|
score = tf.squeeze(score, axis=2) |
|
return score, processed_query |
|
|
|
def get_attn(self, query): |
|
""" compute query layer and unnormalized attention weights """ |
|
processed_query = self.query_layer(tf.expand_dims(query, 1)) |
|
score = self.v(tf.nn.tanh(self.processed_values + processed_query)) |
|
score = tf.squeeze(score, axis=2) |
|
return score, processed_query |
|
|
|
def apply_score_masking(self, score, mask): |
|
""" ignore sequence paddings """ |
|
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) |
|
|
|
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) |
|
return score |
|
|
|
def apply_forward_attention(self, alignment, alpha): |
|
|
|
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0) |
|
|
|
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment |
|
|
|
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True) |
|
return new_alpha |
|
|
|
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None): |
|
states = [] |
|
if self.use_loc_attn: |
|
states = [old_states[0] + scores_norm, attn_weights] |
|
if self.use_forward_attn: |
|
states.append(new_alpha) |
|
return tuple(states) |
|
|
|
def call(self, query, states): |
|
""" |
|
shapes: |
|
query: B x D |
|
""" |
|
if self.use_loc_attn: |
|
score, _ = self.get_loc_attn(query, states) |
|
else: |
|
score, _ = self.get_attn(query) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scores_norm = self.norm_func(score) |
|
attn_weights = scores_norm |
|
|
|
|
|
new_alpha = None |
|
if self.use_forward_attn: |
|
new_alpha = self.apply_forward_attention(attn_weights, states[-1]) |
|
attn_weights = new_alpha |
|
|
|
|
|
|
|
states = self.update_states(states, scores_norm, attn_weights, new_alpha) |
|
|
|
|
|
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False) |
|
context_vector = tf.squeeze(context_vector, axis=1) |
|
return context_vector, attn_weights, states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|