Spaces:
Running
on
Zero
Running
on
Zero
# %% | |
import argparse, os | |
import torch | |
import requests | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from PIL import Image | |
from io import BytesIO | |
from tqdm.auto import tqdm | |
from matplotlib import pyplot as plt | |
from torchvision import transforms as tfms | |
from diffusers import ( | |
StableDiffusionPipeline, | |
DDIMScheduler, | |
DiffusionPipeline, | |
StableDiffusionXLPipeline, | |
) | |
from diffusers.image_processor import VaeImageProcessor | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.transforms as transforms | |
from torchvision.utils import save_image | |
import argparse | |
import PIL.Image as Image | |
from torchvision.utils import make_grid | |
import numpy | |
from diffusers.schedulers import DDIMScheduler | |
import torch.nn.functional as F | |
from models import attn_injection | |
from omegaconf import OmegaConf | |
from typing import List, Tuple | |
import omegaconf | |
import utils.exp_utils | |
import json | |
device = torch.device("cuda") | |
def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device): | |
# Tokenize text and get embeddings | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
with torch.no_grad(): | |
prompt_embeds = text_encoder( | |
text_input_ids.to(device), | |
output_hidden_states=True, | |
) | |
pooled_prompt_embeds = prompt_embeds[0] | |
prompt_embeds = prompt_embeds.hidden_states[-2] | |
if prompt == "": | |
negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) | |
return negative_prompt_embeds, negative_pooled_prompt_embeds | |
return prompt_embeds, pooled_prompt_embeds | |
def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str): | |
device = model._execution_device | |
( | |
prompt_embeds, | |
pooled_prompt_embeds, | |
) = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device) | |
( | |
prompt_embeds_2, | |
pooled_prompt_embeds_2, | |
) = _get_text_embeddings(prompt, model.tokenizer_2, model.text_encoder_2, device) | |
prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1) | |
text_encoder_projection_dim = model.text_encoder_2.config.projection_dim | |
add_time_ids = model._get_add_time_ids( | |
(1024, 1024), (0, 0), (1024, 1024), torch.float16, text_encoder_projection_dim | |
).to(device) | |
# repeat the time ids for each prompt | |
add_time_ids = add_time_ids.repeat(len(prompt), 1) | |
added_cond_kwargs = { | |
"text_embeds": pooled_prompt_embeds_2, | |
"time_ids": add_time_ids, | |
} | |
return added_cond_kwargs, prompt_embeds | |
def _encode_text_sdxl_with_negative( | |
model: StableDiffusionXLPipeline, prompt: List[str] | |
): | |
B = len(prompt) | |
added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt) | |
added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl( | |
model, ["" for _ in range(B)] | |
) | |
prompt_embeds = torch.cat( | |
( | |
prompt_embeds_uncond, | |
prompt_embeds, | |
) | |
) | |
added_cond_kwargs = { | |
"text_embeds": torch.cat( | |
(added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"]) | |
), | |
"time_ids": torch.cat( | |
(added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"]) | |
), | |
} | |
return added_cond_kwargs, prompt_embeds | |
# Sample function (regular DDIM) | |
def sample( | |
pipe, | |
prompt, | |
start_step=0, | |
start_latents=None, | |
intermediate_latents=None, | |
guidance_scale=3.5, | |
num_inference_steps=30, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=True, | |
negative_prompt="", | |
device=device, | |
): | |
negative_prompt = [""] * len(prompt) | |
# Encode prompt | |
if isinstance(pipe, StableDiffusionPipeline): | |
text_embeddings = pipe._encode_prompt( | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt, | |
) | |
added_cond_kwargs = None | |
elif isinstance(pipe, StableDiffusionXLPipeline): | |
added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( | |
pipe, prompt | |
) | |
# Set num inference steps | |
pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
# Create a random starting point if we don't have one already | |
if start_latents is None: | |
start_latents = torch.randn(1, 4, 64, 64, device=device) | |
start_latents *= pipe.scheduler.init_noise_sigma | |
latents = start_latents.clone() | |
latents = latents.repeat(len(prompt), 1, 1, 1) | |
# assume that the first latent is used for reconstruction | |
for i in tqdm(range(start_step, num_inference_steps)): | |
latents[0] = intermediate_latents[(-i + 1)] | |
t = pipe.scheduler.timesteps[i] | |
# Expand the latents if we are doing classifier free guidance | |
latent_model_input = ( | |
torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
) | |
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) | |
# Predict the noise residual | |
noise_pred = pipe.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings, | |
added_cond_kwargs=added_cond_kwargs, | |
).sample | |
# Perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample | |
# Post-processing | |
images = pipe.decode_latents(latents) | |
images = pipe.numpy_to_pil(images) | |
return images | |
# Sample function (regular DDIM), but disentangle the content and style | |
def sample_disentangled( | |
pipe, | |
prompt, | |
start_step=0, | |
start_latents=None, | |
intermediate_latents=None, | |
guidance_scale=3.5, | |
num_inference_steps=30, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=True, | |
use_content_anchor=True, | |
negative_prompt="", | |
device=device, | |
): | |
negative_prompt = [""] * len(prompt) | |
vae_decoder = VaeImageProcessor(vae_scale_factor=pipe.vae.config.scaling_factor) | |
# Encode prompt | |
if isinstance(pipe, StableDiffusionPipeline): | |
text_embeddings = pipe._encode_prompt( | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt, | |
) | |
added_cond_kwargs = None | |
elif isinstance(pipe, StableDiffusionXLPipeline): | |
added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( | |
pipe, prompt | |
) | |
# Set num inference steps | |
pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
# save | |
latent_shape = ( | |
(1, 4, 64, 64) if isinstance(pipe, StableDiffusionPipeline) else (1, 4, 64, 64) | |
) | |
generative_latent = torch.randn(latent_shape, device=device) | |
generative_latent *= pipe.scheduler.init_noise_sigma | |
latents = start_latents.clone() | |
latents = latents.repeat(len(prompt), 1, 1, 1) | |
# randomly initalize the 1st lantent for generation | |
latents[1] = generative_latent | |
# assume that the first latent is used for reconstruction | |
for i in tqdm(range(start_step, num_inference_steps), desc="Stylizing"): | |
if use_content_anchor: | |
latents[0] = intermediate_latents[(-i + 1)] | |
t = pipe.scheduler.timesteps[i] | |
# Expand the latents if we are doing classifier free guidance | |
latent_model_input = ( | |
torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
) | |
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) | |
# Predict the noise residual | |
noise_pred = pipe.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings, | |
added_cond_kwargs=added_cond_kwargs, | |
).sample | |
# Perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample | |
# Post-processing | |
# images = vae_decoder.postprocess(latents) | |
pipe.vae.to(dtype=torch.float32) | |
latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype) | |
latents = 1 / pipe.vae.config.scaling_factor * latents | |
images = pipe.vae.decode(latents, return_dict=False)[0] | |
images = (images / 2 + 0.5).clamp(0, 1) | |
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | |
images = images.cpu().permute(0, 2, 3, 1).float().numpy() | |
images = pipe.numpy_to_pil(images) | |
if isinstance(pipe, StableDiffusionXLPipeline): | |
pipe.vae.to(dtype=torch.float16) | |
return images | |
## Inversion | |
def invert( | |
pipe, | |
start_latents, | |
prompt, | |
guidance_scale=3.5, | |
num_inference_steps=50, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=True, | |
negative_prompt="", | |
device=device, | |
): | |
# Encode prompt | |
if isinstance(pipe, StableDiffusionPipeline): | |
text_embeddings = pipe._encode_prompt( | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt, | |
) | |
added_cond_kwargs = None | |
latents = start_latents.clone().detach() | |
elif isinstance(pipe, StableDiffusionXLPipeline): | |
added_cond_kwargs, text_embeddings = _encode_text_sdxl_with_negative( | |
pipe, [prompt] | |
) # Latents are now the specified start latents | |
latents = start_latents.clone().detach().half() | |
# We'll keep a list of the inverted latents as the process goes on | |
intermediate_latents = [] | |
# Set num inference steps | |
pipe.scheduler.set_timesteps(num_inference_steps, device=device) | |
# Reversed timesteps <<<<<<<<<<<<<<<<<<<< | |
timesteps = reversed(pipe.scheduler.timesteps) | |
for i in tqdm( | |
range(1, num_inference_steps), | |
total=num_inference_steps - 1, | |
desc="DDIM Inversion", | |
): | |
# We'll skip the final iteration | |
if i >= num_inference_steps - 1: | |
continue | |
t = timesteps[i] | |
# Expand the latents if we are doing classifier free guidance | |
latent_model_input = ( | |
torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
) | |
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) | |
# Predict the noise residual | |
noise_pred = pipe.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings, | |
added_cond_kwargs=added_cond_kwargs, | |
).sample | |
# Perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
current_t = max(0, t.item() - (1000 // num_inference_steps)) # t | |
next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1 | |
alpha_t = pipe.scheduler.alphas_cumprod[current_t] | |
alpha_t_next = pipe.scheduler.alphas_cumprod[next_t] | |
# Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents) | |
latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * ( | |
alpha_t_next.sqrt() / alpha_t.sqrt() | |
) + (1 - alpha_t_next).sqrt() * noise_pred | |
# Store | |
intermediate_latents.append(latents) | |
return torch.cat(intermediate_latents) | |
def style_image_with_inversion( | |
pipe, | |
input_image, | |
input_image_prompt, | |
style_prompt, | |
num_steps=100, | |
start_step=30, | |
guidance_scale=3.5, | |
disentangle=False, | |
share_attn=False, | |
share_cross_attn=False, | |
share_resnet_layers=[0, 1], | |
share_attn_layers=[], | |
c2s_layers=[0, 1], | |
share_key=True, | |
share_query=True, | |
share_value=False, | |
use_adain=True, | |
use_content_anchor=True, | |
output_dir: str = None, | |
resnet_mode: str = None, | |
return_intermediate=False, | |
intermediate_latents=None, | |
): | |
with torch.no_grad(): | |
pipe.vae.to(dtype=torch.float32) | |
latent = pipe.vae.encode(input_image.to(device) * 2 - 1) | |
# latent = pipe.vae.encode(input_image.to(device)) | |
l = pipe.vae.config.scaling_factor * latent.latent_dist.sample() | |
if isinstance(pipe, StableDiffusionXLPipeline): | |
pipe.vae.to(dtype=torch.float16) | |
if intermediate_latents is None: | |
inverted_latents = invert( | |
pipe, l, input_image_prompt, num_inference_steps=num_steps | |
) | |
else: | |
inverted_latents = intermediate_latents | |
attn_injection.register_attention_processors( | |
pipe, | |
base_dir=output_dir, | |
resnet_mode=resnet_mode, | |
attn_mode="artist" if disentangle else "pnp", | |
disentangle=disentangle, | |
share_resblock=True, | |
share_attn=share_attn, | |
share_cross_attn=share_cross_attn, | |
share_resnet_layers=share_resnet_layers, | |
share_attn_layers=share_attn_layers, | |
share_key=share_key, | |
share_query=share_query, | |
share_value=share_value, | |
use_adain=use_adain, | |
c2s_layers=c2s_layers, | |
) | |
if disentangle: | |
final_im = sample_disentangled( | |
pipe, | |
style_prompt, | |
start_latents=inverted_latents[-(start_step + 1)][None], | |
intermediate_latents=inverted_latents, | |
start_step=start_step, | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
use_content_anchor=use_content_anchor, | |
) | |
else: | |
final_im = sample( | |
pipe, | |
style_prompt, | |
start_latents=inverted_latents[-(start_step + 1)][None], | |
intermediate_latents=inverted_latents, | |
start_step=start_step, | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
) | |
# unset the attention processors | |
attn_injection.unset_attention_processors( | |
pipe, | |
unset_share_attn=True, | |
unset_share_resblock=True, | |
) | |
if return_intermediate: | |
return final_im, inverted_latents | |
return final_im | |
if __name__ == "__main__": | |
# Load a pipeline | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1-base" | |
).to(device) | |
# pipe = DiffusionPipeline.from_pretrained( | |
# # "playgroundai/playground-v2-1024px-aesthetic", | |
# torch_dtype=torch.float16, | |
# use_safetensors=True, | |
# add_watermarker=False, | |
# variant="fp16", | |
# ) | |
# pipe.to("cuda") | |
# Set up a DDIM scheduler | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
parser = argparse.ArgumentParser(description="Stable Diffusion with OmegaConf") | |
parser.add_argument( | |
"--config", type=str, default="config.yaml", help="Path to the config file" | |
) | |
parser.add_argument( | |
"--mode", | |
type=str, | |
default="dataset", | |
choices=["dataset", "cli", "app"], | |
help="Path to the config file", | |
) | |
parser.add_argument( | |
"--image_dir", type=str, default="test.png", help="Path to the image" | |
) | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
default="an impressionist painting", | |
help="Stylization prompt", | |
) | |
# mode = "single_control_content" | |
args = parser.parse_args() | |
config_dir = args.config | |
mode = args.mode | |
# mode = "dataset" | |
out_name = ["content_delegation", "style_delegation", "style_out"] | |
if mode == "dataset": | |
cfg = OmegaConf.load(config_dir) | |
base_output_path = cfg.out_path | |
if not os.path.exists(cfg.out_path): | |
os.makedirs(cfg.out_path) | |
base_output_path = os.path.join(base_output_path, cfg.exp_name) | |
experiment_output_path = utils.exp_utils.make_unique_experiment_path( | |
base_output_path | |
) | |
# Save the experiment configuration | |
config_file_path = os.path.join(experiment_output_path, "config.yaml") | |
omegaconf.OmegaConf.save(cfg, config_file_path) | |
# Seed all | |
annotation = json.load(open(cfg.annotation)) | |
with open(os.path.join(experiment_output_path, "annotation.json"), "w") as f: | |
json.dump(annotation, f) | |
for i, entry in enumerate(annotation): | |
utils.exp_utils.seed_all(cfg.seed) | |
image_path = entry["image_path"] | |
src_prompt = entry["source_prompt"] | |
tgt_prompt = entry["target_prompt"] | |
resolution = 512 if isinstance(pipe, StableDiffusionXLPipeline) else 512 | |
input_image = utils.exp_utils.get_processed_image( | |
image_path, device, resolution | |
) | |
prompt_in = [ | |
src_prompt, # reconstruction | |
tgt_prompt, # uncontrolled style | |
"", # controlled style | |
] | |
imgs = style_image_with_inversion( | |
pipe, | |
input_image, | |
src_prompt, | |
style_prompt=prompt_in, | |
num_steps=cfg.num_steps, | |
start_step=cfg.start_step, | |
guidance_scale=cfg.style_cfg_scale, | |
disentangle=cfg.disentangle, | |
resnet_mode=cfg.resnet_mode, | |
share_attn=cfg.share_attn, | |
share_cross_attn=cfg.share_cross_attn, | |
share_resnet_layers=cfg.share_resnet_layers, | |
share_attn_layers=cfg.share_attn_layers, | |
share_key=cfg.share_key, | |
share_query=cfg.share_query, | |
share_value=cfg.share_value, | |
use_content_anchor=cfg.use_content_anchor, | |
use_adain=cfg.use_adain, | |
output_dir=experiment_output_path, | |
) | |
for j, img in enumerate(imgs): | |
img.save(f"{experiment_output_path}/out_{i}_{out_name[j]}.png") | |
print( | |
f"Image saved as {experiment_output_path}/out_{i}_{out_name[j]}.png" | |
) | |
elif mode == "cli": | |
cfg = OmegaConf.load(config_dir) | |
utils.exp_utils.seed_all(cfg.seed) | |
image = utils.exp_utils.get_processed_image(args.image_dir, device, 512) | |
tgt_prompt = args.prompt | |
src_prompt = "" | |
prompt_in = [ | |
"", # reconstruction | |
tgt_prompt, # uncontrolled style | |
"", # controlled style | |
] | |
out_dir = "./out" | |
os.makedirs(out_dir, exist_ok=True) | |
imgs = style_image_with_inversion( | |
pipe, | |
image, | |
src_prompt, | |
style_prompt=prompt_in, | |
num_steps=cfg.num_steps, | |
start_step=cfg.start_step, | |
guidance_scale=cfg.style_cfg_scale, | |
disentangle=cfg.disentangle, | |
resnet_mode=cfg.resnet_mode, | |
share_attn=cfg.share_attn, | |
share_cross_attn=cfg.share_cross_attn, | |
share_resnet_layers=cfg.share_resnet_layers, | |
share_attn_layers=cfg.share_attn_layers, | |
share_key=cfg.share_key, | |
share_query=cfg.share_query, | |
share_value=cfg.share_value, | |
use_content_anchor=cfg.use_content_anchor, | |
use_adain=cfg.use_adain, | |
output_dir=out_dir, | |
) | |
image_base_name = os.path.basename(args.image_dir).split(".")[0] | |
for j, img in enumerate(imgs): | |
img.save(f"{out_dir}/{image_base_name}_out_{out_name[j]}.png") | |
print(f"Image saved as {out_dir}/{image_base_name}_out_{out_name[j]}.png") | |
elif mode == "app": | |
# gradio | |
import gradio as gr | |
def style_transfer_app( | |
prompt, | |
image, | |
cfg_scale=7.5, | |
num_content_layers=4, | |
num_style_layers=9, | |
seed=0, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
utils.exp_utils.seed_all(seed) | |
image = utils.exp_utils.process_image(image, device, 512) | |
tgt_prompt = prompt | |
src_prompt = "" | |
prompt_in = [ | |
"", # reconstruction | |
tgt_prompt, # uncontrolled style | |
"", # controlled style | |
] | |
share_resnet_layers = ( | |
list(range(num_content_layers)) if num_content_layers != 0 else None | |
) | |
share_attn_layers = ( | |
list(range(num_style_layers)) if num_style_layers != 0 else None | |
) | |
imgs = style_image_with_inversion( | |
pipe, | |
image, | |
src_prompt, | |
style_prompt=prompt_in, | |
num_steps=50, | |
start_step=0, | |
guidance_scale=cfg_scale, | |
disentangle=True, | |
resnet_mode="hidden", | |
share_attn=True, | |
share_cross_attn=True, | |
share_resnet_layers=share_resnet_layers, | |
share_attn_layers=share_attn_layers, | |
share_key=True, | |
share_query=True, | |
share_value=False, | |
use_content_anchor=True, | |
use_adain=True, | |
output_dir="./", | |
) | |
return imgs[2] | |
# load examples | |
examples = [] | |
annotation = json.load(open("data/example/annotation.json")) | |
for entry in annotation: | |
image = utils.exp_utils.get_processed_image( | |
entry["image_path"], device, 512 | |
) | |
image = transforms.ToPILImage()(image[0]) | |
examples.append([entry["target_prompt"], image, None, None, None]) | |
text_input = gr.Textbox( | |
value="An impressionist painting", | |
label="Text Prompt", | |
info="Describe the style you want to apply to the image, do not include the description of the image content itself", | |
lines=2, | |
placeholder="Enter a text prompt", | |
) | |
image_input = gr.Image( | |
height="80%", | |
width="80%", | |
label="Content image (will be resized to 512x512)", | |
interactive=True, | |
) | |
cfg_slider = gr.Slider( | |
0, | |
15, | |
value=7.5, | |
label="Classifier Free Guidance (CFG) Scale", | |
info="higher values give more style, 7.5 should be good for most cases", | |
) | |
content_slider = gr.Slider( | |
0, | |
9, | |
value=4, | |
step=1, | |
label="Number of content control layer", | |
info="higher values make it more similar to original image. Default to control first 4 layers", | |
) | |
style_slider = gr.Slider( | |
0, | |
9, | |
value=9, | |
step=1, | |
label="Number of style control layer", | |
info="higher values make it more similar to target style. Default to control first 9 layers, usually not necessary to change.", | |
) | |
seed_slider = gr.Slider( | |
0, | |
100, | |
value=0, | |
step=1, | |
label="Seed", | |
info="Random seed for the model", | |
) | |
app = gr.Interface( | |
fn=style_transfer_app, | |
inputs=[ | |
text_input, | |
image_input, | |
cfg_slider, | |
content_slider, | |
style_slider, | |
seed_slider, | |
], | |
outputs=["image"], | |
title="Artist Interactive Demo", | |
examples=examples, | |
) | |
app.launch() |