File size: 2,804 Bytes
984b1c3
 
 
 
af15f73
 
 
 
 
 
984b1c3
 
c274843
984b1c3
 
af15f73
 
 
 
 
 
b17fae5
af15f73
 
 
 
 
 
 
 
 
d3f9ca8
af15f73
 
 
 
 
 
 
 
 
d00e30a
af15f73
 
 
 
b17fae5
af15f73
b17fae5
af15f73
 
 
b17fae5
af15f73
 
 
 
 
 
 
 
 
 
 
 
 
 
b17fae5
af15f73
 
 
 
 
 
0904db3
af15f73
 
 
b17fae5
af15f73
 
 
b17fae5
af15f73
 
eb85959
 
 
 
 
 
 
8fde5fd
984b1c3
eb85959
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from skimage import color
from sklearn.cluster import KMeans
from cloth_segmentation.networks.u2net import U2NET

model = U2NET(3, 1)
model.load_state_dict(torch.load("cloth_segmentation/networks/u2net.pth", map_location=torch.device('cpu')))
model.eval()

# Preprocessing
transform = T.Compose([
    T.Resize((320, 320)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Segmentation mask
@torch.no_grad()
def get_dress_mask(image_pil):
    img = transform(image_pil).unsqueeze(0)
    pred = model(img)[0]
    pred = pred.squeeze().cpu().numpy()
    mask = (pred > 0.5).astype(np.uint8)
    mask = cv2.resize(mask, image_pil.size[::-1])
    return mask

# Color parsing (extract target color from prompt)
def extract_target_color(prompt):
    # Basic keyword matching (can be replaced with NLP-based color detection)
    import re
    colors = ["red", "blue", "green", "yellow", "pink", "black", "white", "sky blue", "purple"]
    for c in colors:
        if re.search(c, prompt.lower()):
            return c
    return "red"  # default fallback

# Recoloring function
def recolor_dress(image_pil, prompt):
    image_np = np.array(image_pil.convert("RGB")) / 255.0
    lab = color.rgb2lab(image_np)

    mask = get_dress_mask(image_pil)

    # Get mean a, b values in masked region
    a_mean = lab[:, :, 1][mask == 1].mean()
    b_mean = lab[:, :, 2][mask == 1].mean()

    # Target a, b (from a small predefined palette)
    target_color_map = {
        "red": [60, 40],
        "blue": [20, -60],
        "green": [-60, 60],
        "yellow": [10, 70],
        "pink": [50, 10],
        "purple": [40, -40],
        "black": [0, 0],
        "white": [0, 0],
        "sky blue": [0, -50],
    }
    target = extract_target_color(prompt)
    target_a, target_b = target_color_map.get(target, [60, 40])

    # Apply color shift only to dress region
    lab_new = lab.copy()
    delta_a = target_a - a_mean
    delta_b = target_b - b_mean
    lab_new[:, :, 1][mask == 1] += delta_a
    lab_new[:, :, 2][mask == 1] += delta_b

    rgb_new = color.lab2rgb(lab_new)
    rgb_new = (rgb_new * 255).astype(np.uint8)
    return Image.fromarray(rgb_new)

# Gradio UI
def interface_fn(image, prompt):
    return recolor_dress(image, prompt)

interface = gr.Interface(
    fn=interface_fn,
    inputs=[
        gr.Image(label="Upload Image", type="pil"),
        gr.Textbox(label="Prompt", placeholder="Describe what to edit")
    ],
    outputs=gr.Image(label="Edited Image"),
    title="Image Editor",
    description="Uses Hugging Face model for real image editing based on prompt."
)

interface.launch(show_error = True)