Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torchvision | |
from scipy.spatial import Delaunay | |
import torch | |
import numpy as np | |
from torch.nn import functional as nnf | |
from easydict import EasyDict | |
from shapely.geometry import Point | |
from shapely.geometry.polygon import Polygon | |
from torchvision import transforms | |
from PIL import Image | |
class SDSLoss(nn.Module): | |
def __init__(self, cfg, device, model): | |
super(SDSLoss, self).__init__() | |
self.cfg = cfg | |
self.device = device | |
self.pipe = model | |
self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device) | |
self.sigmas = (1 - self.pipe.scheduler.alphas_cumprod).to(self.device) | |
self.text_embeddings = None | |
self.embed_text() | |
def embed_text(self): | |
# tokenizer and embed text | |
text_input = self.pipe.tokenizer(self.cfg.caption, padding="max_length", | |
max_length=self.pipe.tokenizer.model_max_length, | |
truncation=True, return_tensors="pt") | |
uncond_input = self.pipe.tokenizer([""], padding="max_length", | |
max_length=text_input.input_ids.shape[-1], | |
return_tensors="pt") | |
with torch.no_grad(): | |
text_embeddings = self.pipe.text_encoder(text_input.input_ids.to(self.device))[0] | |
uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0) | |
def forward(self, x_aug): | |
sds_loss = 0 | |
# encode rendered image | |
x = x_aug * 2. - 1. | |
with torch.cuda.amp.autocast(): | |
init_latent_z = (self.pipe.vae.encode(x).latent_dist.sample()) | |
latent_z = 0.18215 * init_latent_z # scaling_factor * init_latents | |
with torch.inference_mode(): | |
# sample timesteps | |
timestep = torch.randint( | |
low=50, | |
high=min(950, self.cfg.diffusion.timesteps) - 1, # avoid highest timestep | diffusion.timesteps=1000 | |
size=(latent_z.shape[0],), | |
device=self.device, dtype=torch.long) | |
# add noise | |
eps = torch.randn_like(latent_z) | |
# zt = alpha_t * latent_z + sigma_t * eps | |
noised_latent_zt = self.pipe.scheduler.add_noise(latent_z, eps, timestep) | |
# denoise | |
z_in = torch.cat([noised_latent_zt] * 2) # expand latents for classifier free guidance | |
timestep_in = torch.cat([timestep] * 2) | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
eps_t_uncond, eps_t = self.pipe.unet(z_in, timestep, encoder_hidden_states=self.text_embeddings).sample.float().chunk(2) | |
eps_t = eps_t_uncond + self.cfg.diffusion.guidance_scale * (eps_t - eps_t_uncond) | |
# w = alphas[timestep]^0.5 * (1 - alphas[timestep]) = alphas[timestep]^0.5 * sigmas[timestep] | |
grad_z = self.alphas[timestep]**0.5 * self.sigmas[timestep] * (eps_t - eps) | |
assert torch.isfinite(grad_z).all() | |
grad_z = torch.nan_to_num(grad_z.detach().float(), 0.0, 0.0, 0.0) | |
sds_loss = grad_z.clone() * latent_z | |
del grad_z | |
sds_loss = sds_loss.sum(1).mean() | |
return sds_loss | |
class ToneLoss(nn.Module): | |
def __init__(self, cfg): | |
super(ToneLoss, self).__init__() | |
self.dist_loss_weight = cfg.loss.tone.dist_loss_weight | |
self.im_init = None | |
self.cfg = cfg | |
self.mse_loss = nn.MSELoss() | |
self.blurrer = torchvision.transforms.GaussianBlur(kernel_size=(cfg.loss.tone.pixel_dist_kernel_blur, | |
cfg.loss.tone.pixel_dist_kernel_blur), sigma=(cfg.loss.tone.pixel_dist_sigma)) | |
def set_image_init(self, im_init): | |
self.im_init = im_init.permute(2, 0, 1).unsqueeze(0) | |
self.init_blurred = self.blurrer(self.im_init) | |
def get_scheduler(self, step=None): | |
if step is not None: | |
return self.dist_loss_weight * np.exp(-(1/5)*((step-300)/(20)) ** 2) | |
else: | |
return self.dist_loss_weight | |
def forward(self, cur_raster, step=None): | |
blurred_cur = self.blurrer(cur_raster) | |
return self.mse_loss(self.init_blurred.detach(), blurred_cur) * self.get_scheduler(step) | |
class ConformalLoss: | |
def __init__(self, parameters: EasyDict, device: torch.device, target_letter: str, shape_groups): | |
self.parameters = parameters | |
self.target_letter = target_letter | |
self.shape_groups = shape_groups | |
self.faces = self.init_faces(device) | |
self.faces_roll_a = [torch.roll(self.faces[i], 1, 1) for i in range(len(self.faces))] | |
with torch.no_grad(): | |
self.angles = [] | |
self.reset() | |
def get_angles(self, points: torch.Tensor) -> torch.Tensor: | |
angles_ = [] | |
for i in range(len(self.faces)): | |
triangles = points[self.faces[i]] | |
triangles_roll_a = points[self.faces_roll_a[i]] | |
edges = triangles_roll_a - triangles | |
length = edges.norm(dim=-1) | |
edges = edges / (length + 1e-1)[:, :, None] | |
edges_roll = torch.roll(edges, 1, 1) | |
cosine = torch.einsum('ned,ned->ne', edges, edges_roll) | |
angles = torch.arccos(cosine) | |
angles_.append(angles) | |
return angles_ | |
def get_letter_inds(self, letter_to_insert): | |
for group, l in zip(self.shape_groups, self.target_letter): | |
if l == letter_to_insert: | |
letter_inds = group.shape_ids | |
return letter_inds[0], letter_inds[-1], len(letter_inds) | |
def reset(self): | |
points = torch.cat([point.clone().detach() for point in self.parameters.point]) | |
self.angles = self.get_angles(points) | |
def init_faces(self, device: torch.device) -> torch.tensor: | |
faces_ = [] | |
num_shapes = 0 | |
for j, c in enumerate(self.target_letter): | |
points_np = [self.parameters.point[i].clone().detach().cpu().numpy() for i in range(len(self.parameters.point))] | |
start_ind, end_ind, shapes_per_letter = self.get_letter_inds(c) | |
print(c, start_ind, end_ind, shapes_per_letter) | |
holes = [] | |
if shapes_per_letter > 1: | |
holes = points_np[start_ind+1:end_ind] | |
poly = Polygon(points_np[start_ind], holes=holes) | |
poly = poly.buffer(0) | |
points_np = np.concatenate(points_np) | |
faces = Delaunay(points_np).simplices | |
is_intersect = np.array([poly.contains(Point(points_np[face].mean(0))) for face in faces], dtype=np.bool_) | |
faces_.append(torch.from_numpy(faces[is_intersect]).to(device, dtype=torch.int64)) | |
num_shapes += shapes_per_letter | |
if num_shapes >= len(self.target_letter): | |
break | |
return faces_ | |
def __call__(self) -> torch.Tensor: | |
loss_angles = 0 | |
points = torch.cat(self.parameters.point) | |
angles = self.get_angles(points) | |
for i in range(len(self.faces)): | |
loss_angles += (nnf.mse_loss(angles[i], self.angles[i])) | |
return loss_angles |