Spaces:
Running
Running
File size: 2,804 Bytes
984b1c3 af15f73 984b1c3 c274843 984b1c3 af15f73 b17fae5 af15f73 d3f9ca8 af15f73 d00e30a af15f73 b17fae5 af15f73 b17fae5 af15f73 b17fae5 af15f73 b17fae5 af15f73 0904db3 af15f73 b17fae5 af15f73 b17fae5 af15f73 eb85959 8fde5fd 984b1c3 eb85959 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from skimage import color
from sklearn.cluster import KMeans
from cloth_segmentation.networks.u2net import U2NET
model = U2NET(3, 1)
model.load_state_dict(torch.load("cloth_segmentation/networks/u2net.pth", map_location=torch.device('cpu')))
model.eval()
# Preprocessing
transform = T.Compose([
T.Resize((320, 320)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Segmentation mask
@torch.no_grad()
def get_dress_mask(image_pil):
img = transform(image_pil).unsqueeze(0)
pred = model(img)[0]
pred = pred.squeeze().cpu().numpy()
mask = (pred > 0.5).astype(np.uint8)
mask = cv2.resize(mask, image_pil.size[::-1])
return mask
# Color parsing (extract target color from prompt)
def extract_target_color(prompt):
# Basic keyword matching (can be replaced with NLP-based color detection)
import re
colors = ["red", "blue", "green", "yellow", "pink", "black", "white", "sky blue", "purple"]
for c in colors:
if re.search(c, prompt.lower()):
return c
return "red" # default fallback
# Recoloring function
def recolor_dress(image_pil, prompt):
image_np = np.array(image_pil.convert("RGB")) / 255.0
lab = color.rgb2lab(image_np)
mask = get_dress_mask(image_pil)
# Get mean a, b values in masked region
a_mean = lab[:, :, 1][mask == 1].mean()
b_mean = lab[:, :, 2][mask == 1].mean()
# Target a, b (from a small predefined palette)
target_color_map = {
"red": [60, 40],
"blue": [20, -60],
"green": [-60, 60],
"yellow": [10, 70],
"pink": [50, 10],
"purple": [40, -40],
"black": [0, 0],
"white": [0, 0],
"sky blue": [0, -50],
}
target = extract_target_color(prompt)
target_a, target_b = target_color_map.get(target, [60, 40])
# Apply color shift only to dress region
lab_new = lab.copy()
delta_a = target_a - a_mean
delta_b = target_b - b_mean
lab_new[:, :, 1][mask == 1] += delta_a
lab_new[:, :, 2][mask == 1] += delta_b
rgb_new = color.lab2rgb(lab_new)
rgb_new = (rgb_new * 255).astype(np.uint8)
return Image.fromarray(rgb_new)
# Gradio UI
def interface_fn(image, prompt):
return recolor_dress(image, prompt)
interface = gr.Interface(
fn=interface_fn,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Prompt", placeholder="Describe what to edit")
],
outputs=gr.Image(label="Edited Image"),
title="Image Editor",
description="Uses Hugging Face model for real image editing based on prompt."
)
interface.launch(show_error = True) |