Spaces:
Sleeping
Sleeping
File size: 16,465 Bytes
475e066 adc77ba 475e066 adc77ba 475e066 adc77ba 475e066 adc77ba 9c5e178 adc77ba 475e066 adc77ba 475e066 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 |
import torch
from torch import autocast
from diffusers import StableDiffusionInpaintPipeline
import gradio as gr
import traceback
import base64
from io import BytesIO
import os
# import sys
import PIL
import json
import requests
import logging
import time
import warnings
import numpy as np
from PIL import Image, ImageDraw
import cv2
warnings.filterwarnings("ignore")
# sys.path.insert(1, './parser')
# from parser.schp_masker import *
from parser.segformer_parser import SegformerParser
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('clothquill')
# Model paths
SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes"
STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting"
# Global variables for models
parser = None
model = None
inpainter = None
original_image = None # Store the original uploaded image
# Color mapping for different clothing parts
CLOTHING_COLORS = {
'Background': (0, 0, 0, 0), # Transparent
'Hat': (255, 0, 0, 128), # Red
'Hair': (0, 255, 0, 128), # Green
'Glove': (0, 0, 255, 128), # Blue
'Sunglasses': (255, 255, 0, 128), # Yellow
'Upper-clothes': (255, 0, 255, 128), # Magenta
'Dress': (0, 255, 255, 128), # Cyan
'Coat': (128, 0, 0, 128), # Dark Red
'Socks': (0, 128, 0, 128), # Dark Green
'Pants': (0, 0, 128, 128), # Dark Blue
'Jumpsuits': (128, 128, 0, 128), # Dark Yellow
'Scarf': (128, 0, 128, 128), # Dark Magenta
'Skirt': (0, 128, 128, 128), # Dark Cyan
'Face': (192, 192, 192, 128), # Light Gray
'Left-arm': (64, 64, 64, 128), # Dark Gray
'Right-arm': (64, 64, 64, 128), # Dark Gray
'Left-leg': (32, 32, 32, 128), # Very Dark Gray
'Right-leg': (32, 32, 32, 128), # Very Dark Gray
'Left-shoe': (16, 16, 16, 128), # Almost Black
'Right-shoe': (16, 16, 16, 128), # Almost Black
}
def get_device():
if torch.cuda.is_available():
device = "cuda"
logger.info("Using GPU")
else:
device = "cpu"
logger.info("Using CPU")
return device
def init():
global parser
global model
global inpainter
start_time = time.time()
logger.info("Starting application initialization")
try:
device = get_device()
# Check if models directory exists
if not os.path.exists("models"):
logger.info("Creating models directory...")
from download_models import download_models
download_models()
# Initialize Segformer parser
logger.info("Initializing Segformer parser...")
parser = SegformerParser(SEGFORMER_MODEL)
# Initialize Stable Diffusion model
logger.info("Initializing Stable Diffusion model...")
model = StableDiffusionInpaintPipeline.from_pretrained(
STABLE_DIFFUSION_MODEL,
safety_checker=None,
revision="fp16" if device == "cuda" else None,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
# Initialize inpainter
logger.info("Initializing inpainter...")
inpainter = ClothingInpainter(model=model, parser=parser)
logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds")
except Exception as e:
logger.error(f"Error initializing application: {str(e)}")
raise e
class ClothingInpainter:
def __init__(self, model_path=None, model=None, parser=None):
self.device = get_device()
self.last_mask = None # Store the last generated mask
self.original_image = None # Store the original image
if model_path is None and model is None:
raise ValueError('No model provided!')
if model_path is not None:
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_path,
safety_checker=None,
revision="fp16" if self.device == "cuda" else None,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
else:
self.pipe = model
self.parser = parser
def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)):
x, y = im.size
size = max(min_size, x, y)
new_im = PIL.Image.new('RGBA', (size, size), fill_color)
new_im.paste(im, (int((size - x) / 2), int((size - y) / 2)))
return new_im.convert('RGB')
def unmake_square(self, init_im, op_im, min_size=256, rs_size=512):
x, y = init_im.size
size = max(min_size, x, y)
factor = rs_size/size
return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\
int((size+x) * factor / 2), int((size+y) * factor / 2)))
def visualize_segmentation(self, image, masks, selected_parts=None):
"""Visualize segmentation with colored overlays for selected parts and gray for unselected."""
# Always use original image if available
image_to_use = self.original_image if self.original_image is not None else image
# Create a copy of the original image
original_size = image_to_use.size
vis_image = image_to_use.copy().convert('RGBA')
# Create overlay at 512x512
overlay = Image.new('RGBA', (512, 512), (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
# Draw each mask with its corresponding color
for part_name, mask in masks.items():
# Convert part name for color lookup
color_key = part_name.replace('-', ' ').title().replace(' ', '-')
is_selected = selected_parts and part_name in selected_parts
# If selected, use color (with fallback). If unselected, use faint gray
if is_selected:
color = CLOTHING_COLORS.get(color_key, (255, 0, 255, 128)) # Default to magenta if no color found
else:
color = (180, 180, 180, 80) # Faint gray for unselected
mask_array = np.array(mask)
coords = np.where(mask_array > 0)
for y, x in zip(coords[0], coords[1]):
draw.point((x, y), fill=color)
# Resize overlay to match original image size
overlay = overlay.resize(original_size, Image.Resampling.LANCZOS)
# Composite the overlay onto the original image
vis_image = Image.alpha_composite(vis_image, overlay)
return vis_image
def inpaint(self, prompt, init_image, selected_parts=None, dilation_iterations=2) -> dict:
image = self.make_square(init_image).resize((512,512))
if self.parser is not None:
masks = self.parser.get_all_masks(image)
masks = {k: v.resize((512,512)) for k, v in masks.items()}
else:
raise ValueError('Image Parser is Missing')
logger.info(f'[generated required mask(s) at {time.time()}]')
# Create combined mask for selected parts
if selected_parts:
combined_mask = Image.new('L', (512, 512), 0)
for part in selected_parts:
if part in masks:
mask_array = np.array(masks[part])
kernel = np.ones((5,5), np.uint8)
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
dilated_mask = Image.fromarray(dilated_mask)
combined_mask = Image.composite(
Image.new('L', (512, 512), 255),
combined_mask,
dilated_mask
)
else:
# If no parts selected, use all clothing parts
combined_mask = Image.new('L', (512, 512), 0)
for part, mask in masks.items():
if part in ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']:
mask_array = np.array(mask)
kernel = np.ones((5,5), np.uint8)
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
dilated_mask = Image.fromarray(dilated_mask)
combined_mask = Image.composite(
Image.new('L', (512, 512), 255),
combined_mask,
dilated_mask
)
# Run the model
guidance_scale=7.5
num_samples = 3
with autocast("cuda"), torch.inference_mode():
images = self.pipe(
num_inference_steps = 50,
prompt=prompt['pos'],
image=image,
mask_image=combined_mask,
guidance_scale=guidance_scale,
num_images_per_prompt=num_samples,
).images
images_output = []
for img in images:
ch = PIL.Image.composite(img, image, combined_mask)
fin_img = self.unmake_square(init_image, ch)
images_output.append(fin_img)
return images_output
def process_segmentation(image, dilation_iterations=2):
try:
if image is None:
raise gr.Error("Please upload an image")
# Store original image
inpainter.original_image = image.copy()
# Create a processing copy at 512x512
proc_image = image.resize((512, 512), Image.Resampling.LANCZOS)
# Get the main mask
all_masks = inpainter.parser.get_all_masks(proc_image)
if not all_masks:
logger.error("No clothing detected in the image")
raise gr.Error("No clothing detected in the image. Please try a different image.")
inpainter.last_mask = all_masks
# Only show main clothing parts for selection
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
masks = {k: v for k, v in all_masks.items() if k in main_parts}
vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=None)
detected_parts = [k for k in masks.keys()]
return vis_image, gr.update(choices=detected_parts, value=[])
except gr.Error as e:
raise e
except Exception as e:
logger.error(f"Error processing segmentation: {str(e)}")
raise gr.Error("Error processing the image. Please try a different image.")
def update_dilation(image, selected_parts, dilation_iterations):
try:
if image is None or inpainter.last_mask is None:
return image
# Redilate all stored masks
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
masks = {}
for part in main_parts:
if part in inpainter.last_mask:
mask_array = np.array(inpainter.last_mask[part])
kernel = np.ones((5,5), np.uint8)
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
masks[part] = Image.fromarray(dilated_mask)
# Use original image for visualization
vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts)
return vis_image
except Exception as e:
logger.error(f"Error updating dilation: {str(e)}")
return image
def process_image(prompt, image, selected_parts, dilation_iterations):
start_time = time.time()
logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}")
try:
if image is None:
logger.error("No image provided")
raise gr.Error("Please upload an image")
if not prompt:
logger.error("No prompt provided")
raise gr.Error("Please enter a prompt")
if not selected_parts:
logger.error("No parts selected")
raise gr.Error("Please select at least one clothing part to modify")
prompt_dict = {'pos': prompt}
logger.info("Starting inpainting process")
# Generate inpainted images
# Convert selected_parts to lowercase/dash format
selected_parts = [p.lower() for p in selected_parts]
images = inpainter.inpaint(prompt_dict, image, selected_parts, dilation_iterations)
if not images:
logger.error("Inpainting failed to produce results")
raise gr.Error("Failed to generate images. Please try again.")
logger.info(f"Request processed in {time.time() - start_time:.2f} seconds")
return images
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise gr.Error(f"Error processing image: {str(e)}")
def update_selected_parts(image, selected_parts, dilation_iterations):
try:
if image is None or inpainter.last_mask is None:
return image
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
masks = {}
for part in main_parts:
if part in inpainter.last_mask:
mask_array = np.array(inpainter.last_mask[part])
kernel = np.ones((5,5), np.uint8)
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
masks[part] = Image.fromarray(dilated_mask)
# Lowercase the selected_parts for comparison
selected_parts = [p.lower() for p in selected_parts] if selected_parts else []
# Use original image for visualization
vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts)
return vis_image
except Exception as e:
logger.error(f"Error updating selected parts: {str(e)}")
return image
# Initialize the model
init()
# Create Gradio interface
with gr.Blocks(title="ClothQuill - AI Clothing Inpainting") as demo:
gr.Markdown("# ClothQuill - AI Clothing Inpainting")
gr.Markdown("Upload an image to see segmented clothing parts, then select parts to modify and describe your changes")
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="pil",
label="Upload Image",
scale=1, # This ensures the image maintains its aspect ratio
height=None # Allow dynamic height based on content
)
dilation_slider = gr.Slider(
minimum=0,
maximum=5,
value=2,
step=1,
label="Mask Dilation",
info="Adjust the mask dilation to control the area of modification"
)
selected_parts = gr.CheckboxGroup(
choices=[],
label="Select parts to modify",
value=[]
)
prompt = gr.Textbox(
label="Describe the clothing you want to generate",
placeholder="e.g., A stylish black leather jacket"
)
generate_btn = gr.Button("Generate")
with gr.Column():
gallery = gr.Gallery(
label="Generated Results",
show_label=False,
columns=2,
height=None, # Allow dynamic height
object_fit="contain" # Maintain aspect ratio
)
# Add event handler for image upload
input_image.upload(
fn=process_segmentation,
inputs=[input_image, dilation_slider],
outputs=[input_image, selected_parts]
)
# Add event handler for dilation changes
dilation_slider.change(
fn=update_dilation,
inputs=[input_image, selected_parts,dilation_slider],
outputs=input_image
)
# Add event handler for generation
generate_btn.click(
fn=process_image,
inputs=[prompt, input_image, selected_parts, dilation_slider],
outputs=gallery
)
# Add event handler for part selection changes
selected_parts.change(
fn=update_selected_parts,
inputs=[input_image, selected_parts, dilation_slider],
outputs=input_image
)
if __name__ == "__main__":
demo.launch(share=True)
|