| import numpy as np |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from .layers import MLPLayers |
| from .rq import ResidualVectorQuantizer |
|
|
|
|
| class RQVAE(nn.Module): |
| def __init__(self, |
| in_dim=768, |
| |
| num_emb_list=None, |
| e_dim=64, |
| |
| layers=None, |
| dropout_prob=0.0, |
| bn=False, |
| loss_type="mse", |
| quant_loss_weight=1.0, |
| kmeans_init=False, |
| kmeans_iters=100, |
| |
| sk_epsilons=None, |
| sk_iters=100, |
| ): |
| super(RQVAE, self).__init__() |
|
|
| self.in_dim = in_dim |
| self.num_emb_list = num_emb_list |
| self.e_dim = e_dim |
|
|
| self.layers = layers |
| self.dropout_prob = dropout_prob |
| self.bn = bn |
| self.loss_type = loss_type |
| self.quant_loss_weight=quant_loss_weight |
| self.kmeans_init = kmeans_init |
| self.kmeans_iters = kmeans_iters |
| self.sk_epsilons = sk_epsilons |
| self.sk_iters = sk_iters |
|
|
| self.encode_layer_dims = [self.in_dim] + self.layers + [self.e_dim] |
| self.encoder = MLPLayers(layers=self.encode_layer_dims, |
| dropout=self.dropout_prob,bn=self.bn) |
|
|
| self.rq = ResidualVectorQuantizer(num_emb_list, e_dim, |
| kmeans_init = self.kmeans_init, |
| kmeans_iters = self.kmeans_iters, |
| sk_epsilons=self.sk_epsilons, |
| sk_iters=self.sk_iters,) |
|
|
| self.decode_layer_dims = self.encode_layer_dims[::-1] |
| self.decoder = MLPLayers(layers=self.decode_layer_dims, |
| dropout=self.dropout_prob,bn=self.bn) |
|
|
| def forward(self, x, use_sk=True): |
| |
| x = self.encoder(x) |
| x_q, rq_loss, indices = self.rq(x,use_sk=use_sk) |
| out = self.decoder(x_q) |
| |
|
|
| return out, rq_loss, indices |
|
|
| @torch.no_grad() |
| def get_indices(self, xs, use_sk=False): |
| x_e = self.encoder(xs) |
| _, _, indices = self.rq(x_e, use_sk=use_sk) |
| return indices |
|
|
| def compute_loss(self, out, quant_loss, xs=None): |
|
|
| if self.loss_type == 'mse': |
| loss_recon = F.mse_loss(out, xs, reduction='mean') |
| elif self.loss_type == 'l1': |
| loss_recon = F.l1_loss(out, xs, reduction='mean') |
| else: |
| raise ValueError('incompatible loss type') |
|
|
| loss_total = loss_recon + self.quant_loss_weight * quant_loss |
|
|
| return loss_total, loss_recon |