| import PIL.Image |
| import diffusers |
| import numpy as np |
| import diffusers |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from torchvision import transforms |
| import torch |
|
|
| class BlendedLatentDiffusion: |
| def __init__(self): |
| model_name = "CompVis/stable-diffusion-v1-4" |
| text_model_name = "openai/clip-vit-large-patch14" |
|
|
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
| |
| self.autoencoder = diffusers.AutoencoderKL.from_pretrained(model_name, subfolder="vae") |
| self.text_encoder = CLIPTextModel.from_pretrained(text_model_name) |
| self.tokenizer = CLIPTokenizer.from_pretrained(text_model_name) |
| self.unet = diffusers.UNet2DConditionModel.from_pretrained(model_name, subfolder="unet") |
| |
| |
| |
| self.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_name, subfolder="scheduler", algorithm_type="dpmsolver++", use_karras_sigmas=True) |
|
|
|
|
| |
| self.autoencoder.to(device=self.device, dtype=self.dtype).eval() |
| self.text_encoder.to(device=self.device, dtype=self.dtype).eval() |
| self.unet.to(device=self.device, dtype=self.dtype).eval() |
|
|
|
|
|
|
| def blended_latent_diffusion(self, |
| init_image: PIL.Image, |
| mask_image: PIL.Image, |
| prompt: str, |
| num_inference_steps: int = 25, |
| strength: float = 0.8, |
| guidance_scale: float = 7.5 |
| ) -> PIL.Image: |
| """ |
| Applies blended latent diffusion to an input image using a mask and a text prompt. |
| |
| Args: |
| init_image (PIL.Image): The initial image to be modified. |
| mask_image (PIL.Image): A binary mask image where white areas indicate regions to modify. |
| prompt (str): The text prompt guiding the diffusion process. |
| num_inference_steps (int): The number of inference steps for the diffusion process. |
| strength (float): The strength of the diffusion effect, between 0 and 1. |
| guidance_scale (float): The scale for guidance, controlling the influence of the text prompt. |
| Returns: |
| PIL.Image: The modified image after applying blended latent diffusion. |
| """ |
| |
| print(f"🚀 Starting Blended Latent Diffusion | Prompt: '{prompt}'") |
| print(f"📦 Configurations | Steps: {num_inference_steps} | CFG Scale: {guidance_scale}") |
| |
| |
| init_image = init_image.convert("RGB") |
| mask_image = mask_image.convert("L") |
|
|
| print("⏳ Encoding initial image to latent space...") |
| |
| latent_init = self.encode_to_latent(init_image) |
| mask_transform = self.preprocess_mask(mask_image) |
| print(f"✅ Latents Prepared | Shape: {list(latent_init.shape)} | Mask Shape: {list(mask_transform.shape)}") |
| |
| |
| print("⏳ Processing text prompt and creating base noise...") |
| noise = self.generate_noise_from_prompt(prompt, latent_init.shape) |
| print(f"✅ Text Embeddings Configured | Embedded Shape: {list(noise[1].shape)}") |
| |
| |
| print(f"⏳ Entering Denoising Loop ({num_inference_steps} steps via {self.scheduler.__class__.__name__})...") |
| blended_latent = self.blend_latent_with_mask( |
| latent_init, noise, mask_transform, strength, num_inference_steps, guidance_scale) |
| print("✅ Latent optimization sequence complete.") |
| |
| print("⏳ Decoding final blended latents back to image pixels...") |
| |
| output_image = self.decode_from_latent(blended_latent) |
| print("✨ Process complete! Returning output image.") |
|
|
| return output_image |
|
|
|
|
| def encode_to_latent(self, init_image: PIL.Image) -> torch.Tensor: |
| preprocess = transforms.Compose([ |
| transforms.Resize((512, 512)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]) |
| ]) |
| input_tensor = preprocess(init_image).unsqueeze(0).to(self.device, dtype=self.dtype) |
| with torch.no_grad(): |
| latents = self.autoencoder.encode(input_tensor).latent_dist.sample() |
| return latents * 0.18215 |
| |
| def preprocess_mask(self, mask_image: PIL.Image) -> torch.Tensor: |
| |
| mask = mask_image.resize((64, 64), resample=PIL.Image.NEAREST) |
| mask = transforms.ToTensor()(mask).to(self.device, dtype=self.dtype) |
| |
| |
| return mask.unsqueeze(0) |
|
|
|
|
| |
| def generate_noise_from_prompt(self, prompt: str, latent_shape: torch.Size) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Prepares text context with CFG support and creates the baseline noise vector. |
| """ |
| |
| text_inputs = self.tokenizer( |
| prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt" |
| ) |
| text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state |
|
|
| |
| uncond_inputs = self.tokenizer( |
| "", padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt" |
| ) |
| uncond_embeddings = self.text_encoder(uncond_inputs.input_ids.to(self.device)).last_hidden_state |
|
|
| |
| |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| |
| |
| init_noise = torch.randn(latent_shape, device=self.device, dtype=self.dtype) |
| |
| return init_noise, text_embeddings |
|
|
| |
| def blend_latent_with_mask( |
| self, |
| latent_init: torch.Tensor, |
| noise_package: tuple[torch.Tensor, torch.Tensor], |
| mask_tensor: torch.Tensor, |
| strength: float, |
| num_inference_steps: int = 25, |
| guidance_scale: float = 7.5 |
| ) -> torch.Tensor: |
| """ |
| Executes Blended Latent Diffusion using DPMSolverMultistepScheduler |
| with strict 1D tensor array conversion for add_noise compatibility. |
| """ |
| init_noise, text_embeddings = noise_package |
| |
| |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) |
| |
| |
| init_timestep_idx = int(num_inference_steps * (1 - strength)) |
| timesteps = self.scheduler.timesteps[init_timestep_idx:] |
| |
| |
| if hasattr(self.scheduler, "set_begin_index"): |
| self.scheduler.set_begin_index(init_timestep_idx) |
| |
| |
| start_t = timesteps[0].item() |
| start_timestep_tensor = torch.tensor([start_t], device=self.device, dtype=torch.long) |
| |
| |
| latents_fg = self.scheduler.add_noise(latent_init, init_noise, start_timestep_tensor) |
|
|
| |
| fresh_bg_noise = torch.randn_like(latent_init) |
|
|
| |
| for idx, t in enumerate(timesteps): |
| |
| current_t_val = t.item() if isinstance(t, torch.Tensor) else t |
| t_tensor = torch.tensor([current_t_val], device=self.device, dtype=torch.long) |
| |
| |
| latents_bg = self.scheduler.add_noise(latent_init, fresh_bg_noise, t_tensor) |
| |
| |
| latents_fg = mask_tensor * latents_fg + (1.0 - mask_tensor) * latents_bg |
|
|
| |
| latent_model_input = torch.cat([latents_fg] * 2) |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
| |
| with torch.no_grad(): |
| noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
| |
| |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
| |
| latents_fg = self.scheduler.step(noise_pred, t, latents_fg).prev_sample |
|
|
| |
| latents_fg = mask_tensor * latents_fg + (1.0 - mask_tensor) * latent_init |
| |
| return latents_fg |
|
|
| |
| def decode_from_latent(self, blended_latent: torch.Tensor) -> PIL.Image: |
| |
| latents = blended_latent / 0.18215 |
| with torch.no_grad(): |
| image_tensor = self.autoencoder.decode(latents).sample |
| |
| |
| image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1) |
| image_tensor = image_tensor.cpu().permute(0, 2, 3, 1).float().numpy() |
| image_numpy = (image_tensor * 255).astype("uint8")[0] |
| return PIL.Image.fromarray(image_numpy) |
| |
|
|
|
|
| def main(): |
| blended_diffusion = BlendedLatentDiffusion() |
| |
| init_image = PIL.Image.open("/home/aviad/interview/mobileye/messi.jpg") |
| mask_image = PIL.Image.open("/home/aviad/interview/mobileye/messi_mask.png") |
| output = blended_diffusion.blended_latent_diffusion( |
| init_image=init_image, |
| mask_image=mask_image, |
| prompt="fluffy white clouds in a bright blue sky, highly detailed", |
| num_inference_steps=25, |
| strength=0.95, |
| guidance_scale=12.0 |
| ) |
| output.save("output_image.jpg") |
| |
| if __name__ == "__main__": |
| main() |