Face-editor / vqgan_latent_ops.py
Erwann Millon
initial commit
a23872f
raw
history blame
402 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
from gradient_flow_ops import ReplaceGrad
replace_grad = ReplaceGrad.apply
def vector_quantize(x, codebook):
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
indices = d.argmin(-1)
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
return replace_grad(x_q, x)