TalkSHOW / nets /spg /vqvae_modules.py
feifeifeiliu's picture
first version
865fd8a
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
class CasualCT(nn.Module):
def __init__(self,
in_channels,
out_channels,
leaky=False,
p=0,
groups=1, ):
'''
conv-bn-relu
'''
super(CasualCT, self).__init__()
padding = 0
kernel_size = 2
stride = 2
in_channels = in_channels * groups
out_channels = out_channels * groups
self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups)
self.norm = nn.BatchNorm1d(out_channels)
self.dropout = nn.Dropout(p=p)
if leaky:
self.relu = nn.LeakyReLU(negative_slope=0.2)
else:
self.relu = nn.ReLU()
def forward(self, x, **kwargs):
out = self.norm(self.dropout(self.conv(x)))
return self.relu(out)
class CasualConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
leaky=False,
p=0,
groups=1,
downsample=False):
'''
conv-bn-relu
'''
super(CasualConv, self).__init__()
padding = 0
kernel_size = 2
stride = 1
self.downsample = downsample
if self.downsample:
kernel_size = 2
stride = 2
in_channels = in_channels * groups
out_channels = out_channels * groups
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups)
self.norm = nn.BatchNorm1d(out_channels)
self.dropout = nn.Dropout(p=p)
if leaky:
self.relu = nn.LeakyReLU(negative_slope=0.2)
else:
self.relu = nn.ReLU()
def forward(self, x, pre_state=None):
if not self.downsample:
if pre_state is not None:
x = torch.cat([pre_state, x], dim=-1)
else:
zeros = torch.zeros([x.shape[0], x.shape[1], 1], device=x.device)
x = torch.cat([zeros, x], dim=-1)
out = self.norm(self.dropout(self.conv(x)))
return self.relu(out)
class ConvNormRelu(nn.Module):
'''
(B,C_in,H,W) -> (B, C_out, H, W)
there exist some kernel size that makes the result is not H/s
#TODO: there might some problems with residual
'''
def __init__(self,
in_channels,
out_channels,
leaky=False,
sample='none',
p=0,
groups=1,
residual=False,
norm='bn'):
'''
conv-bn-relu
'''
super(ConvNormRelu, self).__init__()
self.residual = residual
self.norm_type = norm
padding = 1
if sample == 'none':
kernel_size = 3
stride = 1
elif sample == 'one':
padding = 0
kernel_size = stride = 1
else:
kernel_size = 4
stride = 2
if self.residual:
if sample == 'down':
self.residual_layer = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
elif sample == 'up':
self.residual_layer = nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
else:
if in_channels == out_channels:
self.residual_layer = nn.Identity()
else:
self.residual_layer = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
)
in_channels = in_channels * groups
out_channels = out_channels * groups
if sample == 'up':
self.conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups)
else:
self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=groups)
self.norm = nn.BatchNorm1d(out_channels)
self.dropout = nn.Dropout(p=p)
if leaky:
self.relu = nn.LeakyReLU(negative_slope=0.2)
else:
self.relu = nn.ReLU()
def forward(self, x, **kwargs):
out = self.norm(self.dropout(self.conv(x)))
if self.residual:
residual = self.residual_layer(x)
out += residual
return self.relu(out)
class Res_CNR_Stack(nn.Module):
def __init__(self,
channels,
layers,
sample='none',
leaky=False,
casual=False,
):
super(Res_CNR_Stack, self).__init__()
if casual:
kernal_size = 1
padding = 0
conv = CasualConv
else:
kernal_size = 3
padding = 1
conv = ConvNormRelu
if sample == 'one':
kernal_size = 1
padding = 0
self._layers = nn.ModuleList()
for i in range(layers):
self._layers.append(conv(channels, channels, leaky=leaky, sample=sample))
self.conv = nn.Conv1d(channels, channels, kernal_size, 1, padding)
self.norm = nn.BatchNorm1d(channels)
self.relu = nn.ReLU()
def forward(self, x, pre_state=None):
# cur_state = []
h = x
for i in range(self._layers.__len__()):
# cur_state.append(h[..., -1:])
h = self._layers[i](h, pre_state=pre_state[i] if pre_state is not None else None)
h = self.norm(self.conv(h))
return self.relu(h + x)
class ExponentialMovingAverage(nn.Module):
"""Maintains an exponential moving average for a value.
This module keeps track of a hidden exponential moving average that is
initialized as a vector of zeros which is then normalized to give the average.
This gives us a moving average which isn't biased towards either zero or the
initial value. Reference (https://arxiv.org/pdf/1412.6980.pdf)
Initially:
hidden_0 = 0
Then iteratively:
hidden_i = hidden_{i-1} - (hidden_{i-1} - value) * (1 - decay)
average_i = hidden_i / (1 - decay^i)
"""
def __init__(self, init_value, decay):
super().__init__()
self.decay = decay
self.counter = 0
self.register_buffer("hidden", torch.zeros_like(init_value))
def forward(self, value):
self.counter += 1
self.hidden.sub_((self.hidden - value) * (1 - self.decay))
average = self.hidden / (1 - self.decay ** self.counter)
return average
class VectorQuantizerEMA(nn.Module):
"""
VQ-VAE layer: Input any tensor to be quantized. Use EMA to update embeddings.
Args:
embedding_dim (int): the dimensionality of the tensors in the
quantized space. Inputs to the modules must be in this format as well.
num_embeddings (int): the number of vectors in the quantized space.
commitment_cost (float): scalar which controls the weighting of the loss terms (see
equation 4 in the paper - this variable is Beta).
decay (float): decay for the moving averages.
epsilon (float): small float constant to avoid numerical instability.
"""
def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay,
epsilon=1e-5):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
self.epsilon = epsilon
# initialize embeddings as buffers
embeddings = torch.empty(self.num_embeddings, self.embedding_dim)
nn.init.xavier_uniform_(embeddings)
self.register_buffer("embeddings", embeddings)
self.ema_dw = ExponentialMovingAverage(self.embeddings, decay)
# also maintain ema_cluster_size, which record the size of each embedding
self.ema_cluster_size = ExponentialMovingAverage(torch.zeros((self.num_embeddings,)), decay)
def forward(self, x):
# [B, C, H, W] -> [B, H, W, C]
x = x.permute(0, 2, 1).contiguous()
# [B, H, W, C] -> [BHW, C]
flat_x = x.reshape(-1, self.embedding_dim)
encoding_indices = self.get_code_indices(flat_x)
quantized = self.quantize(encoding_indices)
quantized = quantized.view_as(x) # [B, W, C]
if not self.training:
quantized = quantized.permute(0, 2, 1).contiguous()
return quantized, encoding_indices.view(quantized.shape[0], quantized.shape[2])
# update embeddings with EMA
with torch.no_grad():
encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
updated_ema_cluster_size = self.ema_cluster_size(torch.sum(encodings, dim=0))
n = torch.sum(updated_ema_cluster_size)
updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
(n + self.num_embeddings * self.epsilon) * n)
dw = torch.matmul(encodings.t(), flat_x) # sum encoding vectors of each cluster
updated_ema_dw = self.ema_dw(dw)
normalised_updated_ema_w = (
updated_ema_dw / updated_ema_cluster_size.reshape(-1, 1))
self.embeddings.data = normalised_updated_ema_w
# commitment loss
e_latent_loss = F.mse_loss(x, quantized.detach())
loss = self.commitment_cost * e_latent_loss
# Straight Through Estimator
quantized = x + (quantized - x).detach()
quantized = quantized.permute(0, 2, 1).contiguous()
return quantized, loss
def get_code_indices(self, flat_x):
# compute L2 distance
distances = (
torch.sum(flat_x ** 2, dim=1, keepdim=True) +
torch.sum(self.embeddings ** 2, dim=1) -
2. * torch.matmul(flat_x, self.embeddings.t())
) # [N, M]
encoding_indices = torch.argmin(distances, dim=1) # [N,]
return encoding_indices
def quantize(self, encoding_indices):
"""Returns embedding tensor for a batch of indices."""
return F.embedding(encoding_indices, self.embeddings)
class Casual_Encoder(nn.Module):
def __init__(self, in_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
super(Casual_Encoder, self).__init__()
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
self.project = nn.Conv1d(in_dim, self._num_hiddens // 4, 1, 1)
self._enc_1 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True)
self._down_1 = CasualConv(self._num_hiddens // 4, self._num_hiddens // 2, leaky=True, downsample=True)
self._enc_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True)
self._down_2 = CasualConv(self._num_hiddens // 2, self._num_hiddens, leaky=True, downsample=True)
self._enc_3 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True)
# self.pre_vq_conv = nn.Conv1d(self._num_hiddens, embedding_dim, 1, 1)
def forward(self, x):
h = self.project(x)
h, _ = self._enc_1(h)
h = self._down_1(h)
h, _ = self._enc_2(h)
h = self._down_2(h)
h, _ = self._enc_3(h)
# h = self.pre_vq_conv(h)
return h
class Casual_Decoder(nn.Module):
def __init__(self, out_dim, embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens):
super(Casual_Decoder, self).__init__()
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
# self.aft_vq_conv = nn.Conv1d(embedding_dim, self._num_hiddens, 1, 1)
self._dec_1 = Res_CNR_Stack(self._num_hiddens, self._num_residual_layers, leaky=True, casual=True)
self._up_2 = CasualCT(self._num_hiddens, self._num_hiddens // 2, leaky=True)
self._dec_2 = Res_CNR_Stack(self._num_hiddens // 2, self._num_residual_layers, leaky=True, casual=True)
self._up_3 = CasualCT(self._num_hiddens // 2, self._num_hiddens // 4, leaky=True)
self._dec_3 = Res_CNR_Stack(self._num_hiddens // 4, self._num_residual_layers, leaky=True, casual=True)
self.project = nn.Conv1d(self._num_hiddens//4, out_dim, 1, 1)
def forward(self, h, pre_state=None):
cur_state = []
# h = self.aft_vq_conv(x)
h, s = self._dec_1(h, pre_state[0] if pre_state is not None else None)
cur_state.append(s)
h = self._up_2(h)
h, s = self._dec_2(h, pre_state[1] if pre_state is not None else None)
cur_state.append(s)
h = self._up_3(h)
h, s = self._dec_3(h, pre_state[2] if pre_state is not None else None)
cur_state.append(s)
recon = self.project(h)
return recon, cur_state