ukiyo-e-postal / app.py
alkzar90's picture
Decrease the number of examples
4641216
raw
history blame
7.4 kB
import open_clip
import gradio as gr
import numpy as np
import torch
import torchvision
from tqdm.auto import tqdm
from PIL import Image, ImageColor
from torchvision import transforms
from diffusers import DDIMScheduler, DDPMPipeline
device = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
# Load the pretrained pipeline
pipeline_name = "alkzar90/sd-class-ukiyo-e-256"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
# Sample some images with a DDIM Scheduler over 40 steps
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)
# Color guidance
#-------------------------------------------------------------------------------
# Color guidance function
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
"""Given a target color (R, G, B) return a loss for how far away on average
the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""
target = (
torch.tensor(target_color).to(images.device) * 2 - 1
) # Map target color to (-1, 1)
target = target[
None, :, None, None
] # Get shape right to work with the images (b, c, h, w)
error = torch.abs(
images - target
).mean() # Mean absolute difference between the image pixels and the target color
return error
# CLIP guidance
#-------------------------------------------------------------------------------
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="openai"
)
clip_model.to(device)
# Transforms to resize and augment an image + normalize to match CLIP's training data
tfms = transforms.Compose(
[
transforms.RandomResizedCrop(224), # Random CROP each time
transforms.RandomAffine(
5
), # One possible random augmentation: skews the image
transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
# CLIP guidance function
def clip_loss(image, text_features):
image_features = clip_model.encode_image(
tfms(image)
) # Note: applies the above transforms
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
dists = (
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
) # Squared Great Circle Distance
return dists.mean()
# Sample generator loop
#-------------------------------------------------------------------------------
def generate(color,
color_loss_scale,
num_examples=4,
seed=None,
prompt=None,
prompt_loss_scale=None,
prompt_n_cuts=None,
inference_steps=50,
):
scheduler.set_timesteps(num_inference_steps=inference_steps)
if seed:
torch.manual_seed(seed)
if prompt:
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text)
target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1)
x = torch.randn(num_examples, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
x = x.detach().requires_grad_()
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# color loss
loss = color_loss(x0, target_color) * color_loss_scale
cond_color_grad = -torch.autograd.grad(loss, x)[0]
# Modify x based solely on the color gradient -> x_cond
x_cond = x.detach() + cond_color_grad
# prompt loss (modify x_cond with cond_prompt_grad) based on
# the original x (not modifified previously with cond_color_grad)
if prompt:
cond_prompt_grad = 0
for cut in range(prompt_n_cuts):
# Set requires grad on x
x = x.detach().requires_grad_()
# Get the predicted x0:
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculate loss
prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale
# Get gradient (scale by n_cuts since we want the average)
cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts
# Modify x based on this gradient
alpha_bar = scheduler.alphas_cumprod[i]
x_cond = (
x_cond + cond_prompt_grad * alpha_bar.sqrt()
) # Note the additional scaling factor here!
x = scheduler.step(noise_pred, t, x_cond).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
im = Image.fromarray(np.array(im * 255).astype(np.uint8))
im.save("test.jpeg")
return im
# GRADIO Interface
#-------------------------------------------------------------------------------
TITLE="Ukiyo-e postal generator service 🎴!"
DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co/datasets/huggan/ukiyoe2photo"
CSS = ".output-image, .input-image, .image-preview {height: 250px !important}"
# See the gradio docs for the types of inputs and outputs available
inputs = [
gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"), # Add any inputs you need here
gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7),
gr.Slider(label="num_examples (# images generated)", minimum=4, maximum=12, value=8, step=4),
gr.Number(label="seed (reproducibility and experimentation)", value=666),
gr.Text(label="Text prompt (optional)", value=None),
gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10),
gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4),
gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", minimum=40, maximum=60, value=40, step=1),
]
outputs = gr.Image(label="result")
# And the minimal interface
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
css=CSS,
examples=[
#["#DF5C16", 6.7, 12, 666, None, None, None, 40],
#["#C01660", 13.5, 12, 1990, None, None, None, 40],
#["#44CCAA", 8.9, 12, 1512, None, None, None, 40],
["#39A291", 5.0, 8, 666, "A sakura tree", 60, 4, 52],
#["#0E0907", 0.0, 12, 666, "A big whale in the ocean", 60, 8, 52],
#["#19A617", 4.6, 12, 666, "An island with sunset at background", 140, 4, 47],
],
title=TITLE,
description=DESCRIPTION,
)
if __name__ == "__main__":
demo.launch(enable_queue=True)