inpainting / app.py
Thomas Laurent
Add app
3ccc838
raw
history blame contribute delete
No virus
3.18 kB
# -*- coding: utf-8 -*-
"""in_painting_app
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1E_8Pn-aZGSq6Sf1DzYqSkcahkZu9AyDq
# In-painting pipeline for Stable Diffusion using 🧨 Diffusers
This notebook shows how to do text-guided in-painting with Stable Diffusion model using πŸ€— Hugging Face [🧨 Diffusers library](https://github.com/huggingface/diffusers).
For a general introduction to the Stable Diffusion model please refer to this [colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb).
"""
!pip install -qq -U diffusers==0.6.0 transformers ftfy gradio
!pip install git+https://github.com/huggingface/diffusers.git
!pip install diffusers[torch]
"""First , in order to use the model, you need to accept the model license before downloading or using the weights. In this post we'll use `runwayml/stable-diffusion-inpainting` model released by Runwayml so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-inpainting), read the license and tick the checkbox if you agree.
You have to be a registered user in πŸ€— Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
"""
# from huggingface_hub import notebook_login
# notebook_login()
import inspect
from typing import List, Optional, Union
import numpy as np
import torch
import PIL
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
device = "cuda"
model_path = "runwayml/stable-diffusion-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_path,
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True,
safety_checker=None
).to(device)
import requests
from io import BytesIO
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
"""### Gradio Demo"""
def predict(dict, prompt):
image = dict['image'].convert("RGB").resize((512, 512))
mask_image = dict['mask'].convert("RGB").resize((512, 512))
images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
return(images[0])
gr.Interface(
predict,
title = 'Stable Diffusion In-Painting',
inputs=[
gr.Image(source = 'upload', tool = 'sketch', type = 'pil'),
gr.Textbox(label = 'prompt')
],
outputs = [
gr.Image()
]
).launch(debug=True)