PSLD / app.py
LituRout's picture
add base
ea94cfb
import gradio as gr
import numpy as np
import torch
from diffusers import StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers import LMSDiscreteScheduler
from share_btn import community_icon_html, loading_icon_html
from tqdm.auto import tqdm
from PIL import Image
# PARAMS
MANUAL_SEED = 42
HEIGHT = 512
WIDTH = 512
ETA = 1e-1
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(torch_device)
# 1. Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
scheduler = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)
generator = torch.manual_seed(MANUAL_SEED) # Seed generator to create the inital latent noise
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
# def predict(dict, prompt=""):
# init_image = dict["image"].convert("RGB").resize((512, 512))
# mask = dict["mask"].convert("RGB").resize((512, 512))
# output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
# return output.images[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def predict(dict, prompt=""):
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""], padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
init_image = dict["image"].convert("RGB").resize((512, 512))
mask = dict["mask"].convert("RGB").resize((512, 512))
# convert input image to array in [-1, 1]
init_image = torch.tensor(2 * (np.asarray(init_image) / 255) - 1, device=torch_device)
mask = torch.tensor((np.asarray(mask) / 255), device=torch_device)
# add one dimension for the batch and bring channels first
init_image = init_image.permute(2, 0, 1).unsqueeze(0)
mask = mask.permute(2, 0, 1).unsqueeze(0)
latents = torch.randn(
(1, unet.in_channels, HEIGHT // 8, WIDTH // 8),
generator=generator,
)
latents = latents.to(torch_device)
for i, t in enumerate(tqdm(scheduler.timesteps)):
t = scheduler.timesteps[i]
z_t = torch.clone(latents.detach())
z_t.requires_grad = True
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = scheduler.scale_model_input(z_t, t)
# predict the noise residual
noise_pred = unet(latent_model_input, t, encoder_hidden_states=uncond_embeddings).sample
# compute z_0 using tweedies's formula
indx = scheduler.num_inference_steps - i - 1
z_0 = (1/torch.sqrt(scheduler.alphas_cumprod[indx]))\
*(z_t + (1-scheduler.alphas_cumprod[indx]) * noise_pred )
# pass through the decoder
z_0 = 1 / 0.18215 * z_0
image_pred = vae.decode(z_0).sample
# clip
image_pred = torch.clamp(image_pred, min=-1.0, max=1.0)
inpainted_image = (1 - mask) * init_image + mask * image_pred
error_measurement = (1/2) * torch.linalg.norm((1-mask) * (init_image - image_pred))**2
# TODO(giannisdaras): add LPIPS?
error = error_measurement
gradients = torch.autograd.grad(error, inputs=z_t)[0]
# compute the previous noisy sample x_t -> x_t-1
z_t_next = scheduler.step(noise_pred, t, z_t).prev_sample
latents = z_t_next - ETA * gradients
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
return images[0], gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
image_blocks = gr.Blocks(css=css)
with image_blocks as demo:
gr.HTML(read_content("header.html"))
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload").style(height=400)
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
btn = gr.Button("Inpaint!").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=False)
loading_icon = gr.HTML(loading_icon_html, visible=False)
btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out, community_icon, loading_icon])
image_blocks.launch()