fffiloni's picture
Update app.py
dbfb439
import gradio as gr
from PIL import Image
from io import BytesIO
import torch
import os
#os.system("pip install git+https://github.com/fffiloni/diffusers")
from diffusers import DiffusionPipeline, DDIMScheduler
from imagic import ImagicStableDiffusionPipeline
has_cuda = torch.cuda.is_available()
device = "cuda"
pipe = ImagicStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
#custom_pipeline=ImagicStableDiffusionPipeline,
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
).to(device)
generator = torch.Generator("cuda").manual_seed(0)
def train(prompt, init_image, trn_text, trn_steps):
init_image = Image.open(init_image).convert("RGB")
init_image = init_image.resize((256, 256))
res = pipe.train(
prompt,
init_image,
guidance_scale=7.5,
num_inference_steps=50,
generator=generator,
text_embedding_optimization_steps=trn_text,
model_fine_tuning_optimization_steps=trn_steps)
with torch.no_grad():
torch.cuda.empty_cache()
return "Training is finished !", gr.update(value=0), gr.update(value=0)
def generate(prompt, init_image, trn_text, trn_steps):
init_image = Image.open(init_image).convert("RGB")
init_image = init_image.resize((256, 256))
res = pipe.train(
prompt,
init_image,
guidance_scale=7.5,
num_inference_steps=50,
generator=generator,
text_embedding_optimization_steps=trn_text,
model_fine_tuning_optimization_steps=trn_steps)
with torch.no_grad():
torch.cuda.empty_cache()
res = pipe(alpha=1)
return res.images[0]
title = """
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-top: 7px;">
Imagic Stable Diffusion • Community Pipeline
</h1>
</div>
<p style="margin-top: 10px; font-size: 94%">
Text-Based Real Image Editing with Diffusion Models
<br />This pipeline aims to implement <a href="https://arxiv.org/abs/2210.09276" target="_blank">this paper</a> to Stable Diffusion, allowing for real-world image editing.
</p>
<br /><img src="https://user-images.githubusercontent.com/788417/196388568-4ee45edd-e990-452c-899f-c25af32939be.png" style="margin:7px 0 20px;"/>
<p style="font-size: 94%">
You can skip the queue by duplicating this space or run the Colab version:
<span style="display: flex;align-items: center;justify-content: center;height: 30px;">
<a href="https://huggingface.co/spaces/fffiloni/imagic-stable-diffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>
</span>
</p>
</div>
"""
article = """
<div class="footer">
<p><a href="https://github.com/huggingface/diffusers/tree/main/examples/community#imagic-stable-diffusion" target="_blank">Community pipeline</a>
baked by <a href="https://github.com/MarkRich" style="text-decoration: underline;" target="_blank">Mark Rich</a> -
Gradio Demo by 🤗 <a href="https://twitter.com/fffiloni" target="_blank">Sylvain Filoni</a>
</p>
</div>
"""
css = '''
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
'''
with gr.Blocks(css=css) as block:
with gr.Column(elem_id="col-container"):
gr.HTML(title)
prompt_input = gr.Textbox(label="Target text", placeholder="Describe the image with what you want to change about the subject")
image_init = gr.Image(source="upload", type="filepath",label="Input Image")
with gr.Row():
trn_text = gr.Slider(0, 500, step=50, value=250, label="text embedding")
trn_steps = gr.Slider(0, 1000, step=50, value=500, label="finetuning steps")
with gr.Row():
train_btn = gr.Button("1.Train")
gen_btn = gr.Button("2.Generate")
training_status = gr.Textbox(label="training status")
image_output = gr.Image(label="Edited image")
#examples=[['a sitting dog','imagic-dog.png', 250], ['a photo of a bird spreading wings','imagic-bird.png',250]]
#ex = gr.Examples(examples=examples, fn=infer, inputs=[prompt_input,image_init,trn_steps], outputs=[image_output], cache_examples=False, run_on_click=False)
gr.HTML(article)
train_btn.click(fn=train, inputs=[prompt_input,image_init,trn_text,trn_steps], outputs=[training_status, trn_text, trn_steps])
gen_btn.click(fn=generate, inputs=[prompt_input,image_init,trn_text,trn_steps], outputs=[image_output])
block.queue(max_size=12).launch(show_api=False)