fffiloni's picture
Update app.py
bdca464
raw
history blame
2.33 kB
import gradio as gr
from PIL import Image
from io import BytesIO
import torch
import os
from diffusers import DiffusionPipeline, DDIMScheduler
MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
has_cuda = torch.cuda.is_available()
device = torch.device('cpu' if not has_cuda else 'cuda')
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
use_auth_token=MY_SECRET_TOKEN,
custom_pipeline="imagic_stable_diffusion",
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
).to(device)
#generator = torch.Generator("cuda").manual_seed(0)
def infer(prompt, init_image):
res = pipe.train(
prompt,
init_image,
guidance_scale=7.5,
num_inference_steps=50)
res = pipe(alpha=1)
return res.images[0]
title = """
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px;">
Imagic Stable Diffusion • Community Pipeline
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
Text-Based Real Image Editing with Diffusion Models
</p>
</div>
"""
article = """
"""
css = '''
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
'''
prompt_input = gr.Textbox()
image_init = gr.Image(source="upload", type="filepath")
image_output = gr.Image()
demo = gr.Interface(fn=infer, inputs=[prompt_input, image_init], outputs=image_output, title=title)
demo.launch()
with gr.Blocks(css=css) as block:
with gr.Column(elem_id="col-container"):
gr.HTML(title)
prompt_input = gr.Textbox()
image_init = gr.Image(source="upload", type="filepath")
submit_btn = gr.Button("Submit")
image_output = gr.Image()
#gr.HTML(article)
submit_btn.click(fn=infer, inputs=[prompt_input,image_init], outputs=[image_output])
block.queue(max_size=32,concurrency_count=20).launch(show_api=False)