File size: 5,774 Bytes
984b1c3
 
 
 
 
 
0bda8ba
984b1c3
d00e30a
6e5e70e
984b1c3
1c2f991
ba17c5b
1c2f991
984b1c3
 
91a732d
0bda8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f978b
91a732d
984b1c3
0bda8ba
984b1c3
 
 
 
12f978b
984b1c3
 
d3f9ca8
984b1c3
 
d00e30a
0bda8ba
 
 
 
 
 
 
 
d00e30a
91a732d
 
 
0bda8ba
91a732d
 
 
4175fd1
0bda8ba
 
4175fd1
0bda8ba
 
 
 
 
4175fd1
91a732d
d3f9ca8
ba17c5b
0bda8ba
 
d3f9ca8
ba17c5b
91a732d
0bda8ba
ba17c5b
908fc7b
0bda8ba
d00e30a
908fc7b
12f978b
 
0bda8ba
 
12f978b
 
d3f9ca8
0bda8ba
12f978b
0bda8ba
 
a22d3b1
908fc7b
0bda8ba
 
95d0b08
a22d3b1
0bda8ba
a22d3b1
984b1c3
 
7da70bd
 
 
 
 
 
 
 
 
 
0bda8ba
 
 
 
 
 
 
 
 
 
 
 
984b1c3
0bda8ba
7da70bd
0bda8ba
 
7da70bd
 
 
0bda8ba
7da70bd
 
 
 
0bda8ba
7da70bd
0bda8ba
7da70bd
 
 
 
0bda8ba
7da70bd
 
 
 
984b1c3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET

# Load U²-Net model
model_path = "cloth_segmentation/networks/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()

def refine_mask(mask):
    """Enhanced mask refinement with erosion and morphological operations"""
    # First closing to fill small holes
    close_kernel = np.ones((5, 5), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
    
    # Erosion to remove small protrusions and extra areas
    erode_kernel = np.ones((3, 3), np.uint8)
    mask = cv2.erode(mask, erode_kernel, iterations=1)
    
    # Second closing to refine edges after erosion
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
    
    # Final blur to smooth edges while preserving shape
    mask = cv2.GaussianBlur(mask, (5, 5), 1.5)
    
    return mask

def segment_dress(image_np):
    """Improved dress segmentation with adaptive thresholding"""
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])

    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)

    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()

    # Adaptive threshold calculation
    output = (output - output.min()) / (output.max() - output.min() + 1e-8)
    adaptive_thresh = np.mean(output) + 0.2  # Increased threshold for tighter mask
    dress_mask = (output > adaptive_thresh).astype(np.uint8) * 255
    
    # Preserve hard edges during resize
    dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), 
                          interpolation=cv2.INTER_NEAREST)

    return refine_mask(dress_mask)

def apply_grabcut(image_np, dress_mask):
    """Mask refinement using GrabCut"""
    bgd_model = np.zeros((1, 65), np.float64)
    fgd_model = np.zeros((1, 65), np.float64)

    mask = np.where(dress_mask > 0, cv2.GC_PR_FGD, cv2.GC_BGD).astype('uint8')
    
    # Get bounding box coordinates
    coords = cv2.findNonZero(dress_mask)
    if coords is not None:
        x, y, w, h = cv2.boundingRect(coords)
        rect = (x, y, w, h)
        cv2.grabCut(image_np, mask, rect, bgd_model, fgd_model, 3, cv2.GC_INIT_WITH_MASK)
    
    refined_mask = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype("uint8")
    return refine_mask(refined_mask)

def recolor_dress(image_np, dress_mask, target_color):
    """Color transformation with improved blending"""
    # Convert colors to LAB space
    target_color_lab = cv2.cvtColor(np.uint8([[target_color]]), cv2.COLOR_BGR2LAB)[0][0]
    img_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)

    # Calculate color shifts
    dress_pixels = img_lab[dress_mask > 0]
    if len(dress_pixels) == 0:
        return image_np

    mean_L, mean_A, mean_B = np.mean(dress_pixels, axis=0)
    a_shift = target_color_lab[1] - mean_A
    b_shift = target_color_lab[2] - mean_B

    # Apply color transformation
    img_lab[..., 1] = np.clip(img_lab[..., 1] + (dress_mask / 255.0) * a_shift, 0, 255)
    img_lab[..., 2] = np.clip(img_lab[..., 2] + (dress_mask / 255.0) * b_shift, 0, 255)

    # Create adaptive blending mask
    img_recolored = cv2.cvtColor(img_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
    feathered_mask = cv2.GaussianBlur(dress_mask, (21, 21), 7)
    lightness_mask = (img_lab[..., 0] / 255.0) ** 0.7
    adaptive_feather = (feathered_mask * lightness_mask).astype(np.uint8)

    # Smooth blending
    return (image_np * (1 - adaptive_feather[..., None]/255) + img_recolored * (adaptive_feather[..., None]/255)).astype(np.uint8)

def change_dress_color(img, color):
    """Main processing function with error handling"""
    if img is None:
        return None

    color_map = {
        "Red": (0, 0, 255), "Blue": (255, 0, 0), "Green": (0, 255, 0),
        "Yellow": (0, 255, 255), "Purple": (128, 0, 128), "Orange": (0, 165, 255),
        "Cyan": (255, 255, 0), "Magenta": (255, 0, 255), "White": (255, 255, 255),
        "Black": (0, 0, 0)
    }
    
    new_color_bgr = color_map.get(color, (0, 0, 255))
    img_np = np.array(img)
    
    try:
        dress_mask = segment_dress(img_np)
        if np.sum(dress_mask) < 1000:  # Minimum mask area threshold
            return img
            
        dress_mask = apply_grabcut(img_np, dress_mask)
        img_recolored = recolor_dress(img_np, dress_mask, new_color_bgr)
        return Image.fromarray(img_recolored)
        
    except Exception as e:
        print(f"Error processing image: {str(e)}")
        return img

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# AI Dress Color Changer")
    gr.Markdown("Upload a dress image and select a new color for realistic recoloring")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            color_choice = gr.Dropdown(
                choices=["Red", "Blue", "Green", "Yellow", "Purple", 
                        "Orange", "Cyan", "Magenta", "White", "Black"],
                value="Red",
                label="Select New Color"
            )
            process_btn = gr.Button("Recolor Dress")
        
        with gr.Column():
            output_image = gr.Image(type="pil", label="Result")

    process_btn.click(
        fn=change_dress_color,
        inputs=[input_image, color_choice],
        outputs=output_image
    )

if __name__ == "__main__":
    demo.launch()