|
""" |
|
WaveGRU model: melspectrogram => mu-law encoded waveform |
|
""" |
|
|
|
from typing import Tuple |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import pax |
|
from pax import GRUState |
|
from tqdm.cli import tqdm |
|
|
|
|
|
class ReLU(pax.Module): |
|
def __call__(self, x): |
|
return jax.nn.relu(x) |
|
|
|
|
|
def dilated_residual_conv_block(dim, kernel, stride, dilation): |
|
""" |
|
Use dilated convs to enlarge the receptive field |
|
""" |
|
return pax.Sequential( |
|
pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False), |
|
pax.LayerNorm(dim, -1, True, True), |
|
ReLU(), |
|
pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False), |
|
pax.LayerNorm(dim, -1, True, True), |
|
ReLU(), |
|
) |
|
|
|
|
|
def tile_1d(x, factor): |
|
""" |
|
Tile tensor of shape N, L, D into N, L*factor, D |
|
""" |
|
N, L, D = x.shape |
|
x = x[:, :, None, :] |
|
x = jnp.tile(x, (1, 1, factor, 1)) |
|
x = jnp.reshape(x, (N, L * factor, D)) |
|
return x |
|
|
|
|
|
def up_block(in_dim, out_dim, factor, relu=True): |
|
""" |
|
Tile >> Conv >> BatchNorm >> ReLU |
|
""" |
|
f = pax.Sequential( |
|
lambda x: tile_1d(x, factor), |
|
pax.Conv1D( |
|
in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False |
|
), |
|
pax.LayerNorm(out_dim, -1, True, True), |
|
) |
|
if relu: |
|
f >>= ReLU() |
|
return f |
|
|
|
|
|
class Upsample(pax.Module): |
|
""" |
|
Upsample melspectrogram to match raw audio sample rate. |
|
""" |
|
|
|
def __init__( |
|
self, input_dim, hidden_dim, rnn_dim, upsample_factors, has_linear_output=False |
|
): |
|
super().__init__() |
|
self.input_conv = pax.Sequential( |
|
pax.Conv1D(input_dim, hidden_dim, 1, with_bias=False), |
|
pax.LayerNorm(hidden_dim, -1, True, True), |
|
) |
|
self.upsample_factors = upsample_factors |
|
self.dilated_convs = [ |
|
dilated_residual_conv_block(hidden_dim, 3, 1, 2**i) for i in range(5) |
|
] |
|
self.up_factors = upsample_factors[:-1] |
|
self.up_blocks = [ |
|
up_block(hidden_dim, hidden_dim, x) for x in self.up_factors[:-1] |
|
] |
|
self.up_blocks.append( |
|
up_block( |
|
hidden_dim, |
|
hidden_dim if has_linear_output else 3 * rnn_dim, |
|
self.up_factors[-1], |
|
relu=False, |
|
) |
|
) |
|
if has_linear_output: |
|
self.x2zrh_fc = pax.Linear(hidden_dim, rnn_dim * 3) |
|
self.has_linear_output = has_linear_output |
|
|
|
self.final_tile = upsample_factors[-1] |
|
|
|
def __call__(self, x, no_repeat=False): |
|
x = self.input_conv(x) |
|
for residual in self.dilated_convs: |
|
y = residual(x) |
|
pad = (x.shape[1] - y.shape[1]) // 2 |
|
x = x[:, pad:-pad, :] + y |
|
|
|
for f in self.up_blocks: |
|
x = f(x) |
|
|
|
if self.has_linear_output: |
|
x = self.x2zrh_fc(x) |
|
|
|
if no_repeat: |
|
return x |
|
x = tile_1d(x, self.final_tile) |
|
return x |
|
|
|
|
|
class GRU(pax.Module): |
|
""" |
|
A customized GRU module. |
|
""" |
|
|
|
input_dim: int |
|
hidden_dim: int |
|
|
|
def __init__(self, hidden_dim: int): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
self.h_zrh_fc = pax.Linear( |
|
hidden_dim, |
|
hidden_dim * 3, |
|
w_init=jax.nn.initializers.variance_scaling( |
|
1, "fan_out", "truncated_normal" |
|
), |
|
) |
|
|
|
def initial_state(self, batch_size: int) -> GRUState: |
|
"""Create an all zeros initial state.""" |
|
return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32)) |
|
|
|
def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]: |
|
hidden = state.hidden |
|
x_zrh = x |
|
h_zrh = self.h_zrh_fc(hidden) |
|
x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1) |
|
h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1) |
|
|
|
zr = x_zr + h_zr |
|
zr = jax.nn.sigmoid(zr) |
|
z, r = jnp.split(zr, 2, axis=-1) |
|
|
|
h_hat = x_h + r * h_h |
|
h_hat = jnp.tanh(h_hat) |
|
|
|
h = (1 - z) * hidden + z * h_hat |
|
return GRUState(h), h |
|
|
|
|
|
class Pruner(pax.Module): |
|
""" |
|
Base class for pruners |
|
""" |
|
|
|
def compute_sparsity(self, step): |
|
t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3) |
|
z = 0.95 * jnp.clip(1.0 - t, a_min=0, a_max=1) |
|
return z |
|
|
|
def prune(self, step, weights): |
|
""" |
|
Return a mask |
|
""" |
|
z = self.compute_sparsity(step) |
|
x = weights |
|
H, W = x.shape |
|
x = x.reshape(H // 4, 4, W // 4, 4) |
|
x = jnp.abs(x) |
|
x = jnp.sum(x, axis=(1, 3), keepdims=True) |
|
q = jnp.quantile(jnp.reshape(x, (-1,)), z) |
|
x = x >= q |
|
x = jnp.tile(x, (1, 4, 1, 4)) |
|
x = jnp.reshape(x, (H, W)) |
|
return x |
|
|
|
|
|
class GRUPruner(Pruner): |
|
def __init__(self, gru): |
|
super().__init__() |
|
self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1 |
|
|
|
def __call__(self, gru: pax.GRU): |
|
""" |
|
Apply mask after an optimization step |
|
""" |
|
zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0) |
|
gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights) |
|
return gru |
|
|
|
def update_mask(self, step, gru: pax.GRU): |
|
""" |
|
Update internal masks |
|
""" |
|
z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1) |
|
z_mask = self.prune(step, z_weight) |
|
r_mask = self.prune(step, r_weight) |
|
h_mask = self.prune(step, h_weight) |
|
self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1) |
|
|
|
|
|
class LinearPruner(Pruner): |
|
def __init__(self, linear): |
|
super().__init__() |
|
self.mask = jnp.ones_like(linear.weight) == 1 |
|
|
|
def __call__(self, linear: pax.Linear): |
|
""" |
|
Apply mask after an optimization step |
|
""" |
|
return linear.replace(weight=jnp.where(self.mask, linear.weight, 0)) |
|
|
|
def update_mask(self, step, linear: pax.Linear): |
|
""" |
|
Update internal masks |
|
""" |
|
self.mask *= self.prune(step, linear.weight) |
|
|
|
|
|
class WaveGRU(pax.Module): |
|
""" |
|
WaveGRU vocoder model. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
mel_dim=80, |
|
rnn_dim=1024, |
|
upsample_factors=(5, 3, 20), |
|
has_linear_output=False, |
|
): |
|
super().__init__() |
|
self.embed = pax.Embed(256, 3 * rnn_dim) |
|
self.upsample = Upsample( |
|
input_dim=mel_dim, |
|
hidden_dim=512, |
|
rnn_dim=rnn_dim, |
|
upsample_factors=upsample_factors, |
|
has_linear_output=has_linear_output, |
|
) |
|
self.rnn = GRU(rnn_dim) |
|
self.o1 = pax.Linear(rnn_dim, rnn_dim) |
|
self.o2 = pax.Linear(rnn_dim, 256) |
|
self.gru_pruner = GRUPruner(self.rnn) |
|
self.o1_pruner = LinearPruner(self.o1) |
|
self.o2_pruner = LinearPruner(self.o2) |
|
|
|
def output(self, x): |
|
x = self.o1(x) |
|
x = jax.nn.relu(x) |
|
x = self.o2(x) |
|
return x |
|
|
|
def inference(self, mel, no_gru=False, seed=42): |
|
""" |
|
generate waveform form melspectrogram |
|
""" |
|
|
|
@jax.jit |
|
def step(rnn_state, mel, rng_key, x): |
|
x = self.embed(x) |
|
x = x + mel |
|
rnn_state, x = self.rnn(rnn_state, x) |
|
x = self.output(x) |
|
rng_key, next_rng_key = jax.random.split(rng_key, 2) |
|
x = jax.random.categorical(rng_key, x, axis=-1) |
|
return rnn_state, next_rng_key, x |
|
|
|
y = self.upsample(mel, no_repeat=no_gru) |
|
if no_gru: |
|
return y |
|
x = jnp.array([127], dtype=jnp.int32) |
|
rnn_state = self.rnn.initial_state(1) |
|
output = [] |
|
rng_key = jax.random.PRNGKey(seed) |
|
for i in tqdm(range(y.shape[1])): |
|
rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x) |
|
output.append(x) |
|
x = jnp.concatenate(output, axis=0) |
|
return x |
|
|
|
def __call__(self, mel, x): |
|
x = self.embed(x) |
|
y = self.upsample(mel) |
|
pad_left = (x.shape[1] - y.shape[1]) // 2 |
|
pad_right = x.shape[1] - y.shape[1] - pad_left |
|
x = x[:, pad_left:-pad_right] |
|
x = x + y |
|
_, x = pax.scan( |
|
self.rnn, |
|
self.rnn.initial_state(x.shape[0]), |
|
x, |
|
time_major=False, |
|
) |
|
x = self.output(x) |
|
return x |
|
|