aviadcohz's picture
Initial commit: Added blended diffusion boilerplate
172c37c
import PIL.Image
import diffusers
import numpy as np
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer
from torchvision import transforms
import torch
class BlendedLatentDiffusion:
def __init__(self):
model_name = "CompVis/stable-diffusion-v1-4"
text_model_name = "openai/clip-vit-large-patch14"
# 1. Hardware configuration
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float16 for speed and lower VRAM if running on GPU, otherwise float32
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
# 2. Load all model components
self.autoencoder = diffusers.AutoencoderKL.from_pretrained(model_name, subfolder="vae")
self.text_encoder = CLIPTextModel.from_pretrained(text_model_name)
self.tokenizer = CLIPTokenizer.from_pretrained(text_model_name)
self.unet = diffusers.UNet2DConditionModel.from_pretrained(model_name, subfolder="unet")
# 3. Load the Scheduler (DDIMScheduler is ideal for Blended Diffusion)
# self.scheduler = diffusers.DDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
self.scheduler = diffusers.DPMSolverMultistepScheduler.from_pretrained(model_name, subfolder="scheduler", algorithm_type="dpmsolver++", use_karras_sigmas=True) # Alternative scheduler for experimentation
# 4. Cast and move models to target device
self.autoencoder.to(device=self.device, dtype=self.dtype).eval()
self.text_encoder.to(device=self.device, dtype=self.dtype).eval()
self.unet.to(device=self.device, dtype=self.dtype).eval()
def blended_latent_diffusion(self,
init_image: PIL.Image,
mask_image: PIL.Image,
prompt: str,
num_inference_steps: int = 25,
strength: float = 0.8,
guidance_scale: float = 7.5
) -> PIL.Image:
"""
Applies blended latent diffusion to an input image using a mask and a text prompt.
Args:
init_image (PIL.Image): The initial image to be modified.
mask_image (PIL.Image): A binary mask image where white areas indicate regions to modify.
prompt (str): The text prompt guiding the diffusion process.
num_inference_steps (int): The number of inference steps for the diffusion process.
strength (float): The strength of the diffusion effect, between 0 and 1.
guidance_scale (float): The scale for guidance, controlling the influence of the text prompt.
Returns:
PIL.Image: The modified image after applying blended latent diffusion.
"""
print(f"🚀 Starting Blended Latent Diffusion | Prompt: '{prompt}'")
print(f"📦 Configurations | Steps: {num_inference_steps} | CFG Scale: {guidance_scale}")
# Step 1: Preprocess the input images
init_image = init_image.convert("RGB")
mask_image = mask_image.convert("L") # Convert to grayscale for masking
print("⏳ Encoding initial image to latent space...")
# Step 2: Encode the initial image into latent space and transform the mask
latent_init = self.encode_to_latent(init_image)
mask_transform = self.preprocess_mask(mask_image)
print(f"✅ Latents Prepared | Shape: {list(latent_init.shape)} | Mask Shape: {list(mask_transform.shape)}")
# Step 3: Generate noise based on the prompt
print("⏳ Processing text prompt and creating base noise...")
noise = self.generate_noise_from_prompt(prompt, latent_init.shape)
print(f"✅ Text Embeddings Configured | Embedded Shape: {list(noise[1].shape)}")
# Step 4: Blend the noise with the latent representation using the mask
print(f"⏳ Entering Denoising Loop ({num_inference_steps} steps via {self.scheduler.__class__.__name__})...")
blended_latent = self.blend_latent_with_mask(
latent_init, noise, mask_transform, strength, num_inference_steps, guidance_scale)
print("✅ Latent optimization sequence complete.")
print("⏳ Decoding final blended latents back to image pixels...")
# Step 5: Decode the blended latent representation back to an image
output_image = self.decode_from_latent(blended_latent)
print("✨ Process complete! Returning output image.")
return output_image
def encode_to_latent(self, init_image: PIL.Image) -> torch.Tensor:
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
input_tensor = preprocess(init_image).unsqueeze(0).to(self.device, dtype=self.dtype)
with torch.no_grad():
latents = self.autoencoder.encode(input_tensor).latent_dist.sample()
return latents * 0.18215
def preprocess_mask(self, mask_image: PIL.Image) -> torch.Tensor:
# Resize to latent space size (512 / 8 = 64)
mask = mask_image.resize((64, 64), resample=PIL.Image.NEAREST)
mask = transforms.ToTensor()(mask).to(self.device, dtype=self.dtype) # Shape: [1, 64, 64]
# FIX: Add a batch dimension to make it [1, 1, 64, 64] for clean matrix broadcasting
return mask.unsqueeze(0)
def generate_noise_from_prompt(self, prompt: str, latent_shape: torch.Size) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepares text context with CFG support and creates the baseline noise vector.
"""
# 1. Encode the positive conditional prompt
text_inputs = self.tokenizer(
prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt"
)
text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state
# 2. Encode the unconditional empty prompt (negative guidance)
uncond_inputs = self.tokenizer(
"", padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_inputs.input_ids.to(self.device)).last_hidden_state
# 3. Concatenate them into a single batch for parallel UNet processing
# Shape becomes [2, 77, 768]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# 4. Generate the single static base noise layout
init_noise = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
return init_noise, text_embeddings
def blend_latent_with_mask(
self,
latent_init: torch.Tensor,
noise_package: tuple[torch.Tensor, torch.Tensor],
mask_tensor: torch.Tensor,
strength: float,
num_inference_steps: int = 25,
guidance_scale: float = 7.5
) -> torch.Tensor:
"""
Executes Blended Latent Diffusion using DPMSolverMultistepScheduler
with strict 1D tensor array conversion for add_noise compatibility.
"""
init_noise, text_embeddings = noise_package
# 1. Initialize full steps on the scheduler
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
# 2. Slice timesteps based on strength parameter
init_timestep_idx = int(num_inference_steps * (1 - strength))
timesteps = self.scheduler.timesteps[init_timestep_idx:]
# 3. Configure multi-step tracking properties
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(init_timestep_idx)
# 4. FIX: Force the starting step to be a 1D vector tensor to prevent IndexError
start_t = timesteps[0].item()
start_timestep_tensor = torch.tensor([start_t], device=self.device, dtype=torch.long)
# Initialize foreground latents with properly scaled starting noise
latents_fg = self.scheduler.add_noise(latent_init, init_noise, start_timestep_tensor)
# 5. Generate a single background noise layout to maintain calculation history
fresh_bg_noise = torch.randn_like(latent_init)
# 6. Core Denoising Loop
for idx, t in enumerate(timesteps):
# A. FIX: Force the loop timestep 't' into a 1D vector tensor for add_noise safety
current_t_val = t.item() if isinstance(t, torch.Tensor) else t
t_tensor = torch.tensor([current_t_val], device=self.device, dtype=torch.long)
# Prepare background for current timestep 't' using the 1D tensor
latents_bg = self.scheduler.add_noise(latent_init, fresh_bg_noise, t_tensor)
# B. Spatial Blending: Sync background state to keep boundaries crisp
latents_fg = mask_tensor * latents_fg + (1.0 - mask_tensor) * latents_bg
# C. Duplicate inputs for Classifier-Free Guidance (CFG) processing
latent_model_input = torch.cat([latents_fg] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# D. Predict noise maps using the UNet configuration
with torch.no_grad():
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# E. Split predictions and extrapolate prompt guidance strength
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# F. Step foreground latents backward one step in time
latents_fg = self.scheduler.step(noise_pred, t, latents_fg).prev_sample
# 7. Final blending pass at t=0 to keep unmasked original pixels perfectly clean
latents_fg = mask_tensor * latents_fg + (1.0 - mask_tensor) * latent_init
return latents_fg
def decode_from_latent(self, blended_latent: torch.Tensor) -> PIL.Image:
# Undo the VAE scaling factor
latents = blended_latent / 0.18215
with torch.no_grad():
image_tensor = self.autoencoder.decode(latents).sample
# Convert tensor back to PIL Image
image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1) # Rescale back to [0, 1]
image_tensor = image_tensor.cpu().permute(0, 2, 3, 1).float().numpy()
image_numpy = (image_tensor * 255).astype("uint8")[0]
return PIL.Image.fromarray(image_numpy)
def main():
blended_diffusion = BlendedLatentDiffusion()
init_image = PIL.Image.open("/home/aviad/interview/mobileye/messi.jpg")
mask_image = PIL.Image.open("/home/aviad/interview/mobileye/messi_mask.png")
output = blended_diffusion.blended_latent_diffusion(
init_image=init_image,
mask_image=mask_image,
prompt="fluffy white clouds in a bright blue sky, highly detailed",
num_inference_steps=25,
strength=0.95, # High strength allows completely overwriting the target area
guidance_scale=12.0 # Slightly higher scale forces strong prompt adhesion over background textures
)
output.save("output_image.jpg")
if __name__ == "__main__":
main()