Spaces:
Configuration error
Configuration error
File size: 402 Bytes
a23872f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from gradient_flow_ops import ReplaceGrad
replace_grad = ReplaceGrad.apply
def vector_quantize(x, codebook):
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
indices = d.argmin(-1)
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
return replace_grad(x_q, x) |