import logging import os import time import cv2 from diffusers import StableDiffusionPipeline import gradio as gr import mediapipe as mp import numpy as np import PIL import torch.cuda # from transformers import pipeline os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', force=True) LOG = logging.getLogger(__name__) LOG.info("Loading image segmentation model") # seg_kwargs = { # "task": "image-segmentation", # "model": "nvidia/segformer-b0-finetuned-ade-512-512" # } # # img_segmentation = pipeline(**seg_kwargs) mp_selfie_segmentation = mp.solutions.selfie_segmentation img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0) LOG.info("Loading diffusion model") diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") if torch.cuda.is_available(): LOG.info("Moving diffusion model to GPU") diffusion.to('cuda') def image_preprocess(image: PIL.Image): LOG.info("Preprocessing image %s", image) start = time.time() # image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") image = resize_image(image) image = np.array(image) # Convert RGB to BGR image = image[:, :, ::-1].copy() elapsed = time.time() - start LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed) return image def resize_image(image: PIL.Image): width, height = image.size ratio = max(width / 512, height / 512) width = int(width / ratio) // 8 * 8 height = int(height / ratio) // 8 * 8 image = image.resize((width, height)) return image def extract_selfie_mask(threshold, image): LOG.info("Extracting selfie mask") start = time.time() image = img_segmentation_model.process(image) mask = image.segmentation_mask cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask) cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask) cv2.blur(mask, (10, 10), dst=mask) elapsed = time.time() - start LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed) return mask def generate_background(prompt, num_inference_steps, height, width): LOG.info("Generating background") start = time.time() background = diffusion( prompt=prompt, num_inference_steps=int(num_inference_steps), height=height, width=width ) nsfw = background.nsfw_content_detected[0] background = background.images[0] if nsfw: LOG.info('NSFW detected, skipping') background = np.zeros((height, width, 3), dtype='uint8') else: background = np.array(background) # Convert RGB to BGR background = background[:, :, ::-1].copy() elapsed = time.time() - start LOG.info("Background generated, elapsed %.2f seconds", elapsed) return background def merge_selfie_and_background(selfie, background, mask): LOG.info("Merging extracted selfie and generated background") cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie) selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB) selfie = PIL.Image.fromarray(selfie) return selfie def demo(threshold, image, prompt, num_inference_steps): LOG.info("Processing image") try: image = image_preprocess(image) mask = extract_selfie_mask(threshold, image) background = generate_background(prompt, num_inference_steps, image.shape[0], image.shape[1]) output = merge_selfie_and_background(image, background, mask) except Exception as e: LOG.error("Some unexpected error occured") LOG.exception(e) raise return output iface = gr.Interface( fn=demo, inputs=[ gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold", value=0.8), gr.Image(type='pil', label="Upload your selfie"), gr.Text(value="a photo of the Eiffel tower on the right side", label="Background description"), gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps", value=50) ], outputs=[ gr.Image(label="Invent yourself a life :)") ]) # iface.launch(server_name="0.0.0.0", server_port=6443) iface.launch()