Spaces:
Runtime error
Runtime error
import torch | |
from torchvision import transforms as tfms | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler | |
from diffusers import StableDiffusionInpaintPipeline | |
import gradio as gr | |
import os | |
auth_token = os.environ.get("API_TOKEN") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def load_artifacts(): | |
''' | |
A function to load all diffusion artifacts | |
''' | |
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16,use_auth_token=auth_token).to(device) | |
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16, use_auth_token=auth_token).to(device) | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16, use_auth_token=auth_token) | |
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16, use_auth_token=auth_token).to(device) | |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) | |
return vae, unet, tokenizer, text_encoder, scheduler | |
def load_image(p): | |
''' | |
Function to load images from a defined path | |
''' | |
return Image.open(p).convert('RGB').resize((512,512)) | |
def pil_to_latents(image): | |
''' | |
Function to convert image to latents | |
''' | |
init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0 | |
init_image = init_image.to(device=device, dtype=torch.float16) | |
init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215 | |
return init_latent_dist | |
def latents_to_pil(latents): | |
''' | |
Function to convert latents to images | |
''' | |
latents = (1 / 0.18215) * latents | |
with torch.no_grad(): | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def text_enc(prompts, maxlen=None): | |
''' | |
A function to take a texual promt and convert it into embeddings | |
''' | |
if maxlen is None: maxlen = tokenizer.model_max_length | |
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt") | |
return text_encoder(inp.input_ids.to(device))[0].half() | |
def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512): | |
""" | |
Diffusion process to convert prompt to image | |
""" | |
# Converting textual prompts to embedding | |
text = text_enc(prompts) | |
# Adding an unconditional prompt , helps in the generation process | |
uncond = text_enc([""], text.shape[1]) | |
emb = torch.cat([uncond, text]) | |
# Setting the seed | |
if seed: torch.manual_seed(seed) | |
# Setting number of steps in scheduler | |
scheduler.set_timesteps(steps) | |
# Convert the seed image to latent | |
init_latents = pil_to_latents(init_img) | |
# Figuring initial time step based on strength | |
init_timestep = int(steps * strength) | |
timesteps = scheduler.timesteps[-init_timestep] | |
timesteps = torch.tensor([timesteps], device=device) | |
# Adding noise to the latents | |
noise = torch.randn(init_latents.shape, generator=None, device=device, dtype=init_latents.dtype) | |
init_latents = scheduler.add_noise(init_latents, noise, timesteps) | |
latents = init_latents | |
# We need to scale the i/p latents to match the variance | |
inp = scheduler.scale_model_input(torch.cat([latents] * 2), timesteps) | |
# Predicting noise residual using U-Net | |
with torch.no_grad(): u,t = unet(inp, timesteps, encoder_hidden_states=emb).sample.chunk(2) | |
# Performing Guidance | |
pred = u + g*(t-u) | |
# Zero shot prediction | |
latents = scheduler.step(pred, timesteps, latents).pred_original_sample | |
# Returning the latent representation to output an array of 4x64x64 | |
return latents.detach().cpu() | |
def create_mask_fast(init_img, rp, qp, n=20, s=0.5): | |
## Initialize a dictionary to save n iterations | |
diff = {} | |
## Repeating the difference process n times | |
for idx in range(n): | |
## Creating denoised sample using reference / original text | |
orig_noise = prompt_2_img_i2i_fast(prompts=rp, init_img=init_img, strength=s, seed = 100*idx)[0] | |
## Creating denoised sample using query / target text | |
query_noise = prompt_2_img_i2i_fast(prompts=qp, init_img=init_img, strength=s, seed = 100*idx)[0] | |
## Taking the difference | |
diff[idx] = (np.array(orig_noise)-np.array(query_noise)) | |
## Creating a mask placeholder | |
mask = np.zeros_like(diff[0]) | |
## Taking an average of 10 iterations | |
for idx in range(n): | |
## Note np.abs is a key step | |
mask += np.abs(diff[idx]) | |
## Averaging multiple channels | |
mask = mask.mean(0) | |
## Normalizing | |
mask = (mask - mask.mean()) / np.std(mask) | |
## Binarizing and returning the mask object | |
return (mask > 0).astype("uint8") | |
def improve_mask(mask): | |
mask = cv2.GaussianBlur(mask*255,(3,3),1) > 0 | |
return mask.astype('uint8') | |
vae, unet, tokenizer, text_encoder, scheduler = load_artifacts() | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
use_auth_token=auth_token | |
).to(device) | |
def fastDiffEdit(init_img, reference_prompt , query_prompt, g=7.5, seed=100, strength =0.7, steps=20, dim=512): | |
## Step 1: Create mask | |
mask = create_mask_fast(init_img=init_img, rp=reference_prompt, qp=query_prompt, n=20) | |
## Improve masking using CV trick | |
mask = improve_mask(mask) | |
## Step 2 and 3: Diffusion process using mask | |
output = pipe( | |
prompt=query_prompt, | |
image=init_img, | |
mask_image=Image.fromarray(mask*255).resize((512,512)), | |
generator=torch.Generator(device).manual_seed(100), | |
num_inference_steps = steps | |
).images | |
return output[0] | |
demo = gr.Interface( | |
fn=fastDiffEdit, | |
inputs=[ | |
gr.inputs.Image(shape=(512, 512), type="pil", label = "Upload your image photo"), | |
gr.Textbox(label="Describe your image. Ex: a horse image"), | |
gr.Textbox(label="Retype the description with target output. Ex: a zebra image")], | |
outputs="image", | |
title = "DiffEdit demo", | |
description = "DiffEdit paper demo. Upload an image, pass reference prompt describing the image, pass query prompt to replace the object with target object", | |
examples = [ | |
["fruitbowl.jpg", "a bowl of fruit", "a bowl of grapes"], | |
["horse.jpg", "a horse image", "a zebra image"]], | |
enable_queue=True | |
) | |
demo.launch() |