raphael-gl's picture
raphael-gl HF staff
Duplicate from raphael-gl/ai-days-image-background-substitution
56bed35
import logging
import os
import time
import cv2
from diffusers import StableDiffusionPipeline
import gradio as gr
# import mediapipe as mp
import numpy as np
import PIL
import torch.cuda
from transformers import pipeline
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
force=True)
LOG = logging.getLogger(__name__)
LOG.info("Loading image segmentation model")
seg_kwargs = {
"task": "image-segmentation",
"model": "nvidia/segformer-b0-finetuned-ade-512-512"
}
img_segmentation_model = pipeline(**seg_kwargs)
# mp_selfie_segmentation = mp.solutions.selfie_segmentation
# img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0)
LOG.info("Loading diffusion model")
diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
if torch.cuda.is_available():
LOG.info("Moving diffusion model to GPU")
diffusion.to('cuda')
def image_preprocess(image: PIL.Image):
LOG.info("Preprocessing image %s", image)
start = time.time()
# image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
image = resize_image(image)
# image = np.array(image)
# # Convert RGB to BGR
# image = image[:, :, ::-1].copy()
elapsed = time.time() - start
LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed)
return image
def resize_image(image: PIL.Image):
width, height = image.size
ratio = max(width / 512, height / 512)
width = int(width / ratio) // 8 * 8
height = int(height / ratio) // 8 * 8
image = image.resize((width, height))
return image
def extract_selfie_mask(threshold, image):
LOG.info("Extracting selfie mask")
start = time.time()
segments = img_segmentation_model(image)
kept = None
for s in segments:
if s['score'] is None:
s['score'] = 1
if s['label'] == 'person' and s['score'] > 0.99:
if not kept:
kept = s
elif kept['score'] < s['score']:
kept = s
if not kept:
LOG.info("No person found in the photo, skipping")
mask = np.zeros((image.size[1], image.size[0], 3), dtype='float32')
else:
mask = kept['mask']
mask = np.array(mask, dtype='float32')
cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask)
cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask)
cv2.blur(mask, (10, 10), dst=mask)
elapsed = time.time() - start
LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed)
return mask
def generate_background(prompt, num_inference_steps, height, width):
LOG.info("Generating background")
start = time.time()
background = diffusion(
prompt=prompt,
num_inference_steps=int(num_inference_steps),
height=height,
width=width
)
nsfw = background.nsfw_content_detected[0]
background = background.images[0]
if nsfw:
LOG.info('NSFW detected, skipping')
background = np.zeros((height, width, 3), dtype='uint8')
else:
background = np.array(background)
# Convert RGB to BGR
background = background[:, :, ::-1].copy()
elapsed = time.time() - start
LOG.info("Background generated, elapsed %.2f seconds", elapsed)
return background
def merge_selfie_and_background(selfie, background, mask):
LOG.info("Merging extracted selfie and generated background")
selfie = np.array(selfie)
# Convert RGB to BGR
selfie = selfie[:, :, ::-1].copy()
cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie)
selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB)
selfie = PIL.Image.fromarray(selfie)
return selfie
def demo(threshold, image, prompt, num_inference_steps):
LOG.info("Processing image")
try:
image = image_preprocess(image)
mask = extract_selfie_mask(threshold, image)
background = generate_background(prompt, num_inference_steps,
image.size[1], image.size[0])
output = merge_selfie_and_background(image, background, mask)
except Exception as e:
LOG.error("Some unexpected error occured")
LOG.exception(e)
raise
return output
iface = gr.Interface(
fn=demo,
inputs=[
gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold",
value=0.8),
gr.Image(type='pil', label="Upload your selfie"),
gr.Text(value="a photo of the Eiffel tower on the right side",
label="Background description"),
gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps",
value=50)
],
outputs=[
gr.Image(label="Invent yourself a life :)")
])
# iface.launch(server_name="0.0.0.0", server_port=6443)
iface.launch()