Loomi-Clothing-Detection-API / clothing_detector.py
kabancov_et
� Optimize color analysis performance: add caching, smart image resizing, and faster KMeans
eea39e9
raw
history blame
56.1 kB
import hashlib
import time
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import torch.nn as nn
from io import BytesIO
import numpy as np
from collections import Counter
import logging
import base64
import warnings
import traceback
# Suppress transformers warnings for cleaner logs
warnings.filterwarnings("ignore", message=".*feature_extractor_type.*")
warnings.filterwarnings("ignore", message=".*reduce_labels.*")
warnings.filterwarnings("ignore", message=".*TRANSFORMERS_CACHE.*")
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global cache for segmentation results with smart priorities
_segmentation_cache = {}
_cache_hits = 0
_cache_misses = 0
_cache_access_times = {} # Track when items were last accessed
# New cache for pred_seg data (for analyze endpoint)
class SegmentationDataCache:
def __init__(self, max_size=50, ttl_hours=2):
self.cache = {}
self.max_size = max_size
self.ttl_seconds = ttl_hours * 3600
self.access_times = {}
def get(self, image_hash):
"""Get cached segmentation data if not expired."""
if image_hash in self.cache:
data, timestamp = self.cache[image_hash]
if time.time() - timestamp < self.ttl_seconds:
self.access_times[image_hash] = time.time()
return data
else:
# Expired, remove it
del self.cache[image_hash]
del self.access_times[image_hash]
return None
def set(self, image_hash, data):
"""Store segmentation data with timestamp."""
if len(self.cache) >= self.max_size:
# Remove oldest item
oldest_hash = min(self.access_times.keys(), key=lambda k: self.access_times[k])
del self.cache[oldest_hash]
del self.access_times[oldest_hash]
current_time = time.time()
self.cache[image_hash] = (data, current_time)
self.access_times[image_hash] = current_time
logger.info(f"Stored segmentation data in cache for hash: {image_hash[:8]}... (cache size: {len(self.cache)})")
def get_stats(self):
"""Get cache statistics."""
return {
"size": len(self.cache),
"max_size": self.max_size,
"ttl_hours": self.ttl_seconds / 3600
}
# Global instance
_segmentation_data_cache = SegmentationDataCache()
# Cache configuration
MAX_CACHE_SIZE = 15 # Increased from 10
CACHE_PRIORITY_THRESHOLD = 5 # Minimum access count to keep in cache
def _cleanup_cache():
"""Smart cache cleanup based on access patterns and memory usage."""
global _segmentation_cache, _cache_access_times
if len(_segmentation_cache) <= MAX_CACHE_SIZE:
return
# Calculate priority scores (access count * recency)
import time
current_time = time.time()
priority_scores = {}
for key, access_info in _cache_access_times.items():
if key in _segmentation_cache:
recency = current_time - access_info['last_access']
priority = access_info['access_count'] / (1 + recency / 3600) # Normalize by hour
priority_scores[key] = priority
# Remove lowest priority items
items_to_remove = len(_segmentation_cache) - MAX_CACHE_SIZE
sorted_keys = sorted(priority_scores.keys(), key=lambda k: priority_scores[k])
for key in sorted_keys[:items_to_remove]:
del _segmentation_cache[key]
del _cache_access_times[key]
logger.info(f"Removed low-priority cache item: {key}")
logger.info(f"Cache cleaned up: {len(_segmentation_cache)} items remaining")
def _update_cache_access(image_hash: str):
"""Update cache access statistics."""
global _cache_access_times
import time
current_time = time.time()
if image_hash in _cache_access_times:
_cache_access_times[image_hash]['access_count'] += 1
_cache_access_times[image_hash]['last_access'] = current_time
else:
_cache_access_times[image_hash] = {
'access_count': 1,
'last_access': current_time
}
# Temporarily clear cache for testing improved quality
_segmentation_cache.clear()
_cache_hits = 0
_cache_misses = 0
logger.info("Cache cleared for testing improved quality")
class ClothingDetector:
def __init__(self):
"""Initialize clothing segmentation model."""
self.device = torch.device("cpu") # Force CPU for free tier
logger.info(f"Using device: {self.device} (free tier optimization)")
# Load processor and model with CPU optimizations
logger.info("Loading SegformerImageProcessor...")
self.processor = SegformerImageProcessor.from_pretrained(
"mattmdjaga/segformer_b2_clothes",
# Remove deprecated arguments that cause warnings
)
logger.info("Loading AutoModelForSemanticSegmentation...")
self.model = AutoModelForSemanticSegmentation.from_pretrained(
"mattmdjaga/segformer_b2_clothes",
torch_dtype=torch.float32, # Use FP32 for CPU stability
low_cpu_mem_usage=True, # Reduce memory usage
)
logger.info(f"Moving model to {self.device}...")
self.model.to(self.device)
self.model.eval()
# CPU-specific optimizations
torch.set_num_threads(4) # Limit CPU threads for stability
logger.info("Clothing detector initialized successfully (CPU optimized)")
# Clothing labels mapping
self.labels = {
0: "Background",
1: "Hat",
2: "Hair",
3: "Sunglasses",
4: "Upper-clothes",
5: "Skirt",
6: "Pants",
7: "Dress",
8: "Belt",
9: "Left-shoe",
10: "Right-shoe",
11: "Face",
12: "Left-leg",
13: "Right-leg",
14: "Left-arm",
15: "Right-arm",
16: "Bag",
17: "Scarf"
}
# Clothing classes (exclude body parts and background)
self.clothing_classes = [4, 5, 6, 7, 8, 9, 10, 16, 17] # Upper-clothes, Skirt, Pants, Dress, Belt, Left-shoe, Right-shoe, Bag, Scarf
def _get_image_hash(self, image_bytes: bytes) -> str:
"""Create image hash to use as cache key."""
return hashlib.md5(image_bytes).hexdigest()
def _segment_image(self, image_bytes: bytes):
"""Run image segmentation with caching."""
image_hash = self._get_image_hash(image_bytes)
# Check cache (re-enabled now that quality is improved)
if image_hash in _segmentation_cache:
global _cache_hits
_cache_hits += 1
_update_cache_access(image_hash) # Update access statistics
logger.info("⏱️ Using cached high-quality segmentation result")
return _segmentation_cache[image_hash]
global _cache_misses
_cache_misses += 1
# Run segmentation
logger.info("Performing new high-quality segmentation")
seg_start = time.time()
try:
# Load and preprocess image
preprocess_start = time.time()
image = Image.open(BytesIO(image_bytes))
image = image.convert('RGB')
preprocess_time = time.time() - preprocess_start
logger.info(f"⏱️ Image preprocessing completed in {preprocess_time:.2f}s")
# Prepare inputs for the model
inputs_start = time.time()
inputs = self.processor(images=image, return_tensors="pt")
inputs_time = time.time() - inputs_start
logger.info(f"⏱️ Input preparation completed in {inputs_time:.2f}s")
# Move inputs to device
device_start = time.time()
inputs = {k: v.to(self.device) for k, v in inputs.items()}
device_time = time.time() - device_start
logger.info(f"⏱️ Device transfer completed in {device_time:.2f}s")
# Run inference
inference_start = time.time()
with torch.no_grad():
outputs = self.model(**inputs)
inference_time = time.time() - inference_start
logger.info(f"⏱️ Model inference completed in {inference_time:.2f}s")
# Get predictions
postprocess_start = time.time()
logits = outputs.logits
pred_seg = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
# Upsample logits to original image size for better quality
import torch.nn.functional as nn
# Ensure logits have correct shape (N, C, H, W)
if logits.dim() == 3:
logits = logits.unsqueeze(0) # Add batch dimension if missing
# Get image dimensions
img_height, img_width = image.size[1], image.size[0] # PIL uses (width, height)
# Log tensor shapes for debugging
logger.info(f"Logits shape: {logits.shape}, Target size: ({img_height}, {img_width})")
# Ensure target size is valid
if img_height <= 0 or img_width <= 0:
logger.warning(f"Invalid image dimensions: {img_height}x{img_width}, using original segmentation")
pred_seg_high_quality = pred_seg
else:
# Upsample logits to original image size
logits_upsampled = nn.interpolate(
logits,
size=(img_height, img_width), # Use (height, width) format
mode="bilinear",
align_corners=False,
)
# Get high-quality predictions
pred_seg_high_quality = logits_upsampled.argmax(dim=1)[0].cpu().numpy()
logger.info(f"Created high-quality segmentation: {pred_seg_high_quality.shape} for image size {image.size}")
postprocess_time = time.time() - postprocess_start
logger.info(f"⏱️ Postprocessing completed in {postprocess_time:.2f}s")
# Store result in cache
cache_start = time.time()
_segmentation_cache[image_hash] = {
'pred_seg': pred_seg_high_quality, # Use high-quality version
'image': image
}
# Update cache access and cleanup if needed
_update_cache_access(image_hash)
_cleanup_cache()
cache_time = time.time() - cache_start
logger.info(f"⏱️ Cache operations completed in {cache_time:.2f}s")
# Total segmentation time
total_seg_time = time.time() - seg_start
logger.info(f"⏱️ TOTAL segmentation completed in {total_seg_time:.2f}s (preprocess: {preprocess_time:.2f}s, inputs: {inputs_time:.2f}s, device: {device_time:.2f}s, inference: {inference_time:.2f}s, postprocess: {postprocess_time:.2f}s, cache: {cache_time:.2f}s)")
return {
'pred_seg': pred_seg_high_quality, # Return high-quality version
'image': image
}
except Exception as e:
total_seg_time = time.time() - seg_start
logger.error(f"❌ Error in segmentation after {total_seg_time:.2f}s: {e}")
raise
def detect_clothing(self, image_bytes: bytes) -> dict:
"""
Detect clothing types on image and return coordinates.
Args:
image_bytes: Raw image bytes
Returns:
dict: Clothing types with pixel stats and bounding boxes
"""
try:
# Get cached segmentation result
seg_result = self._segment_image(image_bytes)
pred_seg = seg_result['pred_seg']
image = seg_result['image']
# Count pixels per class and compute bounding boxes
clothing_types = {}
coordinates = {}
total_pixels = pred_seg.size
for class_id, label_name in self.labels.items():
if label_name not in ["Background", "Face", "Hair", "Left-arm", "Right-arm", "Left-leg", "Right-leg"]:
# Create mask for this class
mask = (pred_seg == class_id)
if np.any(mask):
# Count pixels
count = np.sum(mask)
percentage = (count / total_pixels) * 100
clothing_types[label_name] = {
"pixels": int(count),
"percentage": round(percentage, 2)
}
# Compute bounding box
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if np.any(rows) and np.any(cols):
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
# Add padding (10% of clothing size)
clothing_width = x_max - x_min
clothing_height = y_max - y_min
padding_x = int(clothing_width * 0.1)
padding_y = int(clothing_height * 0.1)
# Apply padding with image bounds
x_min = max(0, x_min - padding_x)
y_min = max(0, y_min - padding_y)
x_max = min(image.width, x_max + padding_x)
y_max = min(image.height, y_max + padding_y)
coordinates[label_name] = {
"x_min": int(x_min),
"y_min": int(y_min),
"x_max": int(x_max),
"y_max": int(y_max),
"width": int(x_max - x_min),
"height": int(y_max - y_min)
}
# Sort by percentage area
sorted_clothing = dict(sorted(
clothing_types.items(),
key=lambda x: x[1]["percentage"],
reverse=True
))
# Convert to the format expected by the API
clothing_instances = []
for label_name, stats in sorted_clothing.items():
if label_name in coordinates:
coord = coordinates[label_name]
clothing_instances.append({
"type": label_name,
"class_id": next(class_id for class_id, name in self.labels.items() if name == label_name),
"bbox": {
"x": coord["x_min"],
"y": coord["y_min"],
"width": coord["width"],
"height": coord["height"]
},
"area_pixels": stats["pixels"],
"area_percentage": stats["percentage"]
})
return {
"message": f"Clothing detection completed. Found {len(sorted_clothing)} items",
"total_detected": len(sorted_clothing),
"clothing_instances": clothing_instances,
"image_info": {
"width": image.width,
"height": image.height,
"total_pixels": image.width * image.height
}
}
except Exception as e:
logger.error(f"Error in clothing detection: {str(e)}")
return {
"message": "Error in clothing detection",
"total_detected": 0,
"clothing_instances": [],
"image_info": {
"width": 0,
"height": 0,
"total_pixels": 0
},
"error": str(e)
}
def create_clothing_only_image(self, image_bytes: bytes, selected_clothing: str = None) -> str:
"""
Create clothing-only image with transparent background.
Args:
image_bytes: Raw image bytes
selected_clothing: Optional clothing label to isolate
Returns:
str: Base64-encoded PNG data URL
"""
try:
# Get cached segmentation
seg_result = self._segment_image(image_bytes)
pred_seg = seg_result['pred_seg']
image = seg_result['image']
# Create clothing-only mask
clothing_mask = np.zeros_like(pred_seg, dtype=bool)
if selected_clothing:
# If specific clothing selected, find its class id
selected_class_id = None
for class_id, label_name in self.labels.items():
if label_name == selected_clothing:
selected_class_id = class_id
break
if selected_class_id is not None:
# Build mask only for the selected class
clothing_mask = (pred_seg == selected_class_id)
else:
# If not found, fall back to all clothing classes
for class_id in self.clothing_classes:
clothing_mask |= (pred_seg == class_id)
else:
# Otherwise, use all clothing classes
for class_id in self.clothing_classes:
clothing_mask |= (pred_seg == class_id)
# Convert image to numpy array
image_array = np.array(image)
# Compose RGBA with transparent background
clothing_only_rgba = np.zeros((image_array.shape[0], image_array.shape[1], 4), dtype=np.uint8)
clothing_only_rgba[..., :3] = image_array # RGB channels
clothing_only_rgba[..., 3] = 255 # Alpha channel (opaque)
clothing_only_rgba[~clothing_mask, 3] = 0 # Transparent for non-clothing
# Create PIL image
clothing_image = Image.fromarray(clothing_only_rgba, 'RGBA')
# If a specific clothing selected, crop with padding
if selected_clothing and selected_class_id is not None:
clothing_image = self._crop_with_padding(clothing_image, clothing_mask)
# Encode to base64
buffer = BytesIO()
clothing_image.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
except Exception as e:
logger.error(f"Error in creating clothing-only image: {str(e)}")
return ""
def _crop_with_padding(self, image: Image.Image, mask: np.ndarray, padding_percent: float = 0.1) -> Image.Image:
"""
Crop image around clothing mask with padding.
Args:
image: PIL image
mask: Clothing mask
padding_percent: Padding percentage relative to clothing size
Returns:
Image.Image: Cropped image
"""
try:
# Find clothing bounds
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if not np.any(rows) or not np.any(cols):
return image # If no clothing found, return original
# Get bounds
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
# Compute clothing size
clothing_width = x_max - x_min
clothing_height = y_max - y_min
# Compute padding
padding_x = int(clothing_width * padding_percent)
padding_y = int(clothing_height * padding_percent)
# Apply padding within image bounds
x_min = max(0, x_min - padding_x)
y_min = max(0, y_min - padding_y)
x_max = min(image.width, x_max + padding_x)
y_max = min(image.height, y_max + padding_y)
# Crop
cropped_image = image.crop((x_min, y_min, x_max, y_max))
return cropped_image
except Exception as e:
logger.error(f"Error in cropping with padding: {str(e)}")
return image
def detect_clothing_with_segmentation(self, image_bytes: bytes) -> dict:
"""
Detect clothing types with full segmentation data for reuse.
Returns both clothing info and segmentation data.
"""
try:
seg_result = self._segment_image(image_bytes)
pred_seg = seg_result['pred_seg']
image = seg_result['image']
clothing_result = self.detect_clothing(image_bytes)
# Convert original image to base64 for reuse
import base64
from io import BytesIO
buffer = BytesIO()
image.save(buffer, format='PNG')
original_image_base64 = base64.b64encode(buffer.getvalue()).decode()
# Create highlighted images for each clothing type
highlighted_images = {}
# Add "All clothing types" highlight
try:
all_clothing_highlight = self.create_clothing_highlight_image(image_bytes, None) # None = all clothing
highlighted_images['all'] = all_clothing_highlight
logger.info("Created highlight for all clothing types")
except Exception as e:
logger.warning(f"Could not create highlight for all clothing types: {e}")
highlighted_images['all'] = original_image_base64
# Create highlights for individual clothing types in parallel batches
clothing_types = clothing_result.get('clothing_instances', [])
if clothing_types:
# Process in smaller batches for better memory management
batch_size = 3
for i in range(0, len(clothing_types), batch_size):
batch = clothing_types[i:i + batch_size]
for clothing_type in batch:
type_name = clothing_type.get('type', '')
if type_name:
try:
highlighted_img = self.create_clothing_highlight_image(image_bytes, type_name)
highlighted_images[type_name] = highlighted_img
logger.info(f"Created highlight for {type_name}")
except Exception as e:
logger.warning(f"Could not create highlight for {type_name}: {e}")
highlighted_images[type_name] = original_image_base64
# Ensure all data is JSON serializable
return {
**clothing_result,
"segmentation_data": {
"pred_seg": pred_seg.tolist(), # Convert numpy array to list for JSON
"image_size": list(image.size), # Convert tuple to list for JSON
"image_hash": self._get_image_hash(image_bytes),
"original_image": f"data:image/png;base64,{original_image_base64}" # Add original image
},
"highlighted_images": highlighted_images, # Images with colored outlines
"original_image": f"data:image/png;base64,{original_image_base64}" # Original image for display
}
except Exception as e:
logger.error(f"Error in clothing detection with segmentation: {e}")
raise
def detect_clothing_with_segmentation_optimized(self, image_bytes: bytes) -> dict:
"""
Optimized version that returns only segmentation data without creating highlight images.
Much faster - client handles visualization.
"""
start_time = time.time()
try:
# Step 1: Segmentation
seg_start = time.time()
seg_result = self._segment_image(image_bytes)
pred_seg = seg_result['pred_seg']
image = seg_result['image']
seg_time = time.time() - seg_start
logger.info(f"⏱️ Segmentation completed in {seg_time:.2f}s")
# Step 2: Clothing detection
detect_start = time.time()
clothing_result = self.detect_clothing(image_bytes)
detect_time = time.time() - detect_start
logger.info(f"⏱️ Clothing detection completed in {detect_time:.2f}s")
# Step 3: Create masks
masks_start = time.time()
clothing_types = clothing_result.get('clothing_instances', [])
masks = {}
logger.info(f"Creating masks for {len(clothing_types)} clothing types...")
# Create masks for each clothing type
for clothing_type in clothing_types:
type_name = clothing_type.get('type', '')
if type_name:
# Get mask for this clothing type
mask = self._get_clothing_mask(pred_seg, type_name)
if mask is not None:
# Convert mask to base64
mask_base64 = self._mask_to_base64(mask)
masks[type_name] = mask_base64
# Create combined mask for all clothing
logger.info("Creating combined mask for all clothing...")
all_clothing_mask = self._get_all_clothing_mask(pred_seg)
masks['all'] = self._mask_to_base64(all_clothing_mask)
logger.info("All masks created successfully")
masks_time = time.time() - masks_start
logger.info(f"⏱️ Masks creation completed in {masks_time:.2f}s")
# Step 4: Cache storage
cache_start = time.time()
image_hash = self._get_image_hash(image_bytes)
_segmentation_data_cache.set(image_hash, {
"pred_seg": pred_seg,
"image_size": list(image.size),
"original_image_bytes": image_bytes # Store original image for background removal
})
cache_time = time.time() - cache_start
logger.info(f"⏱️ Cache storage completed in {cache_time:.2f}s")
# Total time
total_time = time.time() - start_time
logger.info(f"🚀 TOTAL /detect completed in {total_time:.2f}s (seg: {seg_time:.2f}s, detect: {detect_time:.2f}s, masks: {masks_time:.2f}s, cache: {cache_time:.2f}s)")
return {
**clothing_result,
"segmentation_data": {
"masks": masks,
"image_size": list(image.size),
"image_hash": image_hash
# pred_seg removed - stored in server cache instead
}
}
except Exception as e:
total_time = time.time() - start_time
logger.error(f"❌ Error in optimized clothing detection after {total_time:.2f}s: {e}")
raise
def _get_clothing_mask(self, pred_seg: np.ndarray, clothing_type: str) -> np.ndarray:
"""Get binary mask for specific clothing type."""
try:
# Map clothing type to class ID
class_mapping = {
'Hat': 0, 'Hair': 1, 'Glove': 2, 'Sunglasses': 3, 'Upper-clothes': 4,
'Skirt': 5, 'Pants': 6, 'Dress': 7, 'Belt': 8, 'Left-shoe': 9, 'Right-shoe': 10,
'Left-sock': 11, 'Right-sock': 12, 'Left-bag': 13, 'Right-bag': 14, 'Scarf': 15
}
class_id = class_mapping.get(clothing_type)
if class_id is not None:
mask = (pred_seg == class_id).astype(np.uint8)
return mask
return None
except Exception as e:
logger.error(f"Error getting mask for {clothing_type}: {e}")
return None
def _get_all_clothing_mask(self, pred_seg: np.ndarray) -> np.ndarray:
"""Get combined mask for all clothing types."""
try:
# Combine all clothing class IDs (0-15)
all_clothing_mask = np.zeros_like(pred_seg, dtype=np.uint8)
for class_id in range(16): # 0-15 are clothing classes
all_clothing_mask = np.logical_or(all_clothing_mask, pred_seg == class_id)
return all_clothing_mask.astype(np.uint8)
except Exception as e:
logger.error(f"Error getting all clothing mask: {e}")
return np.zeros_like(pred_seg, dtype=np.uint8)
def _mask_to_base64(self, mask: np.ndarray) -> str:
"""Convert numpy mask to compressed base64 string."""
try:
import gzip
# Convert mask to bytes
mask_bytes = mask.tobytes()
# Compress with gzip
compressed_bytes = gzip.compress(mask_bytes, compresslevel=9)
# Encode to base64
mask_base64 = base64.b64encode(compressed_bytes).decode('utf-8')
logger.info(f"Mask compressed: {len(mask_bytes)} -> {len(compressed_bytes)} bytes ({(1 - len(compressed_bytes)/len(mask_bytes))*100:.1f}% reduction)")
return mask_base64
except Exception as e:
logger.error(f"Error converting mask to base64: {e}")
return ""
def analyze_from_segmentation(self, segmentation_data: dict, selected_clothing: str = None) -> dict:
"""
Analyze image using pre-computed segmentation data from server cache.
Much faster than full analysis.
"""
start_time = time.time()
try:
# Step 1: Get data from cache
cache_start = time.time()
image_hash = segmentation_data.get("image_hash")
if not image_hash:
raise ValueError("No image_hash provided in segmentation_data")
cached_data = _segmentation_data_cache.get(image_hash)
if not cached_data:
raise ValueError(f"Segmentation data not found in cache for hash: {image_hash[:8]}...")
# Use cached data
pred_seg = cached_data["pred_seg"]
image_size = cached_data["image_size"]
original_image_bytes = cached_data["original_image_bytes"]
cache_time = time.time() - cache_start
logger.info(f"⏱️ Cache retrieval completed in {cache_time:.2f}s for hash: {image_hash[:8]}...")
# Step 2: Create clothing-only image
image_start = time.time()
clothing_only_image = self._create_real_clothing_only_image(
original_image_bytes, pred_seg, selected_clothing
)
image_time = time.time() - image_start
logger.info(f"⏱️ Clothing-only image creation completed in {image_time:.2f}s")
# Step 3: Analyze dominant color
color_start = time.time()
from process import get_dominant_color_from_base64
color = get_dominant_color_from_base64(clothing_only_image)
color_time = time.time() - color_start
logger.info(f"⏱️ Dominant color analysis completed in {color_time:.2f}s")
# Total time
total_time = time.time() - start_time
logger.info(f"🚀 TOTAL /analyze completed in {total_time:.2f}s (cache: {cache_time:.2f}s, image: {image_time:.2f}s, color: {color_time:.2f}s)")
return {
"dominant_color": color,
"clothing_only_image": clothing_only_image,
"selected_clothing": selected_clothing,
"processing_note": "Used pre-computed segmentation data with original image"
}
except Exception as e:
total_time = time.time() - start_time
logger.error(f"❌ Error in analysis from segmentation after {total_time:.2f}s: {e}")
raise
def _create_segmentation_visualization(self, pred_seg: np.ndarray, image_size: tuple, selected_clothing: str = None) -> str:
"""Create a visualization of the segmentation mask."""
try:
from PIL import Image, ImageDraw
import base64
from io import BytesIO
# Create a new image with the segmentation visualization
img = Image.new('RGBA', image_size, (0, 0, 0, 0))
draw = ImageDraw.Draw(img)
# Create mask for selected clothing or all clothing
if selected_clothing:
# Find class ID for selected clothing
class_id = None
for cid, label in self.labels.items():
if label.lower() == selected_clothing.lower():
class_id = cid
break
if class_id is not None:
mask = (pred_seg == class_id)
else:
# Fallback to all clothing if selected type not found
mask = np.isin(pred_seg, self.clothing_classes)
else:
# All clothing types
mask = np.isin(pred_seg, self.clothing_classes)
# Convert mask to PIL image
mask_img = Image.fromarray(mask.astype(np.uint8) * 255, mode='L')
# Create a colored overlay
overlay = Image.new('RGBA', image_size, (100, 150, 255, 128)) # Blue with transparency
# Apply mask to overlay
overlay.putalpha(mask_img)
# Composite with transparent background
result = Image.alpha_composite(img, overlay)
# Convert to base64
buffer = BytesIO()
result.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
except Exception as e:
logger.error(f"Error creating segmentation visualization: {e}")
# Return a simple colored square as fallback
return "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="
def _create_real_clothing_only_image(self, original_image_bytes: bytes, pred_seg: np.ndarray, selected_clothing: str = None) -> str:
"""Create real clothing-only image using original image and segmentation mask."""
try:
from PIL import Image
import base64
from io import BytesIO
# Load original image directly from bytes
original_image = Image.open(BytesIO(original_image_bytes))
# Optimize image size for faster color analysis while maintaining quality
# Large images can slow down color analysis significantly
if original_image.width > 800 or original_image.height > 800:
# Calculate optimal size (balance between quality and speed)
max_dim = max(original_image.width, original_image.height)
if max_dim > 2000:
target_size = (800, 800) # Very large images
elif max_dim > 1200:
target_size = (1000, 1000) # Large images
else:
target_size = (1200, 1200) # Medium-large images
# Resize while maintaining aspect ratio
original_image.thumbnail(target_size, Image.LANCZOS)
logger.info(f"🔄 Optimized image size from {original_image.width}x{original_image.height} to {target_size[0]}x{target_size[1]} for faster processing")
# Create mask for selected clothing or all clothing
if selected_clothing:
# Find class ID for selected clothing
class_id = None
for cid, label in self.labels.items():
if label.lower() == selected_clothing.lower():
class_id = cid
break
if class_id is not None:
mask = (pred_seg == class_id)
logger.info(f"Selected clothing type '{selected_clothing}' mapped to class ID {class_id}")
else:
# Fallback to all clothing if selected type not found
mask = np.isin(pred_seg, self.clothing_classes)
logger.warning(f"Could not find class ID for '{selected_clothing}', using all clothing types")
else:
# All clothing types
mask = np.isin(pred_seg, self.clothing_classes)
# Ensure mask and image have compatible dimensions
mask_height, mask_width = pred_seg.shape
img_width, img_height = original_image.size
# Resize mask to match original image if needed
if mask_height != img_height or mask_width != img_width:
logger.info(f"Resizing mask from {mask_height}x{mask_width} to {img_height}x{img_width}")
mask_img = Image.fromarray(mask.astype(np.uint8) * 255, mode='L')
# Use LANCZOS for better quality instead of NEAREST
mask_img = mask_img.resize((img_width, img_height), Image.LANCZOS)
mask = np.array(mask_img) > 128 # Threshold for clean binary mask
# Apply Gaussian blur to smooth the mask edges and reduce blockiness
try:
from scipy import ndimage
# Convert to float for better precision
mask_float = mask.astype(float)
# Apply multiple smoothing passes for better quality
# First pass: remove noise
mask_smooth1 = ndimage.gaussian_filter(mask_float, sigma=0.8)
# Second pass: smooth edges
mask_smooth2 = ndimage.gaussian_filter(mask_smooth1, sigma=1.0)
# Apply threshold with hysteresis for cleaner edges
mask = mask_smooth2 > 0.5
logger.info("Applied advanced smoothing to mask for smoother edges")
except ImportError:
logger.info("scipy not available, using basic smoothing")
# Basic smoothing without scipy
from PIL import ImageFilter
mask_img = Image.fromarray(mask.astype(np.uint8) * 255, mode='L')
mask_img = mask_img.filter(ImageFilter.GaussianBlur(radius=1.0))
mask = np.array(mask_img) > 128
# Convert original image to RGBA if it's not already
if original_image.mode != 'RGBA':
original_image = original_image.convert('RGBA')
# Create new image with transparent background
result = Image.new('RGBA', original_image.size, (0, 0, 0, 0))
# Apply mask to original image with smooth edges
original_array = np.array(original_image)
mask_array = mask.astype(np.uint8)
# Create result array
result_array = original_array.copy()
# Make background transparent (where mask is 0)
result_array[mask_array == 0, 3] = 0 # Set alpha to 0
# Convert back to PIL image
result = Image.fromarray(result_array, 'RGBA')
# Convert to base64
buffer = BytesIO()
result.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
except Exception as e:
logger.error(f"Error creating real clothing-only image: {e}")
# Fallback to visualization
return self._create_segmentation_visualization(pred_seg, (pred_seg.shape[1], pred_seg.shape[0]), selected_clothing)
def create_clothing_highlight_image(self, image_bytes: bytes, selected_clothing: str = None) -> str:
"""
Create image with highlighted selected clothing type.
Returns base64 PNG with colored outline around selected clothing.
"""
try:
logger.info(f"Creating highlight image for clothing type: {selected_clothing}")
seg_result = self._segment_image(image_bytes)
pred_seg = seg_result['pred_seg']
image = seg_result['image']
logger.info(f"Segmentation shape: {pred_seg.shape}, Image size: {image.size}")
# Find class ID for selected clothing
class_id = None
if selected_clothing:
for cid, label in self.labels.items():
if label.lower() == selected_clothing.lower():
class_id = cid
break
logger.info(f"Selected clothing '{selected_clothing}' mapped to class ID {class_id}")
if class_id is None:
# Use all clothing types if none selected
mask = np.isin(pred_seg, self.clothing_classes)
logger.info(f"Using all clothing types, mask sum: {np.sum(mask)}")
else:
# Use selected clothing type
mask = (pred_seg == class_id)
logger.info(f"Using selected clothing type {class_id}, mask sum: {np.sum(mask)}")
# Create highlighted image
from PIL import Image, ImageDraw
# Convert to RGBA if needed
if image.mode != 'RGBA':
image = image.convert('RGBA')
# Create a copy for highlighting
highlighted_image = image.copy()
draw = ImageDraw.Draw(highlighted_image)
# Find contours of the mask
mask_array = mask.astype(np.uint8) * 255
mask_img = Image.fromarray(mask_array, mode='L')
# Resize mask to match image if needed (should not be needed with high-quality segmentation)
if mask_img.size != image.size:
logger.info(f"Resizing mask from {mask_img.size} to {image.size}")
# Use LANCZOS for better quality instead of NEAREST
mask_img = mask_img.resize(image.size, Image.LANCZOS)
mask_array = np.array(mask_img)
# Apply threshold to get clean binary mask
mask_array = (mask_array > 128).astype(np.uint8) * 255
# Apply advanced smoothing to eliminate blockiness
try:
from scipy import ndimage
# Convert to float for better precision
mask_float = mask_array.astype(float) / 255.0
# Apply multiple smoothing passes for better quality
# First pass: remove noise
mask_smooth1 = ndimage.gaussian_filter(mask_float, sigma=0.8)
# Second pass: smooth edges
mask_smooth2 = ndimage.gaussian_filter(mask_smooth1, sigma=1.2)
# Third pass: final smoothing
mask_final = ndimage.gaussian_filter(mask_smooth2, sigma=0.6)
# Apply threshold with hysteresis for cleaner edges
mask_clean = (mask_final > 0.3).astype(np.uint8) * 255
# Apply morphological operations for cleaner mask
mask_clean = ndimage.binary_opening(mask_clean > 0, iterations=1)
mask_clean = ndimage.binary_closing(mask_clean, iterations=1)
mask_array = mask_clean.astype(np.uint8) * 255
logger.info("Applied advanced smoothing and morphological operations for high-quality mask")
except ImportError:
logger.info("scipy not available, using basic smoothing")
# Basic smoothing without scipy
from PIL import ImageFilter
mask_img = mask_img.filter(ImageFilter.GaussianBlur(radius=1.5))
mask_array = np.array(mask_img)
mask_array = (mask_array > 128).astype(np.uint8) * 255
# Create outline by dilating and subtracting original mask
try:
from scipy import ndimage
# Create smooth outline
mask_bool = mask_array > 0
# Dilate the mask to create outline
dilated = ndimage.binary_dilation(mask_bool, iterations=2)
outline = dilated & ~mask_bool
logger.info(f"Outline created, outline pixels: {np.sum(outline)}")
# Draw colored outline with anti-aliasing
outline_coords = np.where(outline)
if len(outline_coords[0]) > 0:
logger.info(f"Drawing {len(outline_coords[0])} outline pixels")
# Color based on clothing type - now unified color for all
color = (34, 197, 94, 255) # #22c55e for all clothing types
# Draw smooth outline with anti-aliasing effect
for y, x in zip(outline_coords[0], outline_coords[1]):
if 0 <= y < highlighted_image.height and 0 <= x < highlighted_image.width:
# Create anti-aliasing effect with varying opacity
base_color = list(color)
base_color[3] = 255 # Full opacity for center
highlighted_image.putpixel((x, y), tuple(base_color))
# Add semi-transparent pixels around for smoother edges
for dy in range(-1, 2):
for dx in range(-1, 2):
if dy == 0 and dx == 0:
continue # Skip center pixel
ny, nx = y + dy, x + dx
if 0 <= ny < highlighted_image.height and 0 <= nx < highlighted_image.width:
# Reduce opacity for edge pixels
edge_color = list(color)
edge_color[3] = 128 # 50% opacity
highlighted_image.putpixel((nx, ny), tuple(edge_color))
else:
logger.warning("No outline pixels found!")
# Fallback: create semi-transparent overlay
self._create_semi_transparent_overlay(highlighted_image, mask_array, selected_clothing)
except ImportError:
logger.warning("scipy not available, using semi-transparent overlay method")
# Create semi-transparent colored overlay
self._create_semi_transparent_overlay(highlighted_image, mask_array, selected_clothing)
# Convert to base64
buffer = BytesIO()
highlighted_image.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
logger.info("Highlight image created successfully")
return f"data:image/png;base64,{img_str}"
except Exception as e:
logger.error(f"Error creating highlighted image: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
# Fallback to original image
buffer = BytesIO()
image.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
def _create_semi_transparent_overlay(self, image, mask_array, selected_clothing):
"""Create semi-transparent colored overlay for selected clothing."""
try:
# Unified color for all clothing types
overlay_color = (34, 197, 94, 80) # #22c55e with 30% transparency
# Create overlay image
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
overlay_array = np.array(overlay)
# Apply mask to overlay
mask_bool = mask_array > 0
overlay_array[mask_bool] = overlay_color
# Convert back to PIL and composite
overlay = Image.fromarray(overlay_array, 'RGBA')
result = Image.alpha_composite(image, overlay)
# Copy result back to original image
image.paste(result, (0, 0))
logger.info("Semi-transparent overlay created successfully with unified color")
except Exception as e:
logger.error(f"Error creating overlay: {e}")
# If overlay fails, create simple colored border
self._create_simple_border(image, mask_array, selected_clothing)
def _create_simple_border(self, image, mask_array, selected_clothing):
"""Create simple colored border around detected clothing."""
try:
from PIL import ImageDraw
# Find bounding box of the mask
mask_bool = mask_array > 0
if not np.any(mask_bool):
logger.warning("No clothing detected for border creation")
return
# Get coordinates of clothing pixels
coords = np.where(mask_bool)
if len(coords[0]) == 0:
return
y_min, y_max = np.min(coords[0]), np.max(coords[0])
x_min, x_max = np.min(coords[1]), np.max(coords[1])
# Add padding around the bounding box
padding = 5
y_min = max(0, y_min - padding)
y_max = min(image.height - 1, y_max + padding)
x_min = max(0, x_min - padding)
x_max = min(image.width - 1, x_max + padding)
# Color based on clothing type - now unified color for all
border_color = (34, 197, 94, 255) # #22c55e for all clothing types
# Draw border rectangle
draw = ImageDraw.Draw(image)
border_width = 3
# Draw multiple rectangles for thicker border
for i in range(border_width):
draw.rectangle(
[x_min - i, y_min - i, x_max + i, y_max + i],
outline=border_color,
width=1
)
logger.info(f"Simple border created around clothing: ({x_min}, {y_min}) to ({x_max}, {y_max}) with unified color")
except Exception as e:
logger.error(f"Error creating simple border: {e}")
# If everything fails, just return original image
def process_multiple_images(self, image_bytes_list: list) -> list:
"""
Process multiple images in batch for better efficiency.
Returns list of results for each image.
"""
try:
results = []
# Process images in parallel if possible
import concurrent.futures
from functools import partial
# Use ThreadPoolExecutor for I/O bound operations
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
# Submit all images for processing
future_to_index = {
executor.submit(self.detect_clothing_with_segmentation, img_bytes): i
for i, img_bytes in enumerate(image_bytes_list)
}
# Collect results in order
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
result = future.result()
results.append((index, result))
except Exception as e:
logger.error(f"Error processing image {index}: {e}")
results.append((index, {"error": str(e)}))
# Sort results by original index
results.sort(key=lambda x: x[0])
return [result for _, result in results]
except Exception as e:
logger.error(f"Error in batch processing: {e}")
# Fallback to sequential processing
return [self.detect_clothing_with_segmentation(img_bytes) for img_bytes in image_bytes_list]
def detect_clothing_types_with_segmentation(image_bytes: bytes) -> dict:
"""Get clothing detection with full segmentation data for reuse."""
detector = get_clothing_detector()
return detector.detect_clothing_with_segmentation(image_bytes)
def analyze_from_segmentation(segmentation_data: dict, selected_clothing: str = None) -> dict:
"""Analyze image using pre-computed segmentation data (much faster)."""
detector = get_clothing_detector()
return detector.analyze_from_segmentation(segmentation_data, selected_clothing)
def detect_clothing_types_optimized(image_bytes: bytes) -> dict:
"""Get clothing detection with optimized segmentation data (faster, client handles visualization)."""
detector = get_clothing_detector()
return detector.detect_clothing_with_segmentation_optimized(image_bytes)
# Global detector singleton (to reuse model)
_detector = None
def get_clothing_detector():
"""Get global detector instance (lazy-init)."""
global _detector
if _detector is None:
_detector = ClothingDetector()
return _detector
def detect_clothing_types(image_bytes: bytes) -> dict:
"""Convenience wrapper for clothing detection."""
detector = get_clothing_detector()
return detector.detect_clothing(image_bytes)
def create_clothing_only_image(image_bytes: bytes, selected_clothing: str = None) -> str:
"""Convenience wrapper for clothing-only image creation."""
detector = get_clothing_detector()
return detector.create_clothing_only_image(image_bytes, selected_clothing)