import os
import gradio as gr
import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline
os.environ['GRADIO_THEME'] = 'default'
# load the pipeline
device = "cuda"
model_id_img2img = "runwayml/stable-diffusion-v1-5"
img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_img2img, torch_dtype=torch.float16)
img2img_pipe = img2img_pipe.to(device)
model_id_inpaint = "runwayml/stable-diffusion-inpainting"
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id_inpaint, torch_dtype=torch.float16)
inpaint_pipe = inpaint_pipe.to(device)
def img2img_diff(prompt, pil_img):
img = pil_img.resize((768, 512))
return img2img_pipe(prompt=prompt, image=img, strength=0.75, guidance_scale=7.5).images[0]
def imginpaint_diff(prompt, pil_img, mask_pil_img):
return inpaint_pipe(prompt=prompt, image=pil_img, mask_image=mask_pil_img).images[0]
def header_html(title):
return f"""
"""
with gr.Blocks() as block:
with gr.Group():
with gr.Box():
gr.HTML(header_html("diffusion image to image transform"))
with gr.Row():
with gr.Column():
input_img = gr.Image(type='pil', label='draft image')
with gr.Row():
input_prompt = gr.Text(lable="prompt text")
sumit_button = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
output_img = gr.Image(type="pil")
sumit_button.click(img2img_diff, inputs=[input_prompt, input_img], outputs=[output_img])
with gr.Box():
gr.HTML(header_html("diffusion image inpaint"))
with gr.Row():
with gr.Column():
input_img = gr.Image(type='pil', label='origin image')
mask_img = gr.Image(type='pil', label='mask image')
with gr.Row():
input_prompt = gr.Text(lable="prompt text")
sumit_button = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
output_img = gr.Image(type="pil")
sumit_button.click(imginpaint_diff, inputs=[input_prompt, input_img, mask_img], outputs=[output_img])
block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)