VibeToken / modeling /modules /fuzzy_embedding.py
APGASU's picture
scripts
7bef20f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
import math
class FuzzyEmbedding(nn.Module):
def __init__(self, grid_size, scale, width, apply_fuzzy=False):
super(FuzzyEmbedding, self).__init__()
assert grid_size == 1024, "grid_size must be 1024 for now"
self.grid_size = grid_size
self.scale = scale
self.width = width
self.apply_fuzzy = apply_fuzzy
# grid_size is the minimum possible token size
# then we can use grid_sample to get the fuzzy embedding for any resolution
self.positional_embedding = nn.Parameter(
scale * torch.randn(grid_size, width))
self.class_positional_embedding = nn.Parameter(
scale * torch.randn(1, width))
@torch.cuda.amp.autocast(enabled=False)
def forward(self, grid_height, grid_width, train=True, dtype=torch.float32):
meshx, meshy = torch.meshgrid(
torch.tensor(list(range(grid_height)), device=self.positional_embedding.device),
torch.tensor(list(range(grid_width)), device=self.positional_embedding.device)
)
meshx = meshx.to(dtype)
meshy = meshy.to(dtype)
# Normalize coordinates to [-1, 1] range
meshx = 2 * (meshx / (grid_height - 1)) - 1
meshy = 2 * (meshy / (grid_width - 1)) - 1
if self.apply_fuzzy:
# Add uniform noise in range [-0.0004, 0.0004] to x and y coordinates
if train:
noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004
noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004
else:
noise_x = torch.zeros_like(meshx)
noise_y = torch.zeros_like(meshy)
# Apply noise to the mesh coordinates
meshx = meshx + noise_x
meshy = meshy + noise_y
grid = torch.stack((meshy, meshx), 2).to(self.positional_embedding.device)
grid = grid.unsqueeze(0) # add batch dim
positional_embedding = einops.rearrange(self.positional_embedding, "(h w) d -> d h w", h=int(math.sqrt(self.grid_size)), w=int(math.sqrt(self.grid_size)))
positional_embedding = positional_embedding.to(dtype)
positional_embedding = positional_embedding.unsqueeze(0) # add batch dim
fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False)
fuzzy_embedding = fuzzy_embedding.to(dtype)
fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0)
final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0)
return final_embedding
if __name__ == "__main__":
fuzzy_embedding = FuzzyEmbedding(256, 1.0, 1024)
grid_height = 16
grid_width = 32
fuzzy_embedding = fuzzy_embedding(grid_height, grid_width, dtype=torch.bfloat16)
print(fuzzy_embedding.shape)