import PIL.Image import diffusers import numpy as np import diffusers from transformers import CLIPTextModel, CLIPTokenizer from torchvision import transforms import torch class BlendedLatentDiffusion: def __init__(self): model_name = "CompVis/stable-diffusion-v1-4" text_model_name = "openai/clip-vit-large-patch14" # 1. Hardware configuration self.device = "cuda" if torch.cuda.is_available() else "cpu" # Use float16 for speed and lower VRAM if running on GPU, otherwise float32 self.dtype = torch.float16 if self.device == "cuda" else torch.float32 # 2. Load all model components self.autoencoder = diffusers.AutoencoderKL.from_pretrained(model_name, subfolder="vae") self.text_encoder = CLIPTextModel.from_pretrained(text_model_name) self.tokenizer = CLIPTokenizer.from_pretrained(text_model_name) self.unet = diffusers.UNet2DConditionModel.from_pretrained(model_name, subfolder="unet") # 3. Load the Scheduler (DDIMScheduler is ideal for Blended Diffusion) # self.scheduler = diffusers.DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") self.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_name, subfolder="scheduler", algorithm_type="dpmsolver++", use_karras_sigmas=True) # Alternative scheduler for experimentation # 4. Cast and move models to target device self.autoencoder.to(device=self.device, dtype=self.dtype).eval() self.text_encoder.to(device=self.device, dtype=self.dtype).eval() self.unet.to(device=self.device, dtype=self.dtype).eval() def blended_latent_diffusion(self, init_image: PIL.Image, mask_image: PIL.Image, prompt: str, num_inference_steps: int = 25, strength: float = 0.8, guidance_scale: float = 7.5 ) -> PIL.Image: """ Applies blended latent diffusion to an input image using a mask and a text prompt. Args: init_image (PIL.Image): The initial image to be modified. mask_image (PIL.Image): A binary mask image where white areas indicate regions to modify. prompt (str): The text prompt guiding the diffusion process. num_inference_steps (int): The number of inference steps for the diffusion process. strength (float): The strength of the diffusion effect, between 0 and 1. guidance_scale (float): The scale for guidance, controlling the influence of the text prompt. Returns: PIL.Image: The modified image after applying blended latent diffusion. """ print(f"🚀 Starting Blended Latent Diffusion | Prompt: '{prompt}'") print(f"📦 Configurations | Steps: {num_inference_steps} | CFG Scale: {guidance_scale}") # Step 1: Preprocess the input images init_image = init_image.convert("RGB") mask_image = mask_image.convert("L") # Convert to grayscale for masking print("⏳ Encoding initial image to latent space...") # Step 2: Encode the initial image into latent space and transform the mask latent_init = self.encode_to_latent(init_image) mask_transform = self.preprocess_mask(mask_image) print(f"✅ Latents Prepared | Shape: {list(latent_init.shape)} | Mask Shape: {list(mask_transform.shape)}") # Step 3: Generate noise based on the prompt print("⏳ Processing text prompt and creating base noise...") noise = self.generate_noise_from_prompt(prompt, latent_init.shape) print(f"✅ Text Embeddings Configured | Embedded Shape: {list(noise[1].shape)}") # Step 4: Blend the noise with the latent representation using the mask print(f"⏳ Entering Denoising Loop ({num_inference_steps} steps via {self.scheduler.__class__.__name__})...") blended_latent = self.blend_latent_with_mask( latent_init, noise, mask_transform, strength, num_inference_steps, guidance_scale) print("✅ Latent optimization sequence complete.") print("⏳ Decoding final blended latents back to image pixels...") # Step 5: Decode the blended latent representation back to an image output_image = self.decode_from_latent(blended_latent) print("✨ Process complete! Returning output image.") return output_image def encode_to_latent(self, init_image: PIL.Image) -> torch.Tensor: preprocess = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) input_tensor = preprocess(init_image).unsqueeze(0).to(self.device, dtype=self.dtype) with torch.no_grad(): latents = self.autoencoder.encode(input_tensor).latent_dist.sample() return latents * 0.18215 def preprocess_mask(self, mask_image: PIL.Image) -> torch.Tensor: # Resize to latent space size (512 / 8 = 64) mask = mask_image.resize((64, 64), resample=PIL.Image.NEAREST) mask = transforms.ToTensor()(mask).to(self.device, dtype=self.dtype) # Shape: [1, 64, 64] # FIX: Add a batch dimension to make it [1, 1, 64, 64] for clean matrix broadcasting return mask.unsqueeze(0) def generate_noise_from_prompt(self, prompt: str, latent_shape: torch.Size) -> tuple[torch.Tensor, torch.Tensor]: """ Prepares text context with CFG support and creates the baseline noise vector. """ # 1. Encode the positive conditional prompt text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt" ) text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state # 2. Encode the unconditional empty prompt (negative guidance) uncond_inputs = self.tokenizer( "", padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt" ) uncond_embeddings = self.text_encoder(uncond_inputs.input_ids.to(self.device)).last_hidden_state # 3. Concatenate them into a single batch for parallel UNet processing # Shape becomes [2, 77, 768] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # 4. Generate the single static base noise layout init_noise = torch.randn(latent_shape, device=self.device, dtype=self.dtype) return init_noise, text_embeddings def blend_latent_with_mask( self, latent_init: torch.Tensor, noise_package: tuple[torch.Tensor, torch.Tensor], mask_tensor: torch.Tensor, strength: float, num_inference_steps: int = 25, guidance_scale: float = 7.5 ) -> torch.Tensor: """ Executes Blended Latent Diffusion using DPMSolverMultistepScheduler with strict 1D tensor array conversion for add_noise compatibility. """ init_noise, text_embeddings = noise_package # 1. Initialize full steps on the scheduler self.scheduler.set_timesteps(num_inference_steps, device=self.device) # 2. Slice timesteps based on strength parameter init_timestep_idx = int(num_inference_steps * (1 - strength)) timesteps = self.scheduler.timesteps[init_timestep_idx:] # 3. Configure multi-step tracking properties if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(init_timestep_idx) # 4. FIX: Force the starting step to be a 1D vector tensor to prevent IndexError start_t = timesteps[0].item() start_timestep_tensor = torch.tensor([start_t], device=self.device, dtype=torch.long) # Initialize foreground latents with properly scaled starting noise latents_fg = self.scheduler.add_noise(latent_init, init_noise, start_timestep_tensor) # 5. Generate a single background noise layout to maintain calculation history fresh_bg_noise = torch.randn_like(latent_init) # 6. Core Denoising Loop for idx, t in enumerate(timesteps): # A. FIX: Force the loop timestep 't' into a 1D vector tensor for add_noise safety current_t_val = t.item() if isinstance(t, torch.Tensor) else t t_tensor = torch.tensor([current_t_val], device=self.device, dtype=torch.long) # Prepare background for current timestep 't' using the 1D tensor latents_bg = self.scheduler.add_noise(latent_init, fresh_bg_noise, t_tensor) # B. Spatial Blending: Sync background state to keep boundaries crisp latents_fg = mask_tensor * latents_fg + (1.0 - mask_tensor) * latents_bg # C. Duplicate inputs for Classifier-Free Guidance (CFG) processing latent_model_input = torch.cat([latents_fg] * 2) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # D. Predict noise maps using the UNet configuration with torch.no_grad(): noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # E. Split predictions and extrapolate prompt guidance strength noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # F. Step foreground latents backward one step in time latents_fg = self.scheduler.step(noise_pred, t, latents_fg).prev_sample # 7. Final blending pass at t=0 to keep unmasked original pixels perfectly clean latents_fg = mask_tensor * latents_fg + (1.0 - mask_tensor) * latent_init return latents_fg def decode_from_latent(self, blended_latent: torch.Tensor) -> PIL.Image: # Undo the VAE scaling factor latents = blended_latent / 0.18215 with torch.no_grad(): image_tensor = self.autoencoder.decode(latents).sample # Convert tensor back to PIL Image image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1) # Rescale back to [0, 1] image_tensor = image_tensor.cpu().permute(0, 2, 3, 1).float().numpy() image_numpy = (image_tensor * 255).astype("uint8")[0] return PIL.Image.fromarray(image_numpy) def main(): blended_diffusion = BlendedLatentDiffusion() init_image = PIL.Image.open("/home/aviad/interview/mobileye/messi.jpg") mask_image = PIL.Image.open("/home/aviad/interview/mobileye/messi_mask.png") output = blended_diffusion.blended_latent_diffusion( init_image=init_image, mask_image=mask_image, prompt="fluffy white clouds in a bright blue sky, highly detailed", num_inference_steps=25, strength=0.95, # High strength allows completely overwriting the target area guidance_scale=12.0 # Slightly higher scale forces strong prompt adhesion over background textures ) output.save("output_image.jpg") if __name__ == "__main__": main()