chris-thomas's picture
Upload folder using huggingface_hub
b3fb382 verified
try:
import spaces
from spaces import GPU
except ImportError:
def GPU(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
# Used as @GPU without parameters
return args[0]
# Used as @GPU() with parameters
def decorator(func):
async def wrapper(*func_args, **func_kwargs):
return await func(*func_args, **func_kwargs) if asyncio.iscoroutinefunction(func) else func(*func_args, **func_kwargs)
return wrapper
return decorator
import torch
import timm
from torch import nn, tensor
from torchvision import transforms
from functools import partial
import fastcore.all as fc
from PIL import Image
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, AutoencoderKL
from pathlib import Path
import torch.nn.functional as F
import gc
import sys
import traceback
from tqdm.auto import tqdm
import logging
import numpy as np
# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DIMENSION = 512
MODEL_ID = "stabilityai/stable-diffusion-2-1"
# Helper Classes
class Hook():
def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self))
def remove(self): self.hook.remove()
def __del__(self): self.remove()
class Hooks(list):
def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms])
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
def __del__(self): self.remove()
def __delitem__(self, i):
self[i].remove()
super().__delitem__(i)
def remove(self):
for h in self: h.remove()
# Helper Functions
def get_features(hook, mod, inp, outp):
hook.features = outp.clone()
def normalize(im):
imagenet_mean = tensor([0.485, 0.456, 0.406])[:,None,None].to(im.device)
imagenet_std = tensor([0.229, 0.224, 0.225])[:,None,None].to(im.device)
return (im - imagenet_mean) / imagenet_std
def pil_to_latent(input_im, vae):
with torch.no_grad():
latent = vae.encode(transforms.ToTensor()(input_im).unsqueeze(0).to(DEVICE).half()*2-1)
return 0.18215 * latent.latent_dist.sample()
def latents_to_pil(latents, vae):
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def calc_grams(img):
return torch.einsum('chw, dhw -> cd', img, img) / (img.shape[-2]*img.shape[-1])
def clean_mem():
if hasattr(sys, 'last_traceback'):
traceback.clear_frames(sys.last_traceback)
gc.collect()
with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
# Model Setup Functions
def init_models():
model_id = MODEL_ID
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch.float16
).to(DEVICE)
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
revision="fp16",
torch_dtype=torch.float16,
safety_checker=None
).to(DEVICE)
return pipe, vae, scheduler
def setup_vgg():
vgg16 = timm.create_model('vgg16', pretrained=True).to(DEVICE).features
layers = [i-1 for i,m in enumerate(vgg16.children()) if isinstance(m,nn.MaxPool2d)]
vgg16_layers = [m for i,m in enumerate(vgg16) if i in layers]
return vgg16, vgg16_layers
# Loss Classes
class ContentLossToTarget():
def __init__(self, target_im, vgg16, vgg16_layers, layer_weights=(1, 1, 0, 0, 0)):
self.vgg16 = vgg16
self.vgg16_layers = vgg16_layers
self.layer_weights = layer_weights
with torch.no_grad():
x = normalize(target_im.squeeze())
with Hooks(vgg16_layers, partial(get_features)) as hooks:
vgg16(x)
self.target_features = [h.features for h in hooks]
def __call__(self, input_im):
with Hooks(self.vgg16_layers, partial(get_features)) as hooks:
x = normalize(input_im.squeeze())
self.vgg16(x)
image_features = [h.features for h in hooks]
return sum(abs(f1-f2).mean()*w for f1, f2, w in
zip(image_features, self.target_features, self.layer_weights))
class StyleLossToTarget():
def __init__(self, target_im, vgg16, vgg16_layers, layer_weights=(1, 1, 1, 1, 1)):
self.vgg16 = vgg16
self.vgg16_layers = vgg16_layers
self.layer_weights = layer_weights
with torch.no_grad():
x = normalize(target_im.squeeze())
with Hooks(vgg16_layers, partial(get_features)) as hooks:
vgg16(x)
self.target_features = [h.features for h in hooks]
def __call__(self, input_im):
with Hooks(self.vgg16_layers, partial(get_features)) as hooks:
x = normalize(input_im.squeeze())
self.vgg16(x)
image_features = [h.features for h in hooks]
return sum(abs(calc_grams(f1)-calc_grams(f2)).mean()*w for f1, f2, w in
zip(image_features, self.target_features, self.layer_weights))
# Main Processing Function
@GPU
def process_images(init_image, style_image, prompt, negative_prompt, inference_steps, strength,
style_g1, style_g2, style_g3, style_g4, style_g5,
content_g1, content_g2, content_g3, content_g4, content_g5,
latent_guidance):
try:
# Initialize models
pipe, vae, scheduler = init_models()
vgg16, vgg16_layers = setup_vgg()
# Process images
init_image = init_image.resize((DIMENSION, DIMENSION))
style_image = style_image.resize((DIMENSION, DIMENSION))
# Transform images
style_transform = transforms.Compose([transforms.ToTensor()])
style_tensor = style_transform(style_image)
init_tensor = style_transform(init_image)
# Initialize latents
style_latents = pil_to_latent(style_image, vae)
init_image_latents = pil_to_latent(init_image, vae)
# Normalize tensors
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean_tensor = torch.Tensor(mean).view(1,1,-1).permute(2, 0, 1).to(DEVICE)
std_tensor = torch.Tensor(std).view(1,1,-1).permute(2, 0, 1).to(DEVICE)
norm_style_tensor = (style_tensor.to(DEVICE) - mean_tensor) / std_tensor
norm_style_tensor = norm_style_tensor.unsqueeze(dim=0)
# Setup losses
# style_loss = StyleLossToTarget(norm_style_tensor, vgg16, vgg16_layers,
# layer_weights=(style_guidance**2, style_guidance**2, style_guidance**2, 0, 0))
# content_loss = ContentLossToTarget(norm_style_tensor, vgg16, vgg16_layers,
# layer_weights=(0, content_guidance**2, content_guidance**2, content_guidance**2, 0))
# Setup losses with correct layer weights
# style_loss = StyleLossToTarget(
# norm_style_tensor,
# vgg16,
# vgg16_layers,
# layer_weights=(
# (style_guidance * 5)**2,
# (style_guidance * 5)**2,
# (style_guidance * 5)**2,
# 0,
# 0
# )
# )
# content_loss = ContentLossToTarget(
# norm_style_tensor,
# vgg16,
# vgg16_layers,
# layer_weights=(
# content_guidance,
# content_guidance,
# 0,
# 0,
# 0
# )
# )
# Setup losses with individual layer weights
style_loss = StyleLossToTarget(
norm_style_tensor,
vgg16,
vgg16_layers,
layer_weights=(
(style_g1 * 5)**2,
(style_g2 * 5)**2,
(style_g3 * 5)**2,
(style_g4 * 5)**2,
(style_g5 * 5)**2
)
)
content_loss = ContentLossToTarget(
norm_style_tensor,
vgg16,
vgg16_layers,
layer_weights=(
content_g1,
content_g2,
content_g3,
content_g4,
content_g5
)
)
# Prepare for inference
scheduler.set_timesteps(inference_steps)
offset = scheduler.config.get("steps_offset", 0)
start_step = int(inference_steps * strength) + offset
# Generate initial noise
generator = torch.Generator(device=DEVICE)
generator.manual_seed(42)
noise = torch.randn(
init_image_latents.shape,
generator=generator,
device=DEVICE,
dtype=torch.float16
)
# Add noise to input image
latents = scheduler.add_noise(
init_image_latents,
noise,
timesteps=torch.tensor([scheduler.timesteps[start_step]])
)
# Encode text embeddings
text_embeddings = pipe._encode_prompt(
prompt,
DEVICE,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt
)
# Initialize loss function
mae_loss = torch.nn.L1Loss()
# Denoising loop
timesteps = scheduler.timesteps
for i, t in enumerate(tqdm(scheduler.timesteps)):
# Expand latents for classifier free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# Predict noise
with torch.no_grad():
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings
).sample
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred/noise_pred.norm()*noise_pred_uncond.norm()
# Store current step
pipe.scheduler._step_index = i
#print(f"{i} of {inference_steps} - {start_step}")
if i > start_step:
if i < int(0.8 * inference_steps):
latents = latents.detach().requires_grad_()
current_step = pipe.scheduler._step_index
# print(f"Step {i} - Current scheduler step: {current_step}")
# print(f"Timestep t: {t}")
# print(f"Sigma: {scheduler.sigmas[i]}")
# Get prediction of original sample
step_output = scheduler.step(noise_pred, t, latents)
latents_x0 = step_output.pred_original_sample
# print(f"Latents x0 stats - Mean: {latents_x0.mean():.4f}, Std: {latents_x0.std():.4f}")
pipe.scheduler._step_index = current_step
# Process through VAE
latents_x0_vae = latents_x0.half()
denoised_images = vae.decode((1 / 0.18215) * latents_x0_vae).sample / 2 + 0.5
denoised_images = denoised_images.clamp(0, 1)
# Calculate losses
norm_image_tensor = (denoised_images.squeeze() - mean_tensor) / std_tensor
norm_image_tensor = norm_image_tensor.unsqueeze(dim=0)
# Debug print
# print(f"Step {i} - ", end='')
content_loss_scale = 17.6
loss = content_loss(norm_image_tensor) * content_loss_scale
# print(f"content_loss {loss.item()}")
style_loss_val = style_loss(norm_image_tensor) * 0.5
# print(f"style_loss_val {style_loss_val.item()}")
latent_loss_val = mae_loss(latents_x0, style_latents) * latent_guidance
# print(f"latent_loss_val {latent_loss_val.item()}")
loss += style_loss_val
loss += latent_loss_val
# print(f"loss {loss.item()}")
# Calculate and apply gradients
cond_grad = torch.autograd.grad(loss, latents)[0]
# print(f"Gradient stats - Mean: {cond_grad.mean():.4f}, Std: {cond_grad.std():.4f}")
latents = latents.detach() - cond_grad * scheduler.sigmas[i].to(DEVICE)**2
torch.cuda.empty_cache()
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode final image
with torch.no_grad():
image = pipe.decode_latents(latents)
image = pipe.numpy_to_pil(image)[0]
clean_mem()
return image # Fixed - return the processed image
except Exception as e:
clean_mem()
raise RuntimeError(f"Error during processing: {str(e)}")