import logging import os import time import cv2 from diffusers import StableDiffusionPipeline import gradio as gr # import mediapipe as mp import numpy as np import PIL import torch.cuda from transformers import pipeline os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True) LOG = logging.getLogger(__name__) LOG.info("Loading image segmentation model") seg_kwargs = { "task": "image-segmentation", "model": "nvidia/segformer-b0-finetuned-ade-512-512" } img_segmentation_model = pipeline(**seg_kwargs) # mp_selfie_segmentation = mp.solutions.selfie_segmentation # img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0) LOG.info("Loading diffusion model") diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") if torch.cuda.is_available(): LOG.info("Moving diffusion model to GPU") diffusion.to('cuda') def image_preprocess(image: PIL.Image): LOG.info("Preprocessing image %s", image) start = time.time() # image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") image = resize_image(image) # image = np.array(image) # # Convert RGB to BGR # image = image[:, :, ::-1].copy() elapsed = time.time() - start LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed) return image def resize_image(image: PIL.Image): width, height = image.size ratio = max(width / 512, height / 512) width = int(width / ratio) // 8 * 8 height = int(height / ratio) // 8 * 8 image = image.resize((width, height)) return image def extract_selfie_mask(threshold, image): LOG.info("Extracting selfie mask") start = time.time() segments = img_segmentation_model(image) kept = None for s in segments: if s['score'] is None: s['score'] = 1 if s['label'] == 'person' and s['score'] > 0.99: if not kept: kept = s elif kept['score'] < s['score']: kept = s if not kept: LOG.info("No person found in the photo, skipping") mask = np.zeros((image.size[1], image.size[0], 3), dtype='float32') else: mask = kept['mask'] mask = np.array(mask, dtype='float32') cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask) cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask) cv2.blur(mask, (10, 10), dst=mask) elapsed = time.time() - start LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed) return mask def generate_background(prompt, num_inference_steps, height, width): LOG.info("Generating background") start = time.time() background = diffusion( prompt=prompt, num_inference_steps=int(num_inference_steps), height=height, width=width ) nsfw = background.nsfw_content_detected[0] background = background.images[0] if nsfw: LOG.info('NSFW detected, skipping') background = np.zeros((height, width, 3), dtype='uint8') else: background = np.array(background) # Convert RGB to BGR background = background[:, :, ::-1].copy() elapsed = time.time() - start LOG.info("Background generated, elapsed %.2f seconds", elapsed) return background def merge_selfie_and_background(selfie, background, mask): LOG.info("Merging extracted selfie and generated background") selfie = np.array(selfie) # Convert RGB to BGR selfie = selfie[:, :, ::-1].copy() cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie) selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB) selfie = PIL.Image.fromarray(selfie) return selfie def demo(threshold, image, prompt, num_inference_steps): LOG.info("Processing image") try: image = image_preprocess(image) mask = extract_selfie_mask(threshold, image) background = generate_background(prompt, num_inference_steps, image.size[1], image.size[0]) output = merge_selfie_and_background(image, background, mask) except Exception as e: LOG.error("Some unexpected error occured") LOG.exception(e) raise return output iface = gr.Interface( fn=demo, inputs=[ gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold", value=0.8), gr.Image(type='pil', label="Upload your selfie"), gr.Text(value="a photo of the Eiffel tower on the right side", label="Background description"), gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps", value=50) ], outputs=[ gr.Image(label="Invent yourself a life :)") ]) # iface.launch(server_name="0.0.0.0", server_port=6443) iface.launch()