musika_api / layers.py
nakas's picture
musika clone
050507e
import tensorflow as tf
import tensorflow.python.keras.backend as K
from tensorflow.python.eager import context
from tensorflow.python.ops import (
gen_math_ops,
math_ops,
sparse_ops,
standard_ops,
)
def l2normalize(v, eps=1e-12):
return v / (tf.norm(v) + eps)
class ConvSN2D(tf.keras.layers.Conv2D):
def __init__(self, filters, kernel_size, power_iterations=1, datatype=tf.float32, **kwargs):
super(ConvSN2D, self).__init__(filters, kernel_size, **kwargs)
self.power_iterations = power_iterations
self.datatype = datatype
def build(self, input_shape):
super(ConvSN2D, self).build(input_shape)
if self.data_format == "channels_first":
channel_axis = 1
else:
channel_axis = -1
self.u = self.add_weight(
self.name + "_u",
shape=tuple([1, self.kernel.shape.as_list()[-1]]),
initializer=tf.initializers.RandomNormal(0, 1),
trainable=False,
dtype=self.dtype,
)
def compute_spectral_norm(self, W, new_u, W_shape):
for _ in range(self.power_iterations):
new_v = l2normalize(tf.matmul(new_u, tf.transpose(W)))
new_u = l2normalize(tf.matmul(new_v, W))
sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u))
W_bar = W / sigma
with tf.control_dependencies([self.u.assign(new_u)]):
W_bar = tf.reshape(W_bar, W_shape)
return W_bar
def call(self, inputs):
W_shape = self.kernel.shape.as_list()
W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1]))
new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape)
outputs = self._convolution_op(inputs, new_kernel)
if self.use_bias:
if self.data_format == "channels_first":
outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW")
else:
outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC")
if self.activation is not None:
return self.activation(outputs)
return outputs
class DenseSN(tf.keras.layers.Dense):
def __init__(self, datatype=tf.float32, **kwargs):
super(DenseSN, self).__init__(**kwargs)
self.datatype = datatype
def build(self, input_shape):
super(DenseSN, self).build(input_shape)
self.u = self.add_weight(
self.name + "_u",
shape=tuple([1, self.kernel.shape.as_list()[-1]]),
initializer=tf.initializers.RandomNormal(0, 1),
trainable=False,
dtype=self.datatype,
)
def compute_spectral_norm(self, W, new_u, W_shape):
new_v = l2normalize(tf.matmul(new_u, tf.transpose(W)))
new_u = l2normalize(tf.matmul(new_v, W))
sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u))
W_bar = W / sigma
with tf.control_dependencies([self.u.assign(new_u)]):
W_bar = tf.reshape(W_bar, W_shape)
return W_bar
def call(self, inputs):
W_shape = self.kernel.shape.as_list()
W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1]))
new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape)
rank = len(inputs.shape)
if rank > 2:
outputs = standard_ops.tensordot(inputs, new_kernel, [[rank - 1], [0]])
if not context.executing_eagerly():
shape = inputs.shape.as_list()
output_shape = shape[:-1] + [self.units]
outputs.set_shape(output_shape)
else:
inputs = math_ops.cast(inputs, self._compute_dtype)
if K.is_sparse(inputs):
outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, new_kernel)
else:
outputs = gen_math_ops.mat_mul(inputs, new_kernel)
if self.use_bias:
outputs = tf.nn.bias_add(outputs, self.bias)
if self.activation is not None:
return self.activation(outputs)
return outputs
class AddNoise(tf.keras.layers.Layer):
def __init__(self, datatype=tf.float32, **kwargs):
super(AddNoise, self).__init__(**kwargs)
self.datatype = datatype
def build(self, input_shape):
self.b = self.add_weight(
shape=[
1,
],
initializer=tf.keras.initializers.zeros(),
trainable=True,
name="noise_weight",
)
def call(self, inputs):
rand = tf.random.normal(
[tf.shape(inputs)[0], inputs.shape[1], inputs.shape[2], 1],
mean=0.0,
stddev=1.0,
dtype=self.datatype,
)
output = inputs + self.b * rand
return output
class PosEnc(tf.keras.layers.Layer):
def __init__(self, datatype=tf.float32, **kwargs):
super(PosEnc, self).__init__(**kwargs)
self.datatype = datatype
def call(self, inputs):
pos = tf.repeat(
tf.reshape(tf.range(inputs.shape[-3], dtype=tf.int32), [1, -1, 1, 1]),
inputs.shape[-2],
-2,
)
pos = tf.cast(tf.repeat(pos, tf.shape(inputs)[0], 0), self.dtype) / tf.cast(inputs.shape[-3], self.datatype)
return tf.concat([inputs, pos], -1) # [bs,1,hop,2]
def flatten_hw(x, data_format="channels_last"):
if data_format == "channels_last":
x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first`
old_shape = tf.shape(x)