import logging import math from typing import Dict, List, Optional, Tuple import PIL import PIL.Image import torch from diffusers import DiffusionPipeline from rewards import clip_img_transform from rewards.base_reward import BaseRewardLoss class LatentNoiseTrainer: """Trainer for optimizing latents with reward losses.""" def __init__( self, reward_losses: List[BaseRewardLoss], model: DiffusionPipeline, n_iters: int, n_inference_steps: int, seed: int, no_optim: bool = False, regularize: bool = True, regularization_weight: float = 0.01, grad_clip: float = 0.1, log_metrics: bool = True, save_all_images: bool = False, imageselect: bool = False, device: torch.device = torch.device("cuda"), ): self.reward_losses = reward_losses self.model = model self.n_iters = n_iters self.n_inference_steps = n_inference_steps self.seed = seed self.no_optim = no_optim self.regularize = regularize self.regularization_weight = regularization_weight self.grad_clip = grad_clip self.log_metrics = log_metrics self.save_all_images = save_all_images self.imageselect = imageselect self.device = device self.preprocess_fn = clip_img_transform(224) def train( self, latents: torch.Tensor, prompt: str, optimizer: torch.optim.Optimizer, save_dir: Optional[str] = None, progress_callback=None, ) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]: logging.info(f"Optimizing latents for prompt '{prompt}'.") best_loss = torch.inf best_image = None initial_rewards = None best_rewards = None latent_dim = math.prod(latents.shape[1:]) for iteration in range(self.n_iters): to_log = "" rewards = {} optimizer.zero_grad() generator = torch.Generator("cuda").manual_seed(self.seed) if self.imageselect: new_latents = torch.randn_like( latents, device=self.device, dtype=latents.dtype ) image = self.model.apply( new_latents, prompt, generator=generator, num_inference_steps=self.n_inference_steps, ) else: image = self.model.apply( latents, prompt, generator=generator, num_inference_steps=self.n_inference_steps, ) if self.no_optim: best_image = image break total_loss = 0 preprocessed_image = self.preprocess_fn(image) for reward_loss in self.reward_losses: loss = reward_loss(preprocessed_image, prompt) to_log += f"{reward_loss.name}: {loss.item():.4f}, " total_loss += loss * reward_loss.weighting rewards[reward_loss.name] = loss.item() rewards["total"] = total_loss.item() to_log += f"Total: {total_loss.item():.4f}" total_reward_loss = total_loss.item() if self.regularize: # compute in fp32 to avoid overflow latent_norm = torch.linalg.vector_norm(latents).to(torch.float32) log_norm = torch.log(latent_norm) regularization = self.regularization_weight * ( 0.5 * latent_norm**2 - (latent_dim - 1) * log_norm ) to_log += f", Latent norm: {latent_norm.item()}" rewards["norm"] = latent_norm.item() total_loss += regularization.to(total_loss.dtype) if self.log_metrics: logging.info(f"Iteration {iteration}: {to_log}") if initial_rewards is None: initial_rewards = rewards if total_reward_loss < best_loss: best_loss = total_reward_loss best_image = image best_rewards = rewards if iteration != self.n_iters - 1 and not self.imageselect: total_loss.backward() torch.nn.utils.clip_grad_norm_(latents, self.grad_clip) optimizer.step() if self.save_all_images: image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0] image_pil.save(f"{save_dir}/{iteration}.png") if progress_callback: progress_callback(iteration + 1) image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy() image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0] return image_pil, initial_rewards, best_rewards