from PIL import Image import math import numpy as np import torch import PIL import os import cv2 import mediapipe as mp import gradio as gr from diffusers import StableDiffusionInpaintPipeline YOUR_TOKEN = os.environ.get("API_TOKEN") or True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_path = "runwayml/stable-diffusion-inpainting" if torch.cuda.is_available(): pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16, use_auth_token=YOUR_TOKEN).to(device) else: pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device) def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = PIL.Image.new('RGB', size=(cols * w, rows * h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def mediapipe_segmentation(image_file, mask_file): mp_drawing = mp.solutions.drawing_utils mp_selfie_segmentation = mp.solutions.selfie_segmentation # For static images: BG_COLOR = (0, 0, 0) # gray MASK_COLOR = (255, 255, 255) # white with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation: image = cv2.imread(image_file) image_height, image_width, _ = image.shape # Convert the BGR image to RGB before processing. results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # blurred_image = cv2.GaussianBlur(image,(55,55),0) # condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 # output_image = np.where(condition, image, blurred_image) # Draw selfie segmentation on the background image. # To improve segmentation around boundaries, consider applying a joint # bilateral filter to "results.segmentation_mask" with "image". condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 # Generate solid color images for showing the output selfie segmentation mask. fg_image = np.zeros(image.shape, dtype=np.uint8) fg_image[:] = MASK_COLOR bg_image = np.zeros(image.shape, dtype=np.uint8) bg_image[:] = BG_COLOR output_image = np.where(condition, fg_image, bg_image) cv2.imwrite(mask_file, output_image) def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False): image = PIL.Image.open(image_path).convert("RGB").resize((512, 512)) mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512)) if not is_origin: guidance_scale = 7.5 generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale, generator=generator, num_images_per_prompt=num_samples).images else: images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images # insert initial image in the list so we can compare side by side # images.insert(0, image) return image_grid(images, 2, math.ceil(num_samples/2)) title = "Person Matting & Stable Diffusion In-Painting" description = "Inpainting Stable Diffusion
mediapipe + Stable Diffusion
" examples = [ [ "a bench in a field", 'example.png' ] ] def predict1(dict, prompt): dict['image'].save('image.png') # dict['mask'].save('mask.png') mediapipe_segmentation('image.png') image = image_inpainting(prompt, image_path='image.png', mask_image_path='m_mask.png', is_origin=False) return image def predict2(dict, prompt): dict['image'].save('image.png') dict['mask'].save('mask.png') image = image_inpainting(prompt, image_path='image.png', mask_image_path='mask.png', is_origin=True) return image image_input = gr.Image(source='upload', tool='sketch', type='pil') prompt = gr.Textbox(label='prompt') greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt], outputs=gr.Image(label='auto')) greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt], outputs=gr.Image(label='paint')) demo = gr.Parallel(greeter_1, greeter_2).launch(max_threads=True)