Spaces:
Runtime error
Runtime error
File size: 4,449 Bytes
ffab2fd 931a3ea ffab2fd b3556d1 ffab2fd b3556d1 931a3ea b3556d1 ffab2fd b3556d1 ffab2fd 73e49ee ffab2fd 73e49ee ffab2fd 73e49ee ffab2fd 73e49ee 3183ff3 ffab2fd 73e49ee ffab2fd 73e49ee ffab2fd 73e49ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from PIL import Image
import math
import numpy as np
import torch
import PIL
import os
import cv2
import mediapipe as mp
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
YOUR_TOKEN = os.environ.get("API_TOKEN") or True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "runwayml/stable-diffusion-inpainting"
if torch.cuda.is_available():
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16,
use_auth_token=YOUR_TOKEN).to(device)
else:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def mediapipe_segmentation(image_file, mask_file):
mp_drawing = mp.solutions.drawing_utils
mp_selfie_segmentation = mp.solutions.selfie_segmentation
# For static images:
BG_COLOR = (0, 0, 0) # gray
MASK_COLOR = (255, 255, 255) # white
with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation:
image = cv2.imread(image_file)
image_height, image_width, _ = image.shape
# Convert the BGR image to RGB before processing.
results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# blurred_image = cv2.GaussianBlur(image,(55,55),0)
# condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
# output_image = np.where(condition, image, blurred_image)
# Draw selfie segmentation on the background image.
# To improve segmentation around boundaries, consider applying a joint
# bilateral filter to "results.segmentation_mask" with "image".
condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
# Generate solid color images for showing the output selfie segmentation mask.
fg_image = np.zeros(image.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
output_image = np.where(condition, fg_image, bg_image)
cv2.imwrite(mask_file, output_image)
def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False):
image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512))
if not is_origin:
guidance_scale = 7.5
generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results
images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale, generator=generator,
num_images_per_prompt=num_samples).images
else:
images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images
# insert initial image in the list so we can compare side by side
# images.insert(0, image)
return image_grid(images, 2, math.ceil(num_samples/2))
title = "Person Matting & Stable Diffusion In-Painting"
description = "Inpainting Stable Diffusion <br/><b>mediapipe + Stable Diffusion<b/><br/>"
examples = [
[
"a bench in a field",
'example.png'
]
]
def predict1(dict, prompt):
dict['image'].save('image.png')
# dict['mask'].save('mask.png')
mediapipe_segmentation('image.png')
image = image_inpainting(prompt, image_path='image.png', mask_image_path='m_mask.png', is_origin=False)
return image
def predict2(dict, prompt):
dict['image'].save('image.png')
dict['mask'].save('mask.png')
image = image_inpainting(prompt, image_path='image.png', mask_image_path='mask.png', is_origin=True)
return image
image_input = gr.Image(source='upload', tool='sketch', type='pil')
prompt = gr.Textbox(label='prompt')
greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt], outputs=gr.Image(label='auto'))
greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt], outputs=gr.Image(label='paint'))
demo = gr.Parallel(greeter_1, greeter_2).launch(max_threads=True)
|