kabancov_et commited on
Commit
6e164a8
·
1 Parent(s): 5c496a9

Deploy clothing detection API to HF Spaces

Browse files
Files changed (5) hide show
  1. Dockerfile +40 -0
  2. app.py +168 -0
  3. clothing_detector.py +331 -0
  4. process.py +79 -0
  5. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Create user as required by HF
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+ ENV PATH="/home/user/.local/bin:$PATH"
7
+
8
+ ENV PYTHONDONTWRITEBYTECODE=1 \
9
+ PYTHONUNBUFFERED=1 \
10
+ HOST=0.0.0.0 \
11
+ PORT=7860 \
12
+ WARMUP_ON_STARTUP=true
13
+
14
+ # Install system dependencies
15
+ RUN apt-get update && apt-get install -y --no-install-recommends \
16
+ libgl1 \
17
+ libglib2.0-0 \
18
+ libsm6 \
19
+ libxext6 \
20
+ && rm -rf /var/lib/apt/lists/*
21
+
22
+ WORKDIR /app
23
+
24
+ # Copy requirements and install Python dependencies
25
+ COPY --chown=user ./requirements.txt requirements.txt
26
+ RUN pip install --no-cache-dir --upgrade pip && \
27
+ pip install --no-cache-dir -r requirements.txt
28
+
29
+ # Copy app code
30
+ COPY --chown=user . /app
31
+
32
+ # Create results directory
33
+ RUN mkdir -p results
34
+
35
+ EXPOSE 7860
36
+
37
+ # HF requires port 7860
38
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
39
+
40
+
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ from typing import Optional, List
6
+ from process import get_dominant_color_from_base64
7
+ from clothing_detector import (
8
+ detect_clothing_types,
9
+ create_clothing_only_image,
10
+ get_clothing_detector,
11
+ )
12
+ import logging
13
+ import os
14
+ import base64
15
+ from starlette import status
16
+
17
+ # Logging setup
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ app = FastAPI(title="FashionAI API", description="Clothing analysis & segmentation API")
22
+
23
+ # CORS (configure with env ALLOWED_ORIGINS="http://localhost:5173,https://your-site")
24
+ allowed_origins_env = os.getenv("ALLOWED_ORIGINS", "*")
25
+ allow_origins: List[str]
26
+ if allowed_origins_env.strip() == "*":
27
+ allow_origins = ["*"]
28
+ else:
29
+ allow_origins = [o.strip() for o in allowed_origins_env.split(",") if o.strip()]
30
+
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=allow_origins,
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ # API settings
40
+ MAX_UPLOAD_MB = int(os.getenv("MAX_UPLOAD_MB", "10"))
41
+ MAX_UPLOAD_BYTES = MAX_UPLOAD_MB * 1024 * 1024
42
+ ALLOWED_CONTENT_TYPES = {
43
+ c.strip() for c in os.getenv("ALLOWED_CONTENT_TYPES", "image/jpeg,image/png,image/webp").split(",") if c.strip()
44
+ }
45
+
46
+
47
+ @app.exception_handler(Exception)
48
+ async def unhandled_exception_handler(request: Request, exc: Exception):
49
+ logging.exception("Unhandled server error: %s", exc)
50
+ return JSONResponse(
51
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
52
+ content={"error": "Internal Server Error"},
53
+ )
54
+
55
+ @app.on_event("startup")
56
+ async def maybe_warmup_model():
57
+ if os.getenv("WARMUP_ON_STARTUP", "true").lower() in {"1", "true", "yes"}:
58
+ # Warm up model on startup to reduce first request latency
59
+ get_clothing_detector()
60
+
61
+
62
+ @app.get("/")
63
+ async def api_root():
64
+ return JSONResponse({
65
+ "name": "FashionAI API",
66
+ "status": "ok",
67
+ "docs": "/docs",
68
+ "endpoints": ["/clothing", "/analyze", "/analyze/base64", "/labels", "/healthz"],
69
+ })
70
+
71
+
72
+ @app.get("/healthz")
73
+ async def health_check():
74
+ return {"status": "ok"}
75
+
76
+ @app.post("/clothing")
77
+ async def get_clothing_list(file: UploadFile = File(...)):
78
+ """Detect all clothing types on image and return coordinates."""
79
+ logger.info(f"Processing clothing detection for file: {file.filename}")
80
+ # Validation
81
+ if file.content_type not in ALLOWED_CONTENT_TYPES:
82
+ raise HTTPException(status_code=415, detail=f"Unsupported content-type: {file.content_type}")
83
+ # Read with size guard
84
+ image_bytes = await file.read()
85
+ if len(image_bytes) > MAX_UPLOAD_BYTES:
86
+ raise HTTPException(status_code=413, detail=f"File too large. Max {MAX_UPLOAD_MB}MB")
87
+ clothing_result = detect_clothing_types(image_bytes)
88
+ logger.info(f"Clothing detection completed. Found {clothing_result.get('total_detected', 0)} items")
89
+ return clothing_result
90
+
91
+ @app.post("/analyze")
92
+ async def analyze_image(
93
+ file: UploadFile = File(...),
94
+ selected_clothing: Optional[str] = Form(None)
95
+ ):
96
+ """
97
+ Full image analysis: clothing detection, clothing-only image, dominant color.
98
+
99
+ - selected_clothing: Optional clothing type to focus on
100
+ - color: Dominant color of clothing
101
+ - clothing_analysis: Detected clothing types with stats
102
+ - clothing_only_image: Base64 PNG with transparent background
103
+ """
104
+ logger.info(f"Processing full analysis for file: {file.filename}, selected_clothing: {selected_clothing}")
105
+ if file.content_type not in ALLOWED_CONTENT_TYPES:
106
+ raise HTTPException(status_code=415, detail=f"Unsupported content-type: {file.content_type}")
107
+ image_bytes = await file.read()
108
+ if len(image_bytes) > MAX_UPLOAD_BYTES:
109
+ raise HTTPException(status_code=413, detail=f"File too large. Max {MAX_UPLOAD_MB}MB")
110
+
111
+ # Step 1: Detect clothing types (cached segmentation)
112
+ logger.info("Detecting clothing types...")
113
+ clothing_result = detect_clothing_types(image_bytes)
114
+
115
+ # Step 2: Create clothing-only image (cached segmentation)
116
+ logger.info("Creating clothing-only image...")
117
+ clothing_only_image = create_clothing_only_image(image_bytes, selected_clothing)
118
+
119
+ # Step 3: Get dominant color from clothing-only image (no background)
120
+ logger.info("Getting dominant color from clothing-only image...")
121
+ color = get_dominant_color_from_base64(clothing_only_image)
122
+
123
+ logger.info("Full analysis completed successfully")
124
+ return JSONResponse(content={
125
+ "dominant_color": color,
126
+ "clothing_analysis": clothing_result,
127
+ "clothing_only_image": clothing_only_image,
128
+ "selected_clothing": selected_clothing
129
+ })
130
+
131
+
132
+ class Base64AnalyzeRequest(BaseModel):
133
+ image_base64: str
134
+ selected_clothing: Optional[str] = None
135
+
136
+
137
+ @app.post("/analyze/base64")
138
+ async def analyze_image_base64(payload: Base64AnalyzeRequest):
139
+ """Analyze base64-encoded image (handy for React Native)."""
140
+ # Decode image from base64
141
+ if payload.image_base64.startswith("data:image"):
142
+ base64_data = payload.image_base64.split(",", 1)[1]
143
+ else:
144
+ base64_data = payload.image_base64
145
+
146
+ image_bytes = base64.b64decode(base64_data)
147
+
148
+ # 1) Clothing detection
149
+ clothing_result = detect_clothing_types(image_bytes)
150
+
151
+ # 2) Clothing-only image
152
+ clothing_only_image = create_clothing_only_image(image_bytes, payload.selected_clothing)
153
+
154
+ # 3) Dominant color from clothing-only image
155
+ color = get_dominant_color_from_base64(clothing_only_image)
156
+
157
+ return JSONResponse(content={
158
+ "dominant_color": color,
159
+ "clothing_analysis": clothing_result,
160
+ "clothing_only_image": clothing_only_image,
161
+ "selected_clothing": payload.selected_clothing,
162
+ })
163
+
164
+
165
+ @app.get("/labels")
166
+ async def get_labels():
167
+ detector = get_clothing_detector()
168
+ return {"labels": list(detector.labels.values())}
clothing_detector.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn as nn
6
+ from io import BytesIO
7
+ import numpy as np
8
+ from collections import Counter
9
+ import logging
10
+ import base64
11
+
12
+ # Logging setup
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Global cache for segmentation results
17
+ _segmentation_cache = {}
18
+
19
+ class ClothingDetector:
20
+ def __init__(self):
21
+ """Initialize clothing segmentation model."""
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ logger.info(f"Using device: {self.device}")
24
+
25
+ # Load processor and model
26
+ self.processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
27
+ self.model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
28
+ self.model.to(self.device)
29
+ self.model.eval()
30
+
31
+ # Clothing labels mapping
32
+ self.labels = {
33
+ 0: "Background",
34
+ 1: "Hat",
35
+ 2: "Hair",
36
+ 3: "Sunglasses",
37
+ 4: "Upper-clothes",
38
+ 5: "Skirt",
39
+ 6: "Pants",
40
+ 7: "Dress",
41
+ 8: "Belt",
42
+ 9: "Left-shoe",
43
+ 10: "Right-shoe",
44
+ 11: "Face",
45
+ 12: "Left-leg",
46
+ 13: "Right-leg",
47
+ 14: "Left-arm",
48
+ 15: "Right-arm",
49
+ 16: "Bag",
50
+ 17: "Scarf"
51
+ }
52
+
53
+ # Clothing classes (exclude body parts and background)
54
+ self.clothing_classes = [4, 5, 6, 7, 8, 9, 10, 16, 17] # Upper-clothes, Skirt, Pants, Dress, Belt, Left-shoe, Right-shoe, Bag, Scarf
55
+
56
+ logger.info("Clothing detector initialized successfully")
57
+
58
+ def _get_image_hash(self, image_bytes: bytes) -> str:
59
+ """Create image hash to use as cache key."""
60
+ return hashlib.md5(image_bytes).hexdigest()
61
+
62
+ def _segment_image(self, image_bytes: bytes):
63
+ """Run image segmentation with caching."""
64
+ image_hash = self._get_image_hash(image_bytes)
65
+
66
+ # Check cache
67
+ if image_hash in _segmentation_cache:
68
+ logger.info("Using cached segmentation result")
69
+ return _segmentation_cache[image_hash]
70
+
71
+ # Run segmentation
72
+ logger.info("Performing new segmentation")
73
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
74
+
75
+ # Prepare inputs
76
+ inputs = self.processor(images=image, return_tensors="pt")
77
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
78
+
79
+ # Forward pass
80
+ with torch.no_grad():
81
+ outputs = self.model(**inputs)
82
+ logits = outputs.logits.cpu()
83
+
84
+ # Upsample logits to original image size
85
+ upsampled_logits = nn.functional.interpolate(
86
+ logits,
87
+ size=image.size[::-1], # (height, width)
88
+ mode="bilinear",
89
+ align_corners=False,
90
+ )
91
+
92
+ # Get predicted mask
93
+ pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
94
+
95
+ # Save to cache
96
+ result = {
97
+ 'pred_seg': pred_seg,
98
+ 'image': image,
99
+ 'image_size': image.size
100
+ }
101
+ _segmentation_cache[image_hash] = result
102
+
103
+ # Limit cache size (keep last 10)
104
+ if len(_segmentation_cache) > 10:
105
+ oldest_key = next(iter(_segmentation_cache))
106
+ del _segmentation_cache[oldest_key]
107
+
108
+ return result
109
+
110
+ def detect_clothing(self, image_bytes: bytes) -> dict:
111
+ """
112
+ Detect clothing types on image and return coordinates.
113
+
114
+ Args:
115
+ image_bytes: Raw image bytes
116
+
117
+ Returns:
118
+ dict: Clothing types with pixel stats and bounding boxes
119
+ """
120
+ try:
121
+ # Get cached segmentation result
122
+ seg_result = self._segment_image(image_bytes)
123
+ pred_seg = seg_result['pred_seg']
124
+ image = seg_result['image']
125
+
126
+ # Count pixels per class and compute bounding boxes
127
+ clothing_types = {}
128
+ coordinates = {}
129
+ total_pixels = pred_seg.size
130
+
131
+ for class_id, label_name in self.labels.items():
132
+ if label_name not in ["Background", "Face", "Hair", "Left-arm", "Right-arm", "Left-leg", "Right-leg"]:
133
+ # Create mask for this class
134
+ mask = (pred_seg == class_id)
135
+
136
+ if np.any(mask):
137
+ # Count pixels
138
+ count = np.sum(mask)
139
+ percentage = (count / total_pixels) * 100
140
+
141
+ clothing_types[label_name] = {
142
+ "pixels": int(count),
143
+ "percentage": round(percentage, 2)
144
+ }
145
+
146
+ # Compute bounding box
147
+ rows = np.any(mask, axis=1)
148
+ cols = np.any(mask, axis=0)
149
+
150
+ if np.any(rows) and np.any(cols):
151
+ y_min, y_max = np.where(rows)[0][[0, -1]]
152
+ x_min, x_max = np.where(cols)[0][[0, -1]]
153
+
154
+ # Add padding (10% of clothing size)
155
+ clothing_width = x_max - x_min
156
+ clothing_height = y_max - y_min
157
+ padding_x = int(clothing_width * 0.1)
158
+ padding_y = int(clothing_height * 0.1)
159
+
160
+ # Apply padding with image bounds
161
+ x_min = max(0, x_min - padding_x)
162
+ y_min = max(0, y_min - padding_y)
163
+ x_max = min(image.width, x_max + padding_x)
164
+ y_max = min(image.height, y_max + padding_y)
165
+
166
+ coordinates[label_name] = {
167
+ "x_min": int(x_min),
168
+ "y_min": int(y_min),
169
+ "x_max": int(x_max),
170
+ "y_max": int(y_max),
171
+ "width": int(x_max - x_min),
172
+ "height": int(y_max - y_min)
173
+ }
174
+
175
+ # Sort by percentage area
176
+ sorted_clothing = dict(sorted(
177
+ clothing_types.items(),
178
+ key=lambda x: x[1]["percentage"],
179
+ reverse=True
180
+ ))
181
+
182
+ return {
183
+ "clothing_types": sorted_clothing,
184
+ "coordinates": coordinates,
185
+ "total_detected": len(sorted_clothing),
186
+ "main_clothing": list(sorted_clothing.keys())[:3] if sorted_clothing else []
187
+ }
188
+
189
+ except Exception as e:
190
+ logger.error(f"Error in clothing detection: {str(e)}")
191
+ return {
192
+ "clothing_types": {},
193
+ "coordinates": {},
194
+ "total_detected": 0,
195
+ "main_clothing": [],
196
+ "error": str(e)
197
+ }
198
+
199
+ def create_clothing_only_image(self, image_bytes: bytes, selected_clothing: str = None) -> str:
200
+ """
201
+ Create clothing-only image with transparent background.
202
+
203
+ Args:
204
+ image_bytes: Raw image bytes
205
+ selected_clothing: Optional clothing label to isolate
206
+
207
+ Returns:
208
+ str: Base64-encoded PNG data URL
209
+ """
210
+ try:
211
+ # Get cached segmentation
212
+ seg_result = self._segment_image(image_bytes)
213
+ pred_seg = seg_result['pred_seg']
214
+ image = seg_result['image']
215
+
216
+ # Create clothing-only mask
217
+ clothing_mask = np.zeros_like(pred_seg, dtype=bool)
218
+
219
+ if selected_clothing:
220
+ # If specific clothing selected, find its class id
221
+ selected_class_id = None
222
+ for class_id, label_name in self.labels.items():
223
+ if label_name == selected_clothing:
224
+ selected_class_id = class_id
225
+ break
226
+
227
+ if selected_class_id is not None:
228
+ # Build mask only for the selected class
229
+ clothing_mask = (pred_seg == selected_class_id)
230
+ else:
231
+ # If not found, fall back to all clothing classes
232
+ for class_id in self.clothing_classes:
233
+ clothing_mask |= (pred_seg == class_id)
234
+ else:
235
+ # Otherwise, use all clothing classes
236
+ for class_id in self.clothing_classes:
237
+ clothing_mask |= (pred_seg == class_id)
238
+
239
+ # Convert image to numpy array
240
+ image_array = np.array(image)
241
+
242
+ # Compose RGBA with transparent background
243
+ clothing_only_rgba = np.zeros((image_array.shape[0], image_array.shape[1], 4), dtype=np.uint8)
244
+ clothing_only_rgba[..., :3] = image_array # RGB channels
245
+ clothing_only_rgba[..., 3] = 255 # Alpha channel (opaque)
246
+ clothing_only_rgba[~clothing_mask, 3] = 0 # Transparent for non-clothing
247
+
248
+ # Create PIL image
249
+ clothing_image = Image.fromarray(clothing_only_rgba, 'RGBA')
250
+
251
+ # If a specific clothing selected, crop with padding
252
+ if selected_clothing and selected_class_id is not None:
253
+ clothing_image = self._crop_with_padding(clothing_image, clothing_mask)
254
+
255
+ # Encode to base64
256
+ buffer = BytesIO()
257
+ clothing_image.save(buffer, format='PNG')
258
+ img_str = base64.b64encode(buffer.getvalue()).decode()
259
+
260
+ return f"data:image/png;base64,{img_str}"
261
+
262
+ except Exception as e:
263
+ logger.error(f"Error in creating clothing-only image: {str(e)}")
264
+ return ""
265
+
266
+ def _crop_with_padding(self, image: Image.Image, mask: np.ndarray, padding_percent: float = 0.1) -> Image.Image:
267
+ """
268
+ Crop image around clothing mask with padding.
269
+
270
+ Args:
271
+ image: PIL image
272
+ mask: Clothing mask
273
+ padding_percent: Padding percentage relative to clothing size
274
+
275
+ Returns:
276
+ Image.Image: Cropped image
277
+ """
278
+ try:
279
+ # Find clothing bounds
280
+ rows = np.any(mask, axis=1)
281
+ cols = np.any(mask, axis=0)
282
+
283
+ if not np.any(rows) or not np.any(cols):
284
+ return image # If no clothing found, return original
285
+
286
+ # Get bounds
287
+ y_min, y_max = np.where(rows)[0][[0, -1]]
288
+ x_min, x_max = np.where(cols)[0][[0, -1]]
289
+
290
+ # Compute clothing size
291
+ clothing_width = x_max - x_min
292
+ clothing_height = y_max - y_min
293
+
294
+ # Compute padding
295
+ padding_x = int(clothing_width * padding_percent)
296
+ padding_y = int(clothing_height * padding_percent)
297
+
298
+ # Apply padding within image bounds
299
+ x_min = max(0, x_min - padding_x)
300
+ y_min = max(0, y_min - padding_y)
301
+ x_max = min(image.width, x_max + padding_x)
302
+ y_max = min(image.height, y_max + padding_y)
303
+
304
+ # Crop
305
+ cropped_image = image.crop((x_min, y_min, x_max, y_max))
306
+
307
+ return cropped_image
308
+
309
+ except Exception as e:
310
+ logger.error(f"Error in cropping with padding: {str(e)}")
311
+ return image
312
+
313
+ # Global detector singleton (to reuse model)
314
+ _detector = None
315
+
316
+ def get_clothing_detector():
317
+ """Get global detector instance (lazy-init)."""
318
+ global _detector
319
+ if _detector is None:
320
+ _detector = ClothingDetector()
321
+ return _detector
322
+
323
+ def detect_clothing_types(image_bytes: bytes) -> dict:
324
+ """Convenience wrapper for clothing detection."""
325
+ detector = get_clothing_detector()
326
+ return detector.detect_clothing(image_bytes)
327
+
328
+ def create_clothing_only_image(image_bytes: bytes, selected_clothing: str = None) -> str:
329
+ """Convenience wrapper for clothing-only image creation."""
330
+ detector = get_clothing_detector()
331
+ return detector.create_clothing_only_image(image_bytes, selected_clothing)
process.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rembg import remove
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ from sklearn.cluster import KMeans
5
+ import base64
6
+
7
+ import os
8
+ import uuid
9
+ import numpy as np
10
+
11
+
12
+ def get_dominant_color(processed_bytes, k=3):
13
+ # Step 1: load transparent image
14
+ image = Image.open(BytesIO(processed_bytes)).convert("RGBA")
15
+ image = image.resize((100, 100)) # Resize to speed up
16
+
17
+ # Step 2: Filter only visible (non-transparent) pixels
18
+ np_image = np.array(image)
19
+ rgb_pixels = np_image[...,:3] # Ignore alpha channel
20
+ alpha = np_image[..., 3]
21
+ rgb_pixels = rgb_pixels[alpha > 0] # Keep only pixels where alpha > 0
22
+
23
+ # Step 3: KMeans clustering
24
+ kmeans = KMeans(n_clusters=k, n_init='auto')
25
+ kmeans.fit(rgb_pixels)
26
+ dominant_color = kmeans.cluster_centers_[0]
27
+ r, g, b = map(int, dominant_color)
28
+ return f"rgb({r}, {g}, {b})"
29
+
30
+
31
+ def get_dominant_color_from_base64(base64_image, k=3):
32
+ """Compute dominant color from base64-encoded clothing-only image."""
33
+ try:
34
+ # Step 1: Decode base64 to bytes
35
+ if base64_image.startswith('data:image'):
36
+ # Remove data URL prefix
37
+ base64_data = base64_image.split(',')[1]
38
+ else:
39
+ base64_data = base64_image
40
+
41
+ image_bytes = base64.b64decode(base64_data)
42
+
43
+ # Step 2: Load image and convert to RGBA
44
+ image = Image.open(BytesIO(image_bytes)).convert("RGBA")
45
+ image = image.resize((100, 100)) # Resize to speed up
46
+
47
+ # Step 3: Filter only visible (non-transparent) pixels
48
+ np_image = np.array(image)
49
+ rgb_pixels = np_image[...,:3] # Ignore alpha channel
50
+ alpha = np_image[..., 3]
51
+ rgb_pixels = rgb_pixels[alpha > 0] # Keep only pixels where alpha > 0
52
+
53
+ # Check if we have any visible pixels
54
+ if len(rgb_pixels) == 0:
55
+ return "rgb(0, 0, 0)" # Fallback to black if no visible pixels
56
+
57
+ # Step 4: KMeans clustering
58
+ kmeans = KMeans(n_clusters=k, n_init='auto')
59
+ kmeans.fit(rgb_pixels)
60
+ dominant_color = kmeans.cluster_centers_[0]
61
+ r, g, b = map(int, dominant_color)
62
+ return f"rgb({r}, {g}, {b})"
63
+
64
+ except Exception as e:
65
+ print(f"Error in get_dominant_color_from_base64: {e}")
66
+ return "rgb(0, 0, 0)" # Fallback to black on error
67
+
68
+
69
+ def remove_background(image_bytes: bytes) -> bytes:
70
+ result_bytes = remove(image_bytes)
71
+
72
+ # Save image to disk
73
+ output_image = Image.open(BytesIO(result_bytes))
74
+ file_name = f"{uuid.uuid4().hex[:8]}.png"
75
+ output_path = os.path.join("results", file_name)
76
+ output_image.save(output_path)
77
+
78
+ print(f"✅ Saved background-removed image to: {output_path}")
79
+ return result_bytes
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pillow
4
+ numpy
5
+ transformers
6
+ torch
7
+ torchvision
8
+ scikit-learn
9
+ python-multipart