turbo_hc / app_hairdye.py
zhiweili
remove enhance
a44eef2
raw
history blame
2.02 kB
import numpy as np
import mediapipe as mp
import gradio as gr
import cv2
import spaces
from PIL import Image, ImageColor
from segment_utils import (
segmenter
)
@spaces.GPU(duration=1)
def do_nothing():
pass
def hair_dye(
input_image: Image,
color: str,
):
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
category_mask_np = category_mask.numpy_view()
hair_mask = category_mask_np == 1
cv2_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
hair_image = np.copy(cv2_image)
hair_image[~hair_mask] = 0
hair_hsv = cv2.cvtColor(hair_image, cv2.COLOR_BGR2HSV)
targetRgb = ImageColor.getcolor(color, "RGB")
targetHsv = cv2.cvtColor(np.array([[targetRgb]], dtype=np.uint8), cv2.COLOR_RGB2HSV)[0][0]
# change the hair hsv to target hsv
hair_hsv[..., 0] = targetHsv[0]
hair_hsv[..., 1] = targetHsv[1]
hair_bgr = cv2.cvtColor(hair_hsv, cv2.COLOR_HSV2BGR)
# paste the hair image to the original image by mask
cv2_image[hair_mask] = hair_bgr[hair_mask]
return Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB))
def create_demo() -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
color = gr.ColorPicker(label="Hair Color", value="#ff8080")
with gr.Column():
g_btn = gr.Button("Edit Image")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Column():
generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
g_btn.click(
fn=hair_dye,
inputs=[input_image, color, ],
outputs=[generated_image],
).success(
fn=do_nothing,
)
return demo