Font-To-Sketch / code /losses.py
Badr AlKhamissi
config bug fix
e169248
raw
history blame
No virus
7.78 kB
import torch.nn as nn
import torchvision
from scipy.spatial import Delaunay
import torch
import numpy as np
from torch.nn import functional as nnf
from easydict import EasyDict
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon
class SDSLoss(nn.Module):
def __init__(self, cfg, device, model):
super(SDSLoss, self).__init__()
self.cfg = cfg
self.device = device
self.pipe = model
# self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
# self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# default scheduler: PNDMScheduler(beta_start=0.00085, beta_end=0.012,
# beta_schedule="scaled_linear", num_train_timesteps=1000)
self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device)
self.sigmas = (1 - self.pipe.scheduler.alphas_cumprod).to(self.device)
self.text_embeddings = None
self.embed_text()
def embed_text(self):
# tokenizer and embed text
text_input = self.pipe.tokenizer(self.cfg.caption, padding="max_length",
max_length=self.pipe.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
uncond_input = self.pipe.tokenizer([""], padding="max_length",
max_length=text_input.input_ids.shape[-1],
return_tensors="pt")
with torch.no_grad():
text_embeddings = self.pipe.text_encoder(text_input.input_ids.to(self.device))[0]
uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.device))[0]
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
# del self.pipe.tokenizer
# del self.pipe.text_encoder
def forward(self, x_aug):
sds_loss = 0
# encode rendered image
x = x_aug * 2. - 1.
with torch.cuda.amp.autocast():
init_latent_z = (self.pipe.vae.encode(x).latent_dist.sample())
latent_z = 0.18215 * init_latent_z # scaling_factor * init_latents
with torch.inference_mode():
# sample timesteps
timestep = torch.randint(
low=50,
high=min(950, self.cfg.diffusion.timesteps) - 1, # avoid highest timestep | diffusion.timesteps=1000
size=(latent_z.shape[0],),
device=self.device, dtype=torch.long)
# add noise
eps = torch.randn_like(latent_z)
# zt = alpha_t * latent_z + sigma_t * eps
noised_latent_zt = self.pipe.scheduler.add_noise(latent_z, eps, timestep)
# denoise
z_in = torch.cat([noised_latent_zt] * 2) # expand latents for classifier free guidance
timestep_in = torch.cat([timestep] * 2)
with torch.autocast(device_type="cuda", dtype=torch.float16):
eps_t_uncond, eps_t = self.pipe.unet(z_in, timestep, encoder_hidden_states=self.text_embeddings).sample.float().chunk(2)
eps_t = eps_t_uncond + self.cfg.diffusion.guidance_scale * (eps_t - eps_t_uncond)
# w = alphas[timestep]^0.5 * (1 - alphas[timestep]) = alphas[timestep]^0.5 * sigmas[timestep]
grad_z = self.alphas[timestep]**0.5 * self.sigmas[timestep] * (eps_t - eps)
assert torch.isfinite(grad_z).all()
grad_z = torch.nan_to_num(grad_z.detach().float(), 0.0, 0.0, 0.0)
sds_loss = grad_z.clone() * latent_z
del grad_z
sds_loss = sds_loss.sum(1).mean()
return sds_loss
class ToneLoss(nn.Module):
def __init__(self, cfg):
super(ToneLoss, self).__init__()
self.dist_loss_weight = cfg.loss.tone.dist_loss_weight
self.im_init = None
self.cfg = cfg
self.mse_loss = nn.MSELoss()
self.blurrer = torchvision.transforms.GaussianBlur(kernel_size=(cfg.loss.tone.pixel_dist_kernel_blur,
cfg.loss.tone.pixel_dist_kernel_blur), sigma=(cfg.loss.tone.pixel_dist_sigma))
def set_image_init(self, im_init):
self.im_init = im_init.permute(2, 0, 1).unsqueeze(0)
self.init_blurred = self.blurrer(self.im_init)
def get_scheduler(self, step=None):
if step is not None:
return self.dist_loss_weight * np.exp(-(1/5)*((step-300)/(20)) ** 2)
else:
return self.dist_loss_weight
def forward(self, cur_raster, step=None):
blurred_cur = self.blurrer(cur_raster)
return self.mse_loss(self.init_blurred.detach(), blurred_cur) * self.get_scheduler(step)
class ConformalLoss:
def __init__(self, parameters: EasyDict, device: torch.device, target_letter: str, shape_groups):
self.parameters = parameters
self.target_letter = target_letter
self.shape_groups = shape_groups
self.faces = self.init_faces(device)
self.faces_roll_a = [torch.roll(self.faces[i], 1, 1) for i in range(len(self.faces))]
with torch.no_grad():
self.angles = []
self.reset()
def get_angles(self, points: torch.Tensor) -> torch.Tensor:
angles_ = []
for i in range(len(self.faces)):
triangles = points[self.faces[i]]
triangles_roll_a = points[self.faces_roll_a[i]]
edges = triangles_roll_a - triangles
length = edges.norm(dim=-1)
edges = edges / (length + 1e-1)[:, :, None]
edges_roll = torch.roll(edges, 1, 1)
cosine = torch.einsum('ned,ned->ne', edges, edges_roll)
angles = torch.arccos(cosine)
angles_.append(angles)
return angles_
def get_letter_inds(self, letter_to_insert):
for group, l in zip(self.shape_groups, self.target_letter):
if l == letter_to_insert:
letter_inds = group.shape_ids
return letter_inds[0], letter_inds[-1], len(letter_inds)
def reset(self):
points = torch.cat([point.clone().detach() for point in self.parameters.point])
self.angles = self.get_angles(points)
def init_faces(self, device: torch.device) -> torch.tensor:
faces_ = []
for j, c in enumerate(self.target_letter):
points_np = [self.parameters.point[i].clone().detach().cpu().numpy() for i in range(len(self.parameters.point))]
start_ind, end_ind, shapes_per_letter = self.get_letter_inds(c)
print(c, start_ind, end_ind)
holes = []
if shapes_per_letter > 1:
holes = points_np[start_ind+1:end_ind]
poly = Polygon(points_np[start_ind], holes=holes)
poly = poly.buffer(0)
points_np = np.concatenate(points_np)
faces = Delaunay(points_np).simplices
is_intersect = np.array([poly.contains(Point(points_np[face].mean(0))) for face in faces], dtype=np.bool_)
faces_.append(torch.from_numpy(faces[is_intersect]).to(device, dtype=torch.int64))
return faces_
def __call__(self) -> torch.Tensor:
loss_angles = 0
points = torch.cat(self.parameters.point)
angles = self.get_angles(points)
for i in range(len(self.faces)):
loss_angles += (nnf.mse_loss(angles[i], self.angles[i]))
return loss_angles