try: import spaces from spaces import GPU except ImportError: def GPU(*args, **kwargs): if len(args) == 1 and callable(args[0]): # Used as @GPU without parameters return args[0] # Used as @GPU() with parameters def decorator(func): async def wrapper(*func_args, **func_kwargs): return await func(*func_args, **func_kwargs) if asyncio.iscoroutinefunction(func) else func(*func_args, **func_kwargs) return wrapper return decorator import torch import timm from torch import nn, tensor from torchvision import transforms from functools import partial import fastcore.all as fc from PIL import Image from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, AutoencoderKL from pathlib import Path import torch.nn.functional as F import gc import sys import traceback from tqdm.auto import tqdm import logging import numpy as np # Constants DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DIMENSION = 512 MODEL_ID = "stabilityai/stable-diffusion-2-1" # Helper Classes class Hook(): def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self)) def remove(self): self.hook.remove() def __del__(self): self.remove() class Hooks(list): def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms]) def __enter__(self, *args): return self def __exit__ (self, *args): self.remove() def __del__(self): self.remove() def __delitem__(self, i): self[i].remove() super().__delitem__(i) def remove(self): for h in self: h.remove() # Helper Functions def get_features(hook, mod, inp, outp): hook.features = outp.clone() def normalize(im): imagenet_mean = tensor([0.485, 0.456, 0.406])[:,None,None].to(im.device) imagenet_std = tensor([0.229, 0.224, 0.225])[:,None,None].to(im.device) return (im - imagenet_mean) / imagenet_std def pil_to_latent(input_im, vae): with torch.no_grad(): latent = vae.encode(transforms.ToTensor()(input_im).unsqueeze(0).to(DEVICE).half()*2-1) return 0.18215 * latent.latent_dist.sample() def latents_to_pil(latents, vae): latents = (1 / 0.18215) * latents with torch.no_grad(): image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images def calc_grams(img): return torch.einsum('chw, dhw -> cd', img, img) / (img.shape[-2]*img.shape[-1]) def clean_mem(): if hasattr(sys, 'last_traceback'): traceback.clear_frames(sys.last_traceback) gc.collect() with torch.cuda.device(DEVICE): torch.cuda.empty_cache() # Model Setup Functions def init_models(): model_id = MODEL_ID scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") vae = AutoencoderKL.from_pretrained( model_id, subfolder="vae", torch_dtype=torch.float16 ).to(DEVICE) pipe = StableDiffusionPipeline.from_pretrained( model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16, safety_checker=None ).to(DEVICE) return pipe, vae, scheduler def setup_vgg(): vgg16 = timm.create_model('vgg16', pretrained=True).to(DEVICE).features layers = [i-1 for i,m in enumerate(vgg16.children()) if isinstance(m,nn.MaxPool2d)] vgg16_layers = [m for i,m in enumerate(vgg16) if i in layers] return vgg16, vgg16_layers # Loss Classes class ContentLossToTarget(): def __init__(self, target_im, vgg16, vgg16_layers, layer_weights=(1, 1, 0, 0, 0)): self.vgg16 = vgg16 self.vgg16_layers = vgg16_layers self.layer_weights = layer_weights with torch.no_grad(): x = normalize(target_im.squeeze()) with Hooks(vgg16_layers, partial(get_features)) as hooks: vgg16(x) self.target_features = [h.features for h in hooks] def __call__(self, input_im): with Hooks(self.vgg16_layers, partial(get_features)) as hooks: x = normalize(input_im.squeeze()) self.vgg16(x) image_features = [h.features for h in hooks] return sum(abs(f1-f2).mean()*w for f1, f2, w in zip(image_features, self.target_features, self.layer_weights)) class StyleLossToTarget(): def __init__(self, target_im, vgg16, vgg16_layers, layer_weights=(1, 1, 1, 1, 1)): self.vgg16 = vgg16 self.vgg16_layers = vgg16_layers self.layer_weights = layer_weights with torch.no_grad(): x = normalize(target_im.squeeze()) with Hooks(vgg16_layers, partial(get_features)) as hooks: vgg16(x) self.target_features = [h.features for h in hooks] def __call__(self, input_im): with Hooks(self.vgg16_layers, partial(get_features)) as hooks: x = normalize(input_im.squeeze()) self.vgg16(x) image_features = [h.features for h in hooks] return sum(abs(calc_grams(f1)-calc_grams(f2)).mean()*w for f1, f2, w in zip(image_features, self.target_features, self.layer_weights)) # Main Processing Function @GPU def process_images(init_image, style_image, prompt, negative_prompt, inference_steps, strength, style_g1, style_g2, style_g3, style_g4, style_g5, content_g1, content_g2, content_g3, content_g4, content_g5, latent_guidance): try: # Initialize models pipe, vae, scheduler = init_models() vgg16, vgg16_layers = setup_vgg() # Process images init_image = init_image.resize((DIMENSION, DIMENSION)) style_image = style_image.resize((DIMENSION, DIMENSION)) # Transform images style_transform = transforms.Compose([transforms.ToTensor()]) style_tensor = style_transform(style_image) init_tensor = style_transform(init_image) # Initialize latents style_latents = pil_to_latent(style_image, vae) init_image_latents = pil_to_latent(init_image, vae) # Normalize tensors mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] mean_tensor = torch.Tensor(mean).view(1,1,-1).permute(2, 0, 1).to(DEVICE) std_tensor = torch.Tensor(std).view(1,1,-1).permute(2, 0, 1).to(DEVICE) norm_style_tensor = (style_tensor.to(DEVICE) - mean_tensor) / std_tensor norm_style_tensor = norm_style_tensor.unsqueeze(dim=0) # Setup losses # style_loss = StyleLossToTarget(norm_style_tensor, vgg16, vgg16_layers, # layer_weights=(style_guidance**2, style_guidance**2, style_guidance**2, 0, 0)) # content_loss = ContentLossToTarget(norm_style_tensor, vgg16, vgg16_layers, # layer_weights=(0, content_guidance**2, content_guidance**2, content_guidance**2, 0)) # Setup losses with correct layer weights # style_loss = StyleLossToTarget( # norm_style_tensor, # vgg16, # vgg16_layers, # layer_weights=( # (style_guidance * 5)**2, # (style_guidance * 5)**2, # (style_guidance * 5)**2, # 0, # 0 # ) # ) # content_loss = ContentLossToTarget( # norm_style_tensor, # vgg16, # vgg16_layers, # layer_weights=( # content_guidance, # content_guidance, # 0, # 0, # 0 # ) # ) # Setup losses with individual layer weights style_loss = StyleLossToTarget( norm_style_tensor, vgg16, vgg16_layers, layer_weights=( (style_g1 * 5)**2, (style_g2 * 5)**2, (style_g3 * 5)**2, (style_g4 * 5)**2, (style_g5 * 5)**2 ) ) content_loss = ContentLossToTarget( norm_style_tensor, vgg16, vgg16_layers, layer_weights=( content_g1, content_g2, content_g3, content_g4, content_g5 ) ) # Prepare for inference scheduler.set_timesteps(inference_steps) offset = scheduler.config.get("steps_offset", 0) start_step = int(inference_steps * strength) + offset # Generate initial noise generator = torch.Generator(device=DEVICE) generator.manual_seed(42) noise = torch.randn( init_image_latents.shape, generator=generator, device=DEVICE, dtype=torch.float16 ) # Add noise to input image latents = scheduler.add_noise( init_image_latents, noise, timesteps=torch.tensor([scheduler.timesteps[start_step]]) ) # Encode text embeddings text_embeddings = pipe._encode_prompt( prompt, DEVICE, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt ) # Initialize loss function mae_loss = torch.nn.L1Loss() # Denoising loop timesteps = scheduler.timesteps for i, t in enumerate(tqdm(scheduler.timesteps)): # Expand latents for classifier free guidance latent_model_input = torch.cat([latents] * 2) latent_model_input = scheduler.scale_model_input(latent_model_input, t) # Predict noise with torch.no_grad(): noise_pred = pipe.unet( latent_model_input, t, encoder_hidden_states=text_embeddings ).sample # Perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred/noise_pred.norm()*noise_pred_uncond.norm() # Store current step pipe.scheduler._step_index = i #print(f"{i} of {inference_steps} - {start_step}") if i > start_step: if i < int(0.8 * inference_steps): latents = latents.detach().requires_grad_() current_step = pipe.scheduler._step_index # print(f"Step {i} - Current scheduler step: {current_step}") # print(f"Timestep t: {t}") # print(f"Sigma: {scheduler.sigmas[i]}") # Get prediction of original sample step_output = scheduler.step(noise_pred, t, latents) latents_x0 = step_output.pred_original_sample # print(f"Latents x0 stats - Mean: {latents_x0.mean():.4f}, Std: {latents_x0.std():.4f}") pipe.scheduler._step_index = current_step # Process through VAE latents_x0_vae = latents_x0.half() denoised_images = vae.decode((1 / 0.18215) * latents_x0_vae).sample / 2 + 0.5 denoised_images = denoised_images.clamp(0, 1) # Calculate losses norm_image_tensor = (denoised_images.squeeze() - mean_tensor) / std_tensor norm_image_tensor = norm_image_tensor.unsqueeze(dim=0) # Debug print # print(f"Step {i} - ", end='') content_loss_scale = 17.6 loss = content_loss(norm_image_tensor) * content_loss_scale # print(f"content_loss {loss.item()}") style_loss_val = style_loss(norm_image_tensor) * 0.5 # print(f"style_loss_val {style_loss_val.item()}") latent_loss_val = mae_loss(latents_x0, style_latents) * latent_guidance # print(f"latent_loss_val {latent_loss_val.item()}") loss += style_loss_val loss += latent_loss_val # print(f"loss {loss.item()}") # Calculate and apply gradients cond_grad = torch.autograd.grad(loss, latents)[0] # print(f"Gradient stats - Mean: {cond_grad.mean():.4f}, Std: {cond_grad.std():.4f}") latents = latents.detach() - cond_grad * scheduler.sigmas[i].to(DEVICE)**2 torch.cuda.empty_cache() latents = scheduler.step(noise_pred, t, latents).prev_sample # Decode final image with torch.no_grad(): image = pipe.decode_latents(latents) image = pipe.numpy_to_pil(image)[0] clean_mem() return image # Fixed - return the processed image except Exception as e: clean_mem() raise RuntimeError(f"Error during processing: {str(e)}")