TalkSHOW / nets /inpainting /vqvae_1d_sc.py
feifeifeiliu's picture
first version
865fd8a
raw
history blame
3.98 kB
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.spg.vqvae_modules import VectorQuantizerEMA, ConvNormRelu, Res_CNR_Stack
from nets.spg.vqvae_1d import AudioEncoder
class EncoderSC(nn.Module):
def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
super(EncoderSC, self).__init__()
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
self.project = ConvNormRelu(in_dim, self._num_hiddens // 4, leaky=True)
self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
self._down_1 = ConvNormRelu(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, residual=True,
sample='down')
self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
self._down_2 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens, leaky=True, residual=True, sample='down')
self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
def forward(self, x):
out = []
h = self.project(x)
h = self._enc_1(h)
out[1] = h
h = self._down_1(h)
h = self._enc_2(h)
out[2] = h
h = self._down_2(h)
h = self._enc_3(h)
out[3] = h
h = self.pre_vq_conv(h)
return h, out
class DecoderSC(nn.Module):
def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens, ae=False):
super(DecoderSC, self).__init__()
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True)
self._up_2 = ConvNormRelu(self._num_hiddens, self._num_hiddens // 2, leaky=True, residual=True, sample='up')
self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True)
self._up_3 = ConvNormRelu(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True, residual=True,
sample='up')
self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True)
self.project = nn.Conv1d(self._num_hiddens // 4, out_dim, 1, 1)
def forward(self, h, out):
h = self.aft_vq_conv(h)
h = h + out[3]
h = self._dec_1(h)
h = self._up_2(h)
h = h + out[2]
h = self._dec_2(h)
h = self._up_3(h)
h = h + out[1]
h = self._dec_3(h)
recon = self.project(h)
return recon
class VQVAE_SC(nn.Module):
"""VQ-VAE"""
def __init__(self, in_dim, embedding_dim, num_embeddings,
num_hiddens, num_residual_layers, num_residual_hiddens,
commitment_cost=0.25, decay=0.99, share=False):
super().__init__()
self.in_dim = in_dim
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.encoder = EncoderSC(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
self.vq_layer = VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay)
self.decoder = DecoderSC(in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
def encode(self, gt_poses):
z, enc_feats = self.encoder(gt_poses.transpose(1, 2))
return z, enc_feats
def decode(self, z, enc_feats):
e, e_q_loss = self.vq_layer(z)
x = self.decoder(e, enc_feats)
return e, e_q_loss, x.transpose(1, 2)