RunningYou's picture
mask
dcfd465
raw
history blame contribute delete
No virus
5.03 kB
from PIL import Image
import numpy as np
import torch
import PIL
import os
import cv2
import mediapipe as mp
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
YOUR_TOKEN = os.environ.get('HF_TOKEN_SD')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "runwayml/stable-diffusion-inpainting"
if torch.cuda.is_available():
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16,
use_auth_token=YOUR_TOKEN).to(device)
else:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device)
def image_grid(imgs, cols, rows=1):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def mediapipe_segmentation(image_file, mask_file):
mp_drawing = mp.solutions.drawing_utils
mp_selfie_segmentation = mp.solutions.selfie_segmentation
# For static images:
BG_COLOR = (0, 0, 0) # gray
MASK_COLOR = (255, 255, 255) # white
with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation:
image = cv2.imread(image_file)
image_height, image_width, _ = image.shape
# Convert the BGR image to RGB before processing.
results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# blurred_image = cv2.GaussianBlur(image,(55,55),0)
# condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
# output_image = np.where(condition, image, blurred_image)
# Draw selfie segmentation on the background image.
# To improve segmentation around boundaries, consider applying a joint
# bilateral filter to "results.segmentation_mask" with "image".
condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
# Generate solid color images for showing the output selfie segmentation mask.
fg_image = np.zeros(image.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
output_image = np.where(condition, fg_image, bg_image)
cv2.imwrite(mask_file, output_image)
def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False):
image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512))
num_samples = int(num_samples) if num_samples <= 4 else 4
if not is_origin:
guidance_scale = 7.5
generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results
images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale,
generator=generator, num_images_per_prompt=num_samples).images
else:
images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images
# insert initial image in the list so we can compare side by side
# images.insert(0, image)
return image_grid(images, num_samples, 1)
title = "Person Matting & Stable Diffusion In-Painting"
description = "Inpainting Stable Diffusion <br/><b>mediapipe + Stable Diffusion<b/><br/>"
def predict1(dict, prompt, num_samples):
dict['image'].save('image.png')
# dict['mask'].save('mask.png')
mediapipe_segmentation('image.png', 'm_mask.png')
image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='m_mask.png',
is_origin=False)
return image
def predict2(dict, prompt, num_samples):
dict['image'].save('image.png')
if 'mask' in dict:
dict['mask'].save('mask.png')
image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='mask.png',
is_origin=True)
return image
image_input = gr.Image(source='upload', tool='sketch', type='pil')
prompt = gr.Textbox(label='prompt')
number = gr.Slider(1, 4, value=2, label='num_samples')
examples = [
[os.path.join(os.path.dirname(__file__), 'example1.png'), 'a bench in a field', 2],
# [os.path.join(os.path.dirname(__file__), 'example2.png'), 'a big ship parked on the shore', 2],
# [os.path.join(os.path.dirname(__file__), 'example3.png'), 'a palace with many steps', 2]
]
greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt, number], outputs=gr.Image(label='auto'))
greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt, number], outputs=gr.Image(label='paint'))
demo = gr.Parallel(greeter_1, greeter_2, examples=examples, cache_examples=False)
if __name__ == "__main__":
demo.launch(enable_queue=True)