import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from nets.spg.vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack from nets.spg.vqvae_1d import AudioEncoder class EncoderSC(nn.Module): def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens): super(EncoderSC, self).__init__() self._num_hiddens = num_hiddens self._num_residual_layers = num_residual_layers self._num_residual_hiddens = num_residual_hiddens self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True) self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True) self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True, sample='down') self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True) self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down') self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1) def forward(self, x): out = [] h = self.project(x) h = self._enc_1(h) out[1] = h h = self._down_1(h) h = self._enc_2(h) out[2] = h h = self._down_2(h) h = self._enc_3(h) out[3] = h h = self.pre_vq_conv(h) return h, out class DecoderSC(nn.Module): def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False): super(DecoderSC, self).__init__() self._num_hiddens = num_hiddens self._num_residual_layers = num_residual_layers self._num_residual_hiddens = num_residual_hiddens self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1) self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True) self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up') self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True) self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True, sample='up') self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True) self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1) def forward(self, h, out): h = self.aft_vq_conv(h) h = h + out[3] h = self._dec_1(h) h = self._up_2(h) h = h + out[2] h = self._dec_2(h) h = self._up_3(h) h = h + out[1] h = self._dec_3(h) recon = self.project(h) return recon class VQVAE_SC(nn.Module): """VQ-VAE""" def __init__(self, in_dim, embedding_dim, num_embeddings, num_hiddens, num_residual_layers, num_residual_hiddens, commitment_cost=0.25, decay=0.99, share=False): super().__init__() self.in_dim = in_dim self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.encoder = EncoderSC(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay) self.decoder = DecoderSC(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) def encode(self, gt_poses): z, enc_feats = self.encoder(gt_poses.transpose(1, 2)) return z, enc_feats def decode(self, z, enc_feats): e, e_q_loss = self.vq_layer(z) x = self.decoder(e, enc_feats) return e, e_q_loss, x.transpose(1, 2)