Spaces:
Running
on
Zero
Running
on
Zero
try: | |
import spaces | |
from spaces import GPU | |
except ImportError: | |
def GPU(*args, **kwargs): | |
if len(args) == 1 and callable(args[0]): | |
# Used as @GPU without parameters | |
return args[0] | |
# Used as @GPU() with parameters | |
def decorator(func): | |
async def wrapper(*func_args, **func_kwargs): | |
return await func(*func_args, **func_kwargs) if asyncio.iscoroutinefunction(func) else func(*func_args, **func_kwargs) | |
return wrapper | |
return decorator | |
import torch | |
import timm | |
from torch import nn, tensor | |
from torchvision import transforms | |
from functools import partial | |
import fastcore.all as fc | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, AutoencoderKL | |
from pathlib import Path | |
import torch.nn.functional as F | |
import gc | |
import sys | |
import traceback | |
from tqdm.auto import tqdm | |
import logging | |
import numpy as np | |
# Constants | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
DIMENSION = 512 | |
MODEL_ID = "stabilityai/stable-diffusion-2-1" | |
# Helper Classes | |
class Hook(): | |
def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self)) | |
def remove(self): self.hook.remove() | |
def __del__(self): self.remove() | |
class Hooks(list): | |
def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms]) | |
def __enter__(self, *args): return self | |
def __exit__ (self, *args): self.remove() | |
def __del__(self): self.remove() | |
def __delitem__(self, i): | |
self[i].remove() | |
super().__delitem__(i) | |
def remove(self): | |
for h in self: h.remove() | |
# Helper Functions | |
def get_features(hook, mod, inp, outp): | |
hook.features = outp.clone() | |
def normalize(im): | |
imagenet_mean = tensor([0.485, 0.456, 0.406])[:,None,None].to(im.device) | |
imagenet_std = tensor([0.229, 0.224, 0.225])[:,None,None].to(im.device) | |
return (im - imagenet_mean) / imagenet_std | |
def pil_to_latent(input_im, vae): | |
with torch.no_grad(): | |
latent = vae.encode(transforms.ToTensor()(input_im).unsqueeze(0).to(DEVICE).half()*2-1) | |
return 0.18215 * latent.latent_dist.sample() | |
def latents_to_pil(latents, vae): | |
latents = (1 / 0.18215) * latents | |
with torch.no_grad(): | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def calc_grams(img): | |
return torch.einsum('chw, dhw -> cd', img, img) / (img.shape[-2]*img.shape[-1]) | |
def clean_mem(): | |
if hasattr(sys, 'last_traceback'): | |
traceback.clear_frames(sys.last_traceback) | |
gc.collect() | |
with torch.cuda.device(DEVICE): | |
torch.cuda.empty_cache() | |
# Model Setup Functions | |
def init_models(): | |
model_id = MODEL_ID | |
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") | |
vae = AutoencoderKL.from_pretrained( | |
model_id, | |
subfolder="vae", | |
torch_dtype=torch.float16 | |
).to(DEVICE) | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
scheduler=scheduler, | |
revision="fp16", | |
torch_dtype=torch.float16, | |
safety_checker=None | |
).to(DEVICE) | |
return pipe, vae, scheduler | |
def setup_vgg(): | |
vgg16 = timm.create_model('vgg16', pretrained=True).to(DEVICE).features | |
layers = [i-1 for i,m in enumerate(vgg16.children()) if isinstance(m,nn.MaxPool2d)] | |
vgg16_layers = [m for i,m in enumerate(vgg16) if i in layers] | |
return vgg16, vgg16_layers | |
# Loss Classes | |
class ContentLossToTarget(): | |
def __init__(self, target_im, vgg16, vgg16_layers, layer_weights=(1, 1, 0, 0, 0)): | |
self.vgg16 = vgg16 | |
self.vgg16_layers = vgg16_layers | |
self.layer_weights = layer_weights | |
with torch.no_grad(): | |
x = normalize(target_im.squeeze()) | |
with Hooks(vgg16_layers, partial(get_features)) as hooks: | |
vgg16(x) | |
self.target_features = [h.features for h in hooks] | |
def __call__(self, input_im): | |
with Hooks(self.vgg16_layers, partial(get_features)) as hooks: | |
x = normalize(input_im.squeeze()) | |
self.vgg16(x) | |
image_features = [h.features for h in hooks] | |
return sum(abs(f1-f2).mean()*w for f1, f2, w in | |
zip(image_features, self.target_features, self.layer_weights)) | |
class StyleLossToTarget(): | |
def __init__(self, target_im, vgg16, vgg16_layers, layer_weights=(1, 1, 1, 1, 1)): | |
self.vgg16 = vgg16 | |
self.vgg16_layers = vgg16_layers | |
self.layer_weights = layer_weights | |
with torch.no_grad(): | |
x = normalize(target_im.squeeze()) | |
with Hooks(vgg16_layers, partial(get_features)) as hooks: | |
vgg16(x) | |
self.target_features = [h.features for h in hooks] | |
def __call__(self, input_im): | |
with Hooks(self.vgg16_layers, partial(get_features)) as hooks: | |
x = normalize(input_im.squeeze()) | |
self.vgg16(x) | |
image_features = [h.features for h in hooks] | |
return sum(abs(calc_grams(f1)-calc_grams(f2)).mean()*w for f1, f2, w in | |
zip(image_features, self.target_features, self.layer_weights)) | |
# Main Processing Function | |
def process_images(init_image, style_image, prompt, negative_prompt, inference_steps, strength, | |
style_g1, style_g2, style_g3, style_g4, style_g5, | |
content_g1, content_g2, content_g3, content_g4, content_g5, | |
latent_guidance): | |
try: | |
# Initialize models | |
pipe, vae, scheduler = init_models() | |
vgg16, vgg16_layers = setup_vgg() | |
# Process images | |
init_image = init_image.resize((DIMENSION, DIMENSION)) | |
style_image = style_image.resize((DIMENSION, DIMENSION)) | |
# Transform images | |
style_transform = transforms.Compose([transforms.ToTensor()]) | |
style_tensor = style_transform(style_image) | |
init_tensor = style_transform(init_image) | |
# Initialize latents | |
style_latents = pil_to_latent(style_image, vae) | |
init_image_latents = pil_to_latent(init_image, vae) | |
# Normalize tensors | |
mean = [0.485, 0.456, 0.406] | |
std = [0.229, 0.224, 0.225] | |
mean_tensor = torch.Tensor(mean).view(1,1,-1).permute(2, 0, 1).to(DEVICE) | |
std_tensor = torch.Tensor(std).view(1,1,-1).permute(2, 0, 1).to(DEVICE) | |
norm_style_tensor = (style_tensor.to(DEVICE) - mean_tensor) / std_tensor | |
norm_style_tensor = norm_style_tensor.unsqueeze(dim=0) | |
# Setup losses | |
# style_loss = StyleLossToTarget(norm_style_tensor, vgg16, vgg16_layers, | |
# layer_weights=(style_guidance**2, style_guidance**2, style_guidance**2, 0, 0)) | |
# content_loss = ContentLossToTarget(norm_style_tensor, vgg16, vgg16_layers, | |
# layer_weights=(0, content_guidance**2, content_guidance**2, content_guidance**2, 0)) | |
# Setup losses with correct layer weights | |
# style_loss = StyleLossToTarget( | |
# norm_style_tensor, | |
# vgg16, | |
# vgg16_layers, | |
# layer_weights=( | |
# (style_guidance * 5)**2, | |
# (style_guidance * 5)**2, | |
# (style_guidance * 5)**2, | |
# 0, | |
# 0 | |
# ) | |
# ) | |
# content_loss = ContentLossToTarget( | |
# norm_style_tensor, | |
# vgg16, | |
# vgg16_layers, | |
# layer_weights=( | |
# content_guidance, | |
# content_guidance, | |
# 0, | |
# 0, | |
# 0 | |
# ) | |
# ) | |
# Setup losses with individual layer weights | |
style_loss = StyleLossToTarget( | |
norm_style_tensor, | |
vgg16, | |
vgg16_layers, | |
layer_weights=( | |
(style_g1 * 5)**2, | |
(style_g2 * 5)**2, | |
(style_g3 * 5)**2, | |
(style_g4 * 5)**2, | |
(style_g5 * 5)**2 | |
) | |
) | |
content_loss = ContentLossToTarget( | |
norm_style_tensor, | |
vgg16, | |
vgg16_layers, | |
layer_weights=( | |
content_g1, | |
content_g2, | |
content_g3, | |
content_g4, | |
content_g5 | |
) | |
) | |
# Prepare for inference | |
scheduler.set_timesteps(inference_steps) | |
offset = scheduler.config.get("steps_offset", 0) | |
start_step = int(inference_steps * strength) + offset | |
# Generate initial noise | |
generator = torch.Generator(device=DEVICE) | |
generator.manual_seed(42) | |
noise = torch.randn( | |
init_image_latents.shape, | |
generator=generator, | |
device=DEVICE, | |
dtype=torch.float16 | |
) | |
# Add noise to input image | |
latents = scheduler.add_noise( | |
init_image_latents, | |
noise, | |
timesteps=torch.tensor([scheduler.timesteps[start_step]]) | |
) | |
# Encode text embeddings | |
text_embeddings = pipe._encode_prompt( | |
prompt, | |
DEVICE, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt | |
) | |
# Initialize loss function | |
mae_loss = torch.nn.L1Loss() | |
# Denoising loop | |
timesteps = scheduler.timesteps | |
for i, t in enumerate(tqdm(scheduler.timesteps)): | |
# Expand latents for classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
# Predict noise | |
with torch.no_grad(): | |
noise_pred = pipe.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings | |
).sample | |
# Perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond) | |
noise_pred = noise_pred/noise_pred.norm()*noise_pred_uncond.norm() | |
# Store current step | |
pipe.scheduler._step_index = i | |
#print(f"{i} of {inference_steps} - {start_step}") | |
if i > start_step: | |
if i < int(0.8 * inference_steps): | |
latents = latents.detach().requires_grad_() | |
current_step = pipe.scheduler._step_index | |
# print(f"Step {i} - Current scheduler step: {current_step}") | |
# print(f"Timestep t: {t}") | |
# print(f"Sigma: {scheduler.sigmas[i]}") | |
# Get prediction of original sample | |
step_output = scheduler.step(noise_pred, t, latents) | |
latents_x0 = step_output.pred_original_sample | |
# print(f"Latents x0 stats - Mean: {latents_x0.mean():.4f}, Std: {latents_x0.std():.4f}") | |
pipe.scheduler._step_index = current_step | |
# Process through VAE | |
latents_x0_vae = latents_x0.half() | |
denoised_images = vae.decode((1 / 0.18215) * latents_x0_vae).sample / 2 + 0.5 | |
denoised_images = denoised_images.clamp(0, 1) | |
# Calculate losses | |
norm_image_tensor = (denoised_images.squeeze() - mean_tensor) / std_tensor | |
norm_image_tensor = norm_image_tensor.unsqueeze(dim=0) | |
# Debug print | |
# print(f"Step {i} - ", end='') | |
content_loss_scale = 17.6 | |
loss = content_loss(norm_image_tensor) * content_loss_scale | |
# print(f"content_loss {loss.item()}") | |
style_loss_val = style_loss(norm_image_tensor) * 0.5 | |
# print(f"style_loss_val {style_loss_val.item()}") | |
latent_loss_val = mae_loss(latents_x0, style_latents) * latent_guidance | |
# print(f"latent_loss_val {latent_loss_val.item()}") | |
loss += style_loss_val | |
loss += latent_loss_val | |
# print(f"loss {loss.item()}") | |
# Calculate and apply gradients | |
cond_grad = torch.autograd.grad(loss, latents)[0] | |
# print(f"Gradient stats - Mean: {cond_grad.mean():.4f}, Std: {cond_grad.std():.4f}") | |
latents = latents.detach() - cond_grad * scheduler.sigmas[i].to(DEVICE)**2 | |
torch.cuda.empty_cache() | |
latents = scheduler.step(noise_pred, t, latents).prev_sample | |
# Decode final image | |
with torch.no_grad(): | |
image = pipe.decode_latents(latents) | |
image = pipe.numpy_to_pil(image)[0] | |
clean_mem() | |
return image # Fixed - return the processed image | |
except Exception as e: | |
clean_mem() | |
raise RuntimeError(f"Error during processing: {str(e)}") | |