Spaces:
Running
Running
File size: 5,774 Bytes
984b1c3 0bda8ba 984b1c3 d00e30a 6e5e70e 984b1c3 1c2f991 ba17c5b 1c2f991 984b1c3 91a732d 0bda8ba 12f978b 91a732d 984b1c3 0bda8ba 984b1c3 12f978b 984b1c3 d3f9ca8 984b1c3 d00e30a 0bda8ba d00e30a 91a732d 0bda8ba 91a732d 4175fd1 0bda8ba 4175fd1 0bda8ba 4175fd1 91a732d d3f9ca8 ba17c5b 0bda8ba d3f9ca8 ba17c5b 91a732d 0bda8ba ba17c5b 908fc7b 0bda8ba d00e30a 908fc7b 12f978b 0bda8ba 12f978b d3f9ca8 0bda8ba 12f978b 0bda8ba a22d3b1 908fc7b 0bda8ba 95d0b08 a22d3b1 0bda8ba a22d3b1 984b1c3 7da70bd 0bda8ba 984b1c3 0bda8ba 7da70bd 0bda8ba 7da70bd 0bda8ba 7da70bd 0bda8ba 7da70bd 0bda8ba 7da70bd 0bda8ba 7da70bd 984b1c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET
# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
def refine_mask(mask):
"""Enhanced mask refinement with erosion and morphological operations"""
# First closing to fill small holes
close_kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
# Erosion to remove small protrusions and extra areas
erode_kernel = np.ones((3, 3), np.uint8)
mask = cv2.erode(mask, erode_kernel, iterations=1)
# Second closing to refine edges after erosion
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
# Final blur to smooth edges while preserving shape
mask = cv2.GaussianBlur(mask, (5, 5), 1.5)
return mask
def segment_dress(image_np):
"""Improved dress segmentation with adaptive thresholding"""
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
# Adaptive threshold calculation
output = (output - output.min()) / (output.max() - output.min() + 1e-8)
adaptive_thresh = np.mean(output) + 0.2 # Increased threshold for tighter mask
dress_mask = (output > adaptive_thresh).astype(np.uint8) * 255
# Preserve hard edges during resize
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]),
interpolation=cv2.INTER_NEAREST)
return refine_mask(dress_mask)
def apply_grabcut(image_np, dress_mask):
"""Mask refinement using GrabCut"""
bgd_model = np.zeros((1, 65), np.float64)
fgd_model = np.zeros((1, 65), np.float64)
mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
# Get bounding box coordinates
coords = cv2.findNonZero(dress_mask)
if coords is not None:
x, y, w, h = cv2.boundingRect(coords)
rect = (x, y, w, h)
cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 3, cv2.GC_INIT_WITH_MASK)
refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
return refine_mask(refined_mask)
def recolor_dress(image_np, dress_mask, target_color):
"""Color transformation with improved blending"""
# Convert colors to LAB space
target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
# Calculate color shifts
dress_pixels = img_lab[dress_mask > 0]
if len(dress_pixels) == 0:
return image_np
mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
a_shift = target_color_lab[1] - mean_A
b_shift = target_color_lab[2] - mean_B
# Apply color transformation
img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)
# Create adaptive blending mask
img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
feathered_mask = cv2.GaussianBlur(dress_mask, (21, 21), 7)
lightness_mask = (img_lab[..., 0] / 255.0) ** 0.7
adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)
# Smooth blending
return (image_np * (1 - adaptive_feather[..., None]/255) + img_recolored * (adaptive_feather[..., None]/255)).astype(np.uint8)
def change_dress_color(img, color):
"""Main processing function with error handling"""
if img is None:
return None
color_map = {
"Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0),
"Yellow": (0, 255, 255), "Purple": (128, 0, 128), "Orange": (0, 165, 255),
"Cyan": (255, 255, 0), "Magenta": (255, 0, 255), "White": (255, 255, 255),
"Black": (0, 0, 0)
}
new_color_bgr = color_map.get(color, (0, 0, 255))
img_np = np.array(img)
try:
dress_mask = segment_dress(img_np)
if np.sum(dress_mask) < 1000: # Minimum mask area threshold
return img
dress_mask = apply_grabcut(img_np, dress_mask)
img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
return Image.fromarray(img_recolored)
except Exception as e:
print(f"Error processing image: {str(e)}")
return img
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# AI Dress Color Changer")
gr.Markdown("Upload a dress image and select a new color for realistic recoloring")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
color_choice = gr.Dropdown(
choices=["Red", "Blue", "Green", "Yellow", "Purple",
"Orange", "Cyan", "Magenta", "White", "Black"],
value="Red",
label="Select New Color"
)
process_btn = gr.Button("Recolor Dress")
with gr.Column():
output_image = gr.Image(type="pil", label="Result")
process_btn.click(
fn=change_dress_color,
inputs=[input_image, color_choice],
outputs=output_image
)
if __name__ == "__main__":
demo.launch() |