from PIL import Image
import numpy as np
import torch
import PIL
import os
import cv2
import mediapipe as mp
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
YOUR_TOKEN = os.environ.get('HF_TOKEN_SD')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "runwayml/stable-diffusion-inpainting"
if torch.cuda.is_available():
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16,
use_auth_token=YOUR_TOKEN).to(device)
else:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device)
def image_grid(imgs, cols, rows=1):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def mediapipe_segmentation(image_file, mask_file):
mp_drawing = mp.solutions.drawing_utils
mp_selfie_segmentation = mp.solutions.selfie_segmentation
# For static images:
BG_COLOR = (0, 0, 0) # gray
MASK_COLOR = (255, 255, 255) # white
with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation:
image = cv2.imread(image_file)
image_height, image_width, _ = image.shape
# Convert the BGR image to RGB before processing.
results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# blurred_image = cv2.GaussianBlur(image,(55,55),0)
# condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
# output_image = np.where(condition, image, blurred_image)
# Draw selfie segmentation on the background image.
# To improve segmentation around boundaries, consider applying a joint
# bilateral filter to "results.segmentation_mask" with "image".
condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
# Generate solid color images for showing the output selfie segmentation mask.
fg_image = np.zeros(image.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
output_image = np.where(condition, fg_image, bg_image)
cv2.imwrite(mask_file, output_image)
def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False):
image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512))
num_samples = int(num_samples) if num_samples <= 4 else 4
if not is_origin:
guidance_scale = 7.5
generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results
images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale,
generator=generator, num_images_per_prompt=num_samples).images
else:
images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images
# insert initial image in the list so we can compare side by side
# images.insert(0, image)
return image_grid(images, num_samples, 1)
title = "Person Matting & Stable Diffusion In-Painting"
description = "Inpainting Stable Diffusion
mediapipe + Stable Diffusion
"
def predict1(dict, prompt, num_samples):
dict['image'].save('image.png')
# dict['mask'].save('mask.png')
mediapipe_segmentation('image.png', 'm_mask.png')
image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='m_mask.png',
is_origin=False)
return image
def predict2(dict, prompt, num_samples):
dict['image'].save('image.png')
if 'mask' in dict:
dict['mask'].save('mask.png')
image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='mask.png',
is_origin=True)
return image
image_input = gr.Image(source='upload', tool='sketch', type='pil')
prompt = gr.Textbox(label='prompt')
number = gr.Slider(1, 4, value=2, label='num_samples')
init = {'image': PIL.Image.open(os.path.join(os.path.dirname(__file__), 'example1.png')),
'mask': PIL.Image.open(os.path.join(os.path.dirname(__file__), 'example1_mask.png'))}
examples = [[init, 'a bench in a field', 2]]
greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt, number], outputs=gr.Image(label='auto'))
greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt, number], outputs=gr.Image(label='paint'))
demo = gr.Parallel(greeter_1, greeter_2, examples=examples)
if __name__ == "__main__":
demo.launch(enable_queue=True)