|
import numpy as np |
|
import mediapipe as mp |
|
import gradio as gr |
|
import cv2 |
|
import spaces |
|
|
|
from PIL import Image, ImageColor |
|
from segment_utils import ( |
|
segmenter |
|
) |
|
|
|
@spaces.GPU(duration=1) |
|
def do_nothing(): |
|
pass |
|
|
|
def hair_dye( |
|
input_image: Image, |
|
color: str, |
|
): |
|
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) |
|
segmentation_result = segmenter.segment(image) |
|
category_mask = segmentation_result.category_mask |
|
category_mask_np = category_mask.numpy_view() |
|
hair_mask = category_mask_np == 1 |
|
|
|
cv2_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) |
|
hair_image = np.copy(cv2_image) |
|
hair_image[~hair_mask] = 0 |
|
|
|
hair_hsv = cv2.cvtColor(hair_image, cv2.COLOR_BGR2HSV) |
|
|
|
targetRgb = ImageColor.getcolor(color, "RGB") |
|
targetHsv = cv2.cvtColor(np.array([[targetRgb]], dtype=np.uint8), cv2.COLOR_RGB2HSV)[0][0] |
|
|
|
|
|
hair_hsv[..., 0] = targetHsv[0] |
|
hair_hsv[..., 1] = targetHsv[1] |
|
|
|
hair_bgr = cv2.cvtColor(hair_hsv, cv2.COLOR_HSV2BGR) |
|
|
|
cv2_image[hair_mask] = hair_bgr[hair_mask] |
|
|
|
return Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
def create_demo() -> gr.Blocks: |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
color = gr.ColorPicker(label="Hair Color", value="#ff8080") |
|
with gr.Column(): |
|
g_btn = gr.Button("Edit Image") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Input Image", type="pil") |
|
with gr.Column(): |
|
generated_image = gr.Image(label="Generated Image", type="pil", interactive=False) |
|
|
|
|
|
g_btn.click( |
|
fn=hair_dye, |
|
inputs=[input_image, color, ], |
|
outputs=[generated_image], |
|
).success( |
|
fn=do_nothing, |
|
) |
|
|
|
return demo |