Spaces:
Runtime error
Runtime error
File size: 5,027 Bytes
ffab2fd 931a3ea ffab2fd b3556d1 ffab2fd b3556d1 9130e8d b3556d1 ffab2fd b3556d1 ffab2fd 1e2c561 ffab2fd 73e49ee ffab2fd 1e2c561 73e49ee ffab2fd 497aa6f 73e49ee ffab2fd 1e2c561 ffab2fd 73e49ee 3183ff3 1e2c561 ffab2fd 73e49ee b88cbc2 497aa6f 73e49ee ffab2fd 73e49ee 1e2c561 73e49ee a323a6f 497aa6f ffab2fd dcfd465 a46d851 e5632e8 acbffdf 6f6abe5 acbffdf f6c4ee3 1b478df 96a0d26 1b478df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from PIL import Image
import numpy as np
import torch
import PIL
import os
import cv2
import mediapipe as mp
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
YOUR_TOKEN = os.environ.get('HF_TOKEN_SD')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "runwayml/stable-diffusion-inpainting"
if torch.cuda.is_available():
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16,
use_auth_token=YOUR_TOKEN).to(device)
else:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device)
def image_grid(imgs, cols, rows=1):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def mediapipe_segmentation(image_file, mask_file):
mp_drawing = mp.solutions.drawing_utils
mp_selfie_segmentation = mp.solutions.selfie_segmentation
# For static images:
BG_COLOR = (0, 0, 0) # gray
MASK_COLOR = (255, 255, 255) # white
with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation:
image = cv2.imread(image_file)
image_height, image_width, _ = image.shape
# Convert the BGR image to RGB before processing.
results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# blurred_image = cv2.GaussianBlur(image,(55,55),0)
# condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
# output_image = np.where(condition, image, blurred_image)
# Draw selfie segmentation on the background image.
# To improve segmentation around boundaries, consider applying a joint
# bilateral filter to "results.segmentation_mask" with "image".
condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
# Generate solid color images for showing the output selfie segmentation mask.
fg_image = np.zeros(image.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
output_image = np.where(condition, fg_image, bg_image)
cv2.imwrite(mask_file, output_image)
def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False):
image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512))
num_samples = int(num_samples) if num_samples <= 4 else 4
if not is_origin:
guidance_scale = 7.5
generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results
images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale,
generator=generator, num_images_per_prompt=num_samples).images
else:
images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images
# insert initial image in the list so we can compare side by side
# images.insert(0, image)
return image_grid(images, num_samples, 1)
title = "Person Matting & Stable Diffusion In-Painting"
description = "Inpainting Stable Diffusion <br/><b>mediapipe + Stable Diffusion<b/><br/>"
def predict1(dict, prompt, num_samples):
dict['image'].save('image.png')
# dict['mask'].save('mask.png')
mediapipe_segmentation('image.png', 'm_mask.png')
image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='m_mask.png',
is_origin=False)
return image
def predict2(dict, prompt, num_samples):
dict['image'].save('image.png')
if 'mask' in dict:
dict['mask'].save('mask.png')
image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='mask.png',
is_origin=True)
return image
image_input = gr.Image(source='upload', tool='sketch', type='pil')
prompt = gr.Textbox(label='prompt')
number = gr.Slider(1, 4, value=2, label='num_samples')
examples = [
[os.path.join(os.path.dirname(__file__), 'example1.png'), 'a bench in a field', 2],
# [os.path.join(os.path.dirname(__file__), 'example2.png'), 'a big ship parked on the shore', 2],
# [os.path.join(os.path.dirname(__file__), 'example3.png'), 'a palace with many steps', 2]
]
greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt, number], outputs=gr.Image(label='auto'))
greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt, number], outputs=gr.Image(label='paint'))
demo = gr.Parallel(greeter_1, greeter_2, examples=examples, cache_examples=False)
if __name__ == "__main__":
demo.launch(enable_queue=True)
|