from PIL import Image import numpy as np import torch import PIL import os import cv2 import mediapipe as mp import gradio as gr from diffusers import StableDiffusionInpaintPipeline YOUR_TOKEN = os.environ.get('HF_TOKEN_SD') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_path = "runwayml/stable-diffusion-inpainting" if torch.cuda.is_available(): pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16, use_auth_token=YOUR_TOKEN).to(device) else: pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device) def image_grid(imgs, cols, rows=1): assert len(imgs) == rows * cols w, h = imgs[0].size grid = PIL.Image.new('RGB', size=(cols * w, rows * h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def mediapipe_segmentation(image_file, mask_file): mp_drawing = mp.solutions.drawing_utils mp_selfie_segmentation = mp.solutions.selfie_segmentation # For static images: BG_COLOR = (0, 0, 0) # gray MASK_COLOR = (255, 255, 255) # white with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation: image = cv2.imread(image_file) image_height, image_width, _ = image.shape # Convert the BGR image to RGB before processing. results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # blurred_image = cv2.GaussianBlur(image,(55,55),0) # condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 # output_image = np.where(condition, image, blurred_image) # Draw selfie segmentation on the background image. # To improve segmentation around boundaries, consider applying a joint # bilateral filter to "results.segmentation_mask" with "image". condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 # Generate solid color images for showing the output selfie segmentation mask. fg_image = np.zeros(image.shape, dtype=np.uint8) fg_image[:] = MASK_COLOR bg_image = np.zeros(image.shape, dtype=np.uint8) bg_image[:] = BG_COLOR output_image = np.where(condition, fg_image, bg_image) cv2.imwrite(mask_file, output_image) def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False): image = PIL.Image.open(image_path).convert("RGB").resize((512, 512)) mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512)) num_samples = int(num_samples) if num_samples <= 4 else 4 if not is_origin: guidance_scale = 7.5 generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale, generator=generator, num_images_per_prompt=num_samples).images else: images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images # insert initial image in the list so we can compare side by side # images.insert(0, image) return image_grid(images, num_samples, 1) title = "Person Matting & Stable Diffusion In-Painting" description = "Inpainting Stable Diffusion
mediapipe + Stable Diffusion
" examples = [ [os.path.dirname(__file__) + 'example1.png', 'a bench in a field', 2], # [os.path.join(os.path.dirname(__file__), "example2.png"), 'a building with many steps', 2], # [os.path.join(os.path.dirname(__file__), "example3.png"), 'a big ship parked on the shore', 2] ] def predict1(dict, prompt, num_samples): dict['image'].save('image.png') # dict['mask'].save('mask.png') mediapipe_segmentation('image.png', 'm_mask.png') image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='m_mask.png', is_origin=False) return image def predict2(dict, prompt, num_samples): dict['image'].save('image.png') dict['mask'].save('mask.png') image = image_inpainting(prompt, num_samples=num_samples, image_path='image.png', mask_image_path='mask.png', is_origin=True) return image image_input = gr.Image(source='upload', tool='sketch', type='pil') prompt = gr.Textbox(label='prompt') number = gr.Slider(1, 4, value= 2, label='num_samples') greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt, number], outputs=gr.Image(label='auto')) greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt, number], outputs=gr.Image(label='paint')) demo = gr.Parallel(greeter_1, greeter_2, examples=examples) if __name__ == "__main__": demo.launch(enable_queue=True)