Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
from load_model import sample | |
import torch | |
import random | |
import os | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
device = "mps" if torch.backends.mps.is_available() else device | |
image_size = 128 | |
def show_example_fn(): | |
sketch = Image.open("examples/sketch.png") | |
scribble_folder = "./examples/scribbles/" | |
png_files = [f for f in os.listdir(scribble_folder) if f.lower().endswith(".png")] | |
# get random scribble | |
random_scribble = Image.open( | |
os.path.join(scribble_folder, random.choice(png_files)) | |
) | |
return [sketch, random_scribble] | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((image_size, image_size)), | |
transforms.ToTensor(), | |
transforms.Lambda(lambda t: (t * 2) - 1), | |
] | |
) | |
def process_images( | |
sketch, | |
scribbles, | |
sampling_steps, | |
seed_nr, | |
upscale, | |
progress=gr.Progress(), | |
): | |
w, h = sketch.size | |
sketch = transform(sketch.convert("RGB")) | |
scribbles = transform(scribbles.convert("RGB")) | |
if upscale: | |
return transforms.Resize((h, w))( | |
sample(sketch, scribbles, sampling_steps, seed_nr, progress) | |
) | |
else: | |
return sample(sketch, scribbles, sampling_steps, seed_nr, progress) | |
theme = gr.themes.Monochrome() | |
with gr.Blocks(theme=theme) as demo: | |
with gr.Row(): | |
gr.Markdown( | |
"<h1 style='text-align: center; font-size: 30px;'>Image Inpainting with Conditional Diffusion by MedicAI</h1>" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
sketch_input = gr.Image(type="pil", label="Sketch", height=500) | |
with gr.Column(): | |
scribbles_input = gr.Image(type="pil", label="Scribbles", height=500) | |
with gr.Column(): | |
output = gr.Image(type="pil", label="Output") | |
with gr.Row(): | |
with gr.Column(): | |
seed_slider = gr.Number( | |
label="Random Seed π² (if the image generated is not to your liking, simply use another seed)", | |
value=5, | |
) | |
upscale_button = gr.Checkbox( | |
label=f"Stretch (If you want to stretch the downloadable output to the input size, check this box, the default output of neural networks is {image_size}x{image_size} )", | |
value=False, | |
) | |
with gr.Column(): | |
sampling_slider = gr.Slider( | |
minimum=1, | |
maximum=250, | |
step=1, | |
label="DDPM Sampling Steps π (the higher the number of steps the higher the quality of the images)", | |
value=50, | |
) | |
show_example = gr.Button(value="Show Example Input ") | |
with gr.Row(): | |
generate_button = gr.Button(value="Paint π¨ ") | |
with gr.Row(): | |
generate_info = gr.Markdown( | |
"<p style='text-align: center; font-size: 16px;'>" | |
"Notes: Depending on where you run this demo, it might take a while to generate the output. For the HF space it may take up to 20 minutes for 100 sampling steps. We recommend lowering the sampling steps to 10 for the HF space. Model trained using this <a href='https://huggingface.co/datasets/pawlo2013/anime_diffusion_full'>dataset</a>." | |
"</p>" | |
) | |
show_example.click( | |
show_example_fn, | |
inputs=[], | |
outputs=[sketch_input, scribbles_input], | |
concurrency_limit=1, | |
trigger_mode="once", | |
) | |
generate_button.click( | |
process_images, | |
inputs=[ | |
sketch_input, | |
scribbles_input, | |
sampling_slider, | |
seed_slider, | |
upscale_button, | |
], | |
outputs=output, | |
concurrency_limit=1, | |
trigger_mode="once", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |