HopeKr's picture
Update app.py
d0d8da6
import inspect
import os
from typing import List, Optional, Union
import numpy as np
import torch
import PIL
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
from rembg import remove
import requests
from io import BytesIO
from huggingface_hub import login
token = os.getenv("WRITE_TOKEN")
login(token, True)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def predict(dict, prompt):
image = dict['image'].convert("RGB").resize((512, 512))
mask_image = dict['mask'].convert("RGB").resize((512, 512))
images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
return(images[0])
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
model_path = "runwayml/stable-diffusion-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_path,
# revision="fp16",
# torch_dtype=torch.float16,
use_auth_token=True
)
img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
image = download_image(img_url).resize((512, 512))
inverted_mask_image = remove(data = image, only_mask = True)
mask_image = PIL.ImageOps.invert(inverted_mask_image)
prompt = "crazy portal universe"
guidance_scale=7.5
num_samples = 3
generator = torch.Generator(device="cpu").manual_seed(0) # change the seed to get different results
images = pipe(
prompt=prompt,
image=image,
mask_image=mask_image,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=num_samples,
).images
images.insert(0, image)
image_grid(images, 1, num_samples + 1)
gr.Interface(
predict,
title = 'Stable Diffusion In-Painting',
inputs=[
gr.Image(source = 'upload', tool = 'sketch', type = 'pil'),
gr.Textbox(label = 'prompt')
],
outputs = [
gr.Image()
]
).launch(debug=True)