diffedit / app.py
aayushmnit's picture
Update app.py
1173e62
import torch
from torchvision import transforms as tfms
import numpy as np
import cv2
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from diffusers import StableDiffusionInpaintPipeline
import gradio as gr
import os
auth_token = os.environ.get("API_TOKEN")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_artifacts():
'''
A function to load all diffusion artifacts
'''
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16,use_auth_token=auth_token).to(device)
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16, use_auth_token=auth_token)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
return vae, unet, tokenizer, text_encoder, scheduler
def load_image(p):
'''
Function to load images from a defined path
'''
return Image.open(p).convert('RGB').resize((512,512))
def pil_to_latents(image):
'''
Function to convert image to latents
'''
init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0
init_image = init_image.to(device=device, dtype=torch.float16)
init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215
return init_latent_dist
def latents_to_pil(latents):
'''
Function to convert latents to images
'''
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def text_enc(prompts, maxlen=None):
'''
A function to take a texual promt and convert it into embeddings
'''
if maxlen is None: maxlen = tokenizer.model_max_length
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
return text_encoder(inp.input_ids.to(device))[0].half()
def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):
"""
Diffusion process to convert prompt to image
"""
# Converting textual prompts to embedding
text = text_enc(prompts)
# Adding an unconditional prompt , helps in the generation process
uncond = text_enc([""], text.shape[1])
emb = torch.cat([uncond, text])
# Setting the seed
if seed: torch.manual_seed(seed)
# Setting number of steps in scheduler
scheduler.set_timesteps(steps)
# Convert the seed image to latent
init_latents = pil_to_latents(init_img)
# Figuring initial time step based on strength
init_timestep = int(steps * strength)
timesteps = scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps], device=device)
# Adding noise to the latents
noise = torch.randn(init_latents.shape, generator=None, device=device, dtype=init_latents.dtype)
init_latents = scheduler.add_noise(init_latents, noise, timesteps)
latents = init_latents
# We need to scale the i/p latents to match the variance
inp = scheduler.scale_model_input(torch.cat([latents] * 2), timesteps)
# Predicting noise residual using U-Net
with torch.no_grad(): u,t = unet(inp, timesteps, encoder_hidden_states=emb).sample.chunk(2)
# Performing Guidance
pred = u + g*(t-u)
# Zero shot prediction
latents = scheduler.step(pred, timesteps, latents).pred_original_sample
# Returning the latent representation to output an array of 4x64x64
return latents.detach().cpu()
def create_mask_fast(init_img, rp, qp, n=20, s=0.5):
## Initialize a dictionary to save n iterations
diff = {}
## Repeating the difference process n times
for idx in range(n):
## Creating denoised sample using reference / original text
orig_noise = prompt_2_img_i2i_fast(prompts=rp, init_img=init_img, strength=s, seed = 100*idx)[0]
## Creating denoised sample using query / target text
query_noise = prompt_2_img_i2i_fast(prompts=qp, init_img=init_img, strength=s, seed = 100*idx)[0]
## Taking the difference
diff[idx] = (np.array(orig_noise)-np.array(query_noise))
## Creating a mask placeholder
mask = np.zeros_like(diff[0])
## Taking an average of 10 iterations
for idx in range(n):
## Note np.abs is a key step
mask += np.abs(diff[idx])
## Averaging multiple channels
mask = mask.mean(0)
## Normalizing
mask = (mask - mask.mean()) / np.std(mask)
## Binarizing and returning the mask object
return (mask > 0).astype("uint8")
def improve_mask(mask):
mask = cv2.GaussianBlur(mask*255,(3,3),1) > 0
return mask.astype('uint8')
vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=auth_token
).to(device)
def fastDiffEdit(init_img, reference_prompt , query_prompt, g=7.5, seed=100, strength =0.7, steps=20, dim=512):
## Step 1: Create mask
mask = create_mask_fast(init_img=init_img, rp=reference_prompt, qp=query_prompt, n=20)
## Improve masking using CV trick
mask = improve_mask(mask)
## Step 2 and 3: Diffusion process using mask
output = pipe(
prompt=query_prompt,
image=init_img,
mask_image=Image.fromarray(mask*255).resize((512,512)),
generator=torch.Generator(device).manual_seed(100),
num_inference_steps = steps
).images
return output[0]
demo = gr.Interface(
fn=fastDiffEdit,
inputs=[
gr.inputs.Image(shape=(512, 512), type="pil", label = "Upload your image photo"),
gr.Textbox(label="Describe your image. Ex: a horse image"),
gr.Textbox(label="Retype the description with target output. Ex: a zebra image")],
outputs="image",
title = "DiffEdit demo",
description = "DiffEdit paper demo. Upload an image, pass reference prompt describing the image, pass query prompt to replace the object with target object",
examples = [
["fruitbowl.jpg", "a bowl of fruit", "a bowl of grapes"],
["horse.jpg", "a horse image", "a zebra image"]],
enable_queue=True
)
demo.launch()