Font-To-Sketch / code /losses.py
bkhmsi's picture
bug fix + cleaned losses
4435e47
raw
history blame
7.57 kB
import torch.nn as nn
import torchvision
from scipy.spatial import Delaunay
import torch
import numpy as np
from torch.nn import functional as nnf
from easydict import EasyDict
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon
from torchvision import transforms
from PIL import Image
class SDSLoss(nn.Module):
def __init__(self, cfg, device, model):
super(SDSLoss, self).__init__()
self.cfg = cfg
self.device = device
self.pipe = model
self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device)
self.sigmas = (1 - self.pipe.scheduler.alphas_cumprod).to(self.device)
self.text_embeddings = None
self.embed_text()
def embed_text(self):
# tokenizer and embed text
text_input = self.pipe.tokenizer(self.cfg.caption, padding="max_length",
max_length=self.pipe.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
uncond_input = self.pipe.tokenizer([""], padding="max_length",
max_length=text_input.input_ids.shape[-1],
return_tensors="pt")
with torch.no_grad():
text_embeddings = self.pipe.text_encoder(text_input.input_ids.to(self.device))[0]
uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.device))[0]
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
def forward(self, x_aug):
sds_loss = 0
# encode rendered image
x = x_aug * 2. - 1.
with torch.cuda.amp.autocast():
init_latent_z = (self.pipe.vae.encode(x).latent_dist.sample())
latent_z = 0.18215 * init_latent_z # scaling_factor * init_latents
with torch.inference_mode():
# sample timesteps
timestep = torch.randint(
low=50,
high=min(950, self.cfg.diffusion.timesteps) - 1, # avoid highest timestep | diffusion.timesteps=1000
size=(latent_z.shape[0],),
device=self.device, dtype=torch.long)
# add noise
eps = torch.randn_like(latent_z)
# zt = alpha_t * latent_z + sigma_t * eps
noised_latent_zt = self.pipe.scheduler.add_noise(latent_z, eps, timestep)
# denoise
z_in = torch.cat([noised_latent_zt] * 2) # expand latents for classifier free guidance
timestep_in = torch.cat([timestep] * 2)
with torch.autocast(device_type="cuda", dtype=torch.float16):
eps_t_uncond, eps_t = self.pipe.unet(z_in, timestep, encoder_hidden_states=self.text_embeddings).sample.float().chunk(2)
eps_t = eps_t_uncond + self.cfg.diffusion.guidance_scale * (eps_t - eps_t_uncond)
# w = alphas[timestep]^0.5 * (1 - alphas[timestep]) = alphas[timestep]^0.5 * sigmas[timestep]
grad_z = self.alphas[timestep]**0.5 * self.sigmas[timestep] * (eps_t - eps)
assert torch.isfinite(grad_z).all()
grad_z = torch.nan_to_num(grad_z.detach().float(), 0.0, 0.0, 0.0)
sds_loss = grad_z.clone() * latent_z
del grad_z
sds_loss = sds_loss.sum(1).mean()
return sds_loss
class ToneLoss(nn.Module):
def __init__(self, cfg):
super(ToneLoss, self).__init__()
self.dist_loss_weight = cfg.loss.tone.dist_loss_weight
self.im_init = None
self.cfg = cfg
self.mse_loss = nn.MSELoss()
self.blurrer = torchvision.transforms.GaussianBlur(kernel_size=(cfg.loss.tone.pixel_dist_kernel_blur,
cfg.loss.tone.pixel_dist_kernel_blur), sigma=(cfg.loss.tone.pixel_dist_sigma))
def set_image_init(self, im_init):
self.im_init = im_init.permute(2, 0, 1).unsqueeze(0)
self.init_blurred = self.blurrer(self.im_init)
def get_scheduler(self, step=None):
if step is not None:
return self.dist_loss_weight * np.exp(-(1/5)*((step-300)/(20)) ** 2)
else:
return self.dist_loss_weight
def forward(self, cur_raster, step=None):
blurred_cur = self.blurrer(cur_raster)
return self.mse_loss(self.init_blurred.detach(), blurred_cur) * self.get_scheduler(step)
class ConformalLoss:
def __init__(self, parameters: EasyDict, device: torch.device, target_letter: str, shape_groups):
self.parameters = parameters
self.target_letter = target_letter
self.shape_groups = shape_groups
self.faces = self.init_faces(device)
self.faces_roll_a = [torch.roll(self.faces[i], 1, 1) for i in range(len(self.faces))]
with torch.no_grad():
self.angles = []
self.reset()
def get_angles(self, points: torch.Tensor) -> torch.Tensor:
angles_ = []
for i in range(len(self.faces)):
triangles = points[self.faces[i]]
triangles_roll_a = points[self.faces_roll_a[i]]
edges = triangles_roll_a - triangles
length = edges.norm(dim=-1)
edges = edges / (length + 1e-1)[:, :, None]
edges_roll = torch.roll(edges, 1, 1)
cosine = torch.einsum('ned,ned->ne', edges, edges_roll)
angles = torch.arccos(cosine)
angles_.append(angles)
return angles_
def get_letter_inds(self, letter_to_insert):
for group, l in zip(self.shape_groups, self.target_letter):
if l == letter_to_insert:
letter_inds = group.shape_ids
return letter_inds[0], letter_inds[-1], len(letter_inds)
def reset(self):
points = torch.cat([point.clone().detach() for point in self.parameters.point])
self.angles = self.get_angles(points)
def init_faces(self, device: torch.device) -> torch.tensor:
faces_ = []
num_shapes = 0
for j, c in enumerate(self.target_letter):
points_np = [self.parameters.point[i].clone().detach().cpu().numpy() for i in range(len(self.parameters.point))]
start_ind, end_ind, shapes_per_letter = self.get_letter_inds(c)
print(c, start_ind, end_ind, shapes_per_letter)
holes = []
if shapes_per_letter > 1:
holes = points_np[start_ind+1:end_ind]
poly = Polygon(points_np[start_ind], holes=holes)
poly = poly.buffer(0)
points_np = np.concatenate(points_np)
faces = Delaunay(points_np).simplices
is_intersect = np.array([poly.contains(Point(points_np[face].mean(0))) for face in faces], dtype=np.bool_)
faces_.append(torch.from_numpy(faces[is_intersect]).to(device, dtype=torch.int64))
num_shapes += shapes_per_letter
if num_shapes >= len(self.target_letter):
break
return faces_
def __call__(self) -> torch.Tensor:
loss_angles = 0
points = torch.cat(self.parameters.point)
angles = self.get_angles(points)
for i in range(len(self.faces)):
loss_angles += (nnf.mse_loss(angles[i], self.angles[i]))
return loss_angles