Spaces:
Runtime error
Runtime error
from PIL import Image | |
import numpy as np | |
import torch | |
import PIL | |
import os | |
import cv2 | |
import mediapipe as mp | |
import gradio as gr | |
from diffusers import StableDiffusionInpaintPipeline | |
YOUR_TOKEN = os.environ.get('HF_TOKEN_SD') | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model_path = "runwayml/stable-diffusion-inpainting" | |
if torch.cuda.is_available(): | |
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16, | |
use_auth_token=YOUR_TOKEN).to(device) | |
else: | |
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device) | |
def image_grid(imgs, cols, rows=1): | |
assert len(imgs) == rows * cols | |
w, h = imgs[0].size | |
grid = PIL.Image.new('RGB', size=(cols * w, rows * h)) | |
grid_w, grid_h = grid.size | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i % cols * w, i // cols * h)) | |
return grid | |
def mediapipe_segmentation(image_file, mask_file): | |
mp_drawing = mp.solutions.drawing_utils | |
mp_selfie_segmentation = mp.solutions.selfie_segmentation | |
# For static images: | |
BG_COLOR = (0, 0, 0) # gray | |
MASK_COLOR = (255, 255, 255) # white | |
with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation: | |
image = cv2.imread(image_file) | |
image_height, image_width, _ = image.shape | |
# Convert the BGR image to RGB before processing. | |
results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
# blurred_image = cv2.GaussianBlur(image,(55,55),0) | |
# condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 | |
# output_image = np.where(condition, image, blurred_image) | |
# Draw selfie segmentation on the background image. | |
# To improve segmentation around boundaries, consider applying a joint | |
# bilateral filter to "results.segmentation_mask" with "image". | |
condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 | |
# Generate solid color images for showing the output selfie segmentation mask. | |
fg_image = np.zeros(image.shape, dtype=np.uint8) | |
fg_image[:] = MASK_COLOR | |
bg_image = np.zeros(image.shape, dtype=np.uint8) | |
bg_image[:] = BG_COLOR | |
output_image = np.where(condition, fg_image, bg_image) | |
cv2.imwrite(mask_file, output_image) | |
def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False): | |
image = PIL.Image.open(image_path).convert("RGB").resize((512, 512)) | |
mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512)) | |
num_samples = int(num_samples) if num_samples <= 4 else 4 | |
if not is_origin: | |
guidance_scale = 7.5 | |
generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results | |
images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale, | |
generator=generator, num_images_per_prompt=num_samples).images | |
else: | |
images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images | |
# insert initial image in the list so we can compare side by side | |
# images.insert(0, image) | |
return image_grid(images, num_samples, 1) | |
title = "Person Matting & Stable Diffusion In-Painting" | |
description = "Inpainting Stable Diffusion <br/><b>mediapipe + Stable Diffusion<b/><br/>" | |
def predict1(dict, prompt, num_samples): | |
dict['image'].save('image.png') | |
# dict['mask'].save('mask.png') | |
mediapipe_segmentation('image.png', 'm_mask.png') | |
image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='m_mask.png', | |
is_origin=False) | |
return image | |
def predict2(dict, prompt, num_samples): | |
dict['image'].save('image.png') | |
if 'mask' in dict: | |
dict['mask'].save('mask.png') | |
image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='mask.png', | |
is_origin=True) | |
return image | |
image_input = gr.Image(source='upload', tool='sketch', type='pil') | |
prompt = gr.Textbox(label='prompt') | |
number = gr.Slider(1, 4, value=2, label='num_samples') | |
# | |
# init = {'image': PIL.Image.open(os.path.join(os.path.dirname(__file__), 'example1.png')), | |
# 'mask': PIL.Image.open(os.path.join(os.path.dirname(__file__), 'example1_mask.png'))} | |
# examples = [[init, 'a bench in a field', 2]] | |
greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt, number], outputs=gr.Image(label='auto')) | |
greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt, number], outputs=gr.Image(label='paint')) | |
demo = gr.Parallel(greeter_1, greeter_2) | |
if __name__ == "__main__": | |
demo.launch(enable_queue=True) | |