Raphael
Fix nsfw black background
22587f6 unverified
import logging
import os
import time
import cv2
from diffusers import StableDiffusionPipeline
import gradio as gr
import mediapipe as mp
import numpy as np
import PIL
import torch.cuda
# from transformers import pipeline
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
force=True)
LOG = logging.getLogger(__name__)
LOG.info("Loading image segmentation model")
# seg_kwargs = {
# "task": "image-segmentation",
# "model": "nvidia/segformer-b0-finetuned-ade-512-512"
# }
#
# img_segmentation = pipeline(**seg_kwargs)
mp_selfie_segmentation = mp.solutions.selfie_segmentation
img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0)
LOG.info("Loading diffusion model")
diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
if torch.cuda.is_available():
LOG.info("Moving diffusion model to GPU")
diffusion.to('cuda')
def image_preprocess(image: PIL.Image):
LOG.info("Preprocessing image %s", image)
start = time.time()
# image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
image = resize_image(image)
image = np.array(image)
# Convert RGB to BGR
image = image[:, :, ::-1].copy()
elapsed = time.time() - start
LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed)
return image
def resize_image(image: PIL.Image):
width, height = image.size
ratio = max(width / 512, height / 512)
width = int(width / ratio) // 8 * 8
height = int(height / ratio) // 8 * 8
image = image.resize((width, height))
return image
def extract_selfie_mask(threshold, image):
LOG.info("Extracting selfie mask")
start = time.time()
image = img_segmentation_model.process(image)
mask = image.segmentation_mask
cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask)
cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask)
cv2.blur(mask, (10, 10), dst=mask)
elapsed = time.time() - start
LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed)
return mask
def generate_background(prompt, num_inference_steps, height, width):
LOG.info("Generating background")
start = time.time()
background = diffusion(
prompt=prompt,
num_inference_steps=int(num_inference_steps),
height=height,
width=width
)
nsfw = background.nsfw_content_detected[0]
background = background.images[0]
if nsfw:
LOG.info('NSFW detected, skipping')
background = np.zeros((height, width, 3), dtype='uint8')
else:
background = np.array(background)
# Convert RGB to BGR
background = background[:, :, ::-1].copy()
elapsed = time.time() - start
LOG.info("Background generated, elapsed %.2f seconds", elapsed)
return background
def merge_selfie_and_background(selfie, background, mask):
LOG.info("Merging extracted selfie and generated background")
cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie)
selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB)
selfie = PIL.Image.fromarray(selfie)
return selfie
def demo(threshold, image, prompt, num_inference_steps):
LOG.info("Processing image")
try:
image = image_preprocess(image)
mask = extract_selfie_mask(threshold, image)
background = generate_background(prompt, num_inference_steps,
image.shape[0], image.shape[1])
output = merge_selfie_and_background(image, background, mask)
except Exception as e:
LOG.error("Some unexpected error occured")
LOG.exception(e)
raise
return output
iface = gr.Interface(
fn=demo,
inputs=[
gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold",
value=0.8),
gr.Image(type='pil', label="Upload your selfie"),
gr.Text(value="a photo of the Eiffel tower on the right side",
label="Background description"),
gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps",
value=50)
],
outputs=[
gr.Image(label="Invent yourself a life :)")
])
# iface.launch(server_name="0.0.0.0", server_port=6443)
iface.launch()