Munaf1987 commited on
Commit
a8dfdb8
·
verified ·
1 Parent(s): ff68c5b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +349 -46
main.py CHANGED
@@ -15,12 +15,16 @@ import numpy as np
15
  import cv2
16
  import onnxruntime as ort
17
  from PIL import Image
18
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException
19
- from fastapi.responses import JSONResponse
20
  from fastapi.staticfiles import StaticFiles
21
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
22
  import gradio as gr
23
  import uvicorn
 
24
 
25
  # Import spaces directly without conditional
26
  try:
@@ -34,14 +38,99 @@ except ImportError:
34
  print("📦 main.py loaded. ZeroGPU debugging enabled.")
35
 
36
  ###############################################################################
37
- # 3. App Setup
38
  ###############################################################################
39
- API_KEY = os.getenv("API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- app = FastAPI(docs_url="/api/docs", openapi_url="/api/openapi.json")
42
  app.add_middleware(
43
  CORSMiddleware,
44
  allow_origins=["*"],
 
45
  allow_methods=["*"],
46
  allow_headers=["*"]
47
  )
@@ -54,7 +143,30 @@ model_path = "BiRefNet-portrait-epoch_150.onnx"
54
  input_size = (1024, 1024)
55
 
56
  ###############################################################################
57
- # 4. Preprocess & Postprocess Utilities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ###############################################################################
59
  def preprocess_image(image):
60
  """Handle both bytes and numpy arrays"""
@@ -85,10 +197,13 @@ def apply_mask(original_img, mask_array, original_shape):
85
  return bgra
86
 
87
  ###############################################################################
88
- # 5. Core Processing Function
89
  ###############################################################################
90
- def process_image_core(image_data, use_gpu=False):
91
- """Core image processing logic"""
 
 
 
92
  try:
93
  providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
94
  session = ort.InferenceSession(model_path, providers=providers)
@@ -101,10 +216,14 @@ def process_image_core(image_data, use_gpu=False):
101
  output = session.run(None, {input_name: input_tensor})
102
  mask = output[0]
103
  result_img = apply_mask(original_img, mask, original_shape)
104
- return Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGRA2RGBA))
 
 
 
 
105
 
106
  ###############################################################################
107
- # 6. GPU Gradio Function (Direct Decoration)
108
  ###############################################################################
109
  def gradio_processor(image_np):
110
  """Process image for Gradio interface"""
@@ -114,7 +233,8 @@ def gradio_processor(image_np):
114
  raise ValueError("Failed to encode image")
115
  image_bytes = img_encoded.tobytes()
116
 
117
- return process_image_core(image_bytes, use_gpu=USE_SPACES_GPU)
 
118
 
119
  # Apply GPU decorator directly
120
  if USE_SPACES_GPU:
@@ -127,59 +247,242 @@ else:
127
  print("🧠 gradio_processor has GPU metadata:", getattr(gradio_processor, "_spaces_gpu", None))
128
 
129
  ###############################################################################
130
- # 7. Gradio Interface
131
  ###############################################################################
132
- interface = gr.Interface(
133
- fn=gradio_processor,
134
- inputs=gr.Image(type="numpy", label="Upload Image"),
135
- outputs=gr.Image(type="pil", label="Processed Image"),
136
- title="🧠 Background Removal (GPU)",
137
- description="Upload an image to remove its background using ZeroGPU-powered ONNX.",
138
- flagging_options=None
139
- )
140
 
141
- interface.api_name = "remove_background_gpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- ###############################################################################
144
- # 8. FastAPI REST API (CPU Inference)
145
- ###############################################################################
146
- @app.post("/api/remove-background")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  async def remove_background_api(
148
- api_key: str = Form(...),
149
- image: UploadFile = File(...)
 
 
150
  ):
151
- if api_key != API_KEY:
152
- raise HTTPException(status_code=401, detail="Invalid API key")
153
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  image_data = await image.read()
156
- result_img = process_image_core(image_data, use_gpu=False)
 
 
 
 
 
157
 
158
- result_filename = f"{uuid.uuid4()}.png"
 
159
  output_path = f"{TMP_FOLDER}/{result_filename}"
160
- result_img.save(output_path, "PNG")
 
 
 
 
 
 
 
 
 
 
161
 
 
162
  with open(output_path, "rb") as img_file:
163
  base64_image = base64.b64encode(img_file.read()).decode("utf-8")
164
 
165
- return JSONResponse(content={
166
- "status": "success",
167
- "image_code": f"data:image/png;base64,{base64_image}"
168
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  except Exception as e:
171
- return JSONResponse(content={
172
- "status": "failure",
173
- "message": str(e)
174
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  ###############################################################################
177
- # 9. Mount Gradio at Root Path (UI)
178
  ###############################################################################
179
- app = gr.mount_gradio_app(app, interface, path="/", ssr_mode=False)
180
 
181
  ###############################################################################
182
- # 10. Run Local Dev Server
183
  ###############################################################################
184
  if __name__ == "__main__":
185
- uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
 
 
 
 
 
 
 
 
 
 
15
  import cv2
16
  import onnxruntime as ort
17
  from PIL import Image
18
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query, Depends
19
+ from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
20
  from fastapi.staticfiles import StaticFiles
21
  from fastapi.middleware.cors import CORSMiddleware
22
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
23
+ from pydantic import BaseModel, Field
24
+ from typing import Optional, List, Union
25
  import gradio as gr
26
  import uvicorn
27
+ from datetime import datetime
28
 
29
  # Import spaces directly without conditional
30
  try:
 
38
  print("📦 main.py loaded. ZeroGPU debugging enabled.")
39
 
40
  ###############################################################################
41
+ # 3. Pydantic Models for API Documentation
42
  ###############################################################################
43
+ class BackgroundRemovalRequest(BaseModel):
44
+ """Request model for background removal"""
45
+ image_format: Optional[str] = Field(default="PNG", description="Output image format (PNG, JPEG)")
46
+ quality: Optional[int] = Field(default=95, ge=1, le=100, description="Output quality for JPEG (1-100)")
47
+
48
+ class BackgroundRemovalResponse(BaseModel):
49
+ """Response model for successful background removal"""
50
+ status: str = Field(..., description="Status of the operation")
51
+ image_code: str = Field(..., description="Base64 encoded image with data URI prefix")
52
+ processing_time: float = Field(..., description="Processing time in seconds")
53
+ original_size: List[int] = Field(..., description="Original image dimensions [width, height]")
54
+ output_format: str = Field(..., description="Output image format")
55
+
56
+ class ErrorResponse(BaseModel):
57
+ """Error response model"""
58
+ status: str = Field(..., description="Status of the operation")
59
+ message: str = Field(..., description="Error message")
60
+ error_code: Optional[str] = Field(None, description="Specific error code")
61
+
62
+ class HealthResponse(BaseModel):
63
+ """Health check response"""
64
+ status: str = Field(..., description="Service status")
65
+ timestamp: str = Field(..., description="Current timestamp")
66
+ version: str = Field(..., description="API version")
67
+ gpu_available: bool = Field(..., description="Whether GPU is available")
68
+ model_loaded: bool = Field(..., description="Whether the model is loaded")
69
+
70
+ ###############################################################################
71
+ # 4. App Setup with Enhanced Documentation
72
+ ###############################################################################
73
+ API_KEY = os.getenv("API_KEY", "demo-key-change-in-production")
74
+
75
+ app = FastAPI(
76
+ title="🧠 Background Removal API",
77
+ description="""
78
+ # Background Removal API with Gradio Interface
79
+
80
+ This API provides advanced background removal capabilities using ONNX models with optional GPU acceleration.
81
+
82
+ ## Features
83
+ - 🖼️ High-quality background removal
84
+ - ⚡ GPU acceleration (when available)
85
+ - 🎨 Multiple output formats (PNG, JPEG)
86
+ - 📱 Gradio web interface
87
+ - 🔒 API key authentication
88
+ - 📊 Real-time processing metrics
89
+
90
+ ## Usage
91
+ 1. **Web Interface**: Visit the root path `/` for the interactive Gradio interface
92
+ 2. **REST API**: Use `/api/remove-background` endpoint for programmatic access
93
+ 3. **Documentation**: Visit `/api/docs` for this interactive documentation
94
+
95
+ ## Authentication
96
+ - API endpoints require an API key provided via form data or header
97
+ - Set `API_KEY` environment variable for production use
98
+
99
+ ## Gradio Integration
100
+ According to [Gradio's documentation](https://www.gradio.app/guides/querying-gradio-apps-with-curl),
101
+ the Gradio interface automatically exposes REST API endpoints that can be accessed via cURL:
102
+
103
+ ```bash
104
+ # Make prediction
105
+ curl -X POST {your-url}/call/remove_background_gpu \\
106
+ -H "Content-Type: application/json" \\
107
+ -d '{"data": [{"path": "https://example.com/image.jpg"}]}'
108
+
109
+ # Get result
110
+ curl -N {your-url}/call/remove_background_gpu/{event_id}
111
+ ```
112
+ """,
113
+ version="2.0.0",
114
+ docs_url="/api/docs",
115
+ redoc_url="/api/redoc",
116
+ openapi_url="/api/openapi.json",
117
+ contact={
118
+ "name": "Background Removal API",
119
+ "url": "https://github.com/yourusername/background-removal-api",
120
+ },
121
+ license_info={
122
+ "name": "MIT",
123
+ "url": "https://opensource.org/licenses/MIT",
124
+ },
125
+ )
126
+
127
+ # Security
128
+ security = HTTPBearer()
129
 
 
130
  app.add_middleware(
131
  CORSMiddleware,
132
  allow_origins=["*"],
133
+ allow_credentials=True,
134
  allow_methods=["*"],
135
  allow_headers=["*"]
136
  )
 
143
  input_size = (1024, 1024)
144
 
145
  ###############################################################################
146
+ # 5. Authentication Helper
147
+ ###############################################################################
148
+ async def verify_api_key(api_key: str = Form(...)):
149
+ """Verify API key from form data"""
150
+ if api_key != API_KEY:
151
+ raise HTTPException(
152
+ status_code=401,
153
+ detail="Invalid API key",
154
+ headers={"WWW-Authenticate": "Bearer"},
155
+ )
156
+ return api_key
157
+
158
+ async def verify_api_key_header(credentials: HTTPAuthorizationCredentials = Depends(security)):
159
+ """Verify API key from Authorization header"""
160
+ if credentials.credentials != API_KEY:
161
+ raise HTTPException(
162
+ status_code=401,
163
+ detail="Invalid API key",
164
+ headers={"WWW-Authenticate": "Bearer"},
165
+ )
166
+ return credentials.credentials
167
+
168
+ ###############################################################################
169
+ # 6. Preprocess & Postprocess Utilities
170
  ###############################################################################
171
  def preprocess_image(image):
172
  """Handle both bytes and numpy arrays"""
 
197
  return bgra
198
 
199
  ###############################################################################
200
+ # 7. Core Processing Function
201
  ###############################################################################
202
+ def process_image_core(image_data, use_gpu=False, output_format="PNG", quality=95):
203
+ """Core image processing logic with enhanced options"""
204
+ import time
205
+ start_time = time.time()
206
+
207
  try:
208
  providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
209
  session = ort.InferenceSession(model_path, providers=providers)
 
216
  output = session.run(None, {input_name: input_tensor})
217
  mask = output[0]
218
  result_img = apply_mask(original_img, mask, original_shape)
219
+ result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGRA2RGBA))
220
+
221
+ processing_time = time.time() - start_time
222
+
223
+ return result_pil, processing_time, original_shape
224
 
225
  ###############################################################################
226
+ # 8. GPU Gradio Function (Direct Decoration)
227
  ###############################################################################
228
  def gradio_processor(image_np):
229
  """Process image for Gradio interface"""
 
233
  raise ValueError("Failed to encode image")
234
  image_bytes = img_encoded.tobytes()
235
 
236
+ result, _, _ = process_image_core(image_bytes, use_gpu=USE_SPACES_GPU)
237
+ return result
238
 
239
  # Apply GPU decorator directly
240
  if USE_SPACES_GPU:
 
247
  print("🧠 gradio_processor has GPU metadata:", getattr(gradio_processor, "_spaces_gpu", None))
248
 
249
  ###############################################################################
250
+ # 9. FastAPI Endpoints
251
  ###############################################################################
 
 
 
 
 
 
 
 
252
 
253
+ @app.get("/", response_class=HTMLResponse, include_in_schema=False)
254
+ async def root():
255
+ """Root endpoint redirects to Gradio interface"""
256
+ return """
257
+ <!DOCTYPE html>
258
+ <html>
259
+ <head>
260
+ <title>Background Removal Service</title>
261
+ <meta http-equiv="refresh" content="0; url=/gradio">
262
+ </head>
263
+ <body>
264
+ <p>Redirecting to Gradio interface...</p>
265
+ <p><a href="/gradio">Click here if not redirected automatically</a></p>
266
+ </body>
267
+ </html>
268
+ """
269
 
270
+ @app.get("/api/health", response_model=HealthResponse, tags=["Health"])
271
+ async def health_check():
272
+ """
273
+ Health check endpoint to verify service status
274
+
275
+ Returns:
276
+ HealthResponse: Current service status and information
277
+ """
278
+ try:
279
+ # Test model loading
280
+ session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
281
+ model_loaded = True
282
+ except:
283
+ model_loaded = False
284
+
285
+ return HealthResponse(
286
+ status="healthy",
287
+ timestamp=datetime.now().isoformat(),
288
+ version="2.0.0",
289
+ gpu_available=USE_SPACES_GPU,
290
+ model_loaded=model_loaded
291
+ )
292
+
293
+ @app.post(
294
+ "/api/remove-background",
295
+ response_model=BackgroundRemovalResponse,
296
+ responses={
297
+ 401: {"model": ErrorResponse, "description": "Invalid API key"},
298
+ 400: {"model": ErrorResponse, "description": "Invalid input"},
299
+ 500: {"model": ErrorResponse, "description": "Processing error"}
300
+ },
301
+ tags=["Background Removal"]
302
+ )
303
  async def remove_background_api(
304
+ image: UploadFile = File(..., description="Image file to process"),
305
+ api_key: str = Depends(verify_api_key),
306
+ output_format: str = Form(default="PNG", description="Output format: PNG or JPEG"),
307
+ quality: int = Form(default=95, ge=1, le=100, description="JPEG quality (1-100, ignored for PNG)")
308
  ):
309
+ """
310
+ Remove background from uploaded image
311
+
312
+ This endpoint processes an uploaded image and returns the result with background removed.
313
+
314
+ **Parameters:**
315
+ - **image**: Image file to process (supports common formats: JPEG, PNG, WebP, etc.)
316
+ - **api_key**: Authentication key
317
+ - **output_format**: Output image format (PNG recommended for transparency)
318
+ - **quality**: JPEG compression quality (1-100, only used for JPEG output)
319
+
320
+ **Returns:**
321
+ - Base64 encoded image with transparent background
322
+ - Processing time and metadata
323
+
324
+ **Example Usage:**
325
+ ```bash
326
+ curl -X POST "http://localhost:7860/api/remove-background" \\
327
+ -F "api_key=demo-key-change-in-production" \\
328
+ -F "image=@your-image.jpg" \\
329
+ -F "output_format=PNG"
330
+ ```
331
+ """
332
  try:
333
+ # Validate image file
334
+ if not image.content_type.startswith('image/'):
335
+ raise HTTPException(
336
+ status_code=400,
337
+ detail=f"Invalid file type: {image.content_type}. Please upload an image file."
338
+ )
339
+
340
+ # Validate output format
341
+ if output_format.upper() not in ["PNG", "JPEG", "JPG"]:
342
+ raise HTTPException(
343
+ status_code=400,
344
+ detail="Invalid output format. Use 'PNG' or 'JPEG'."
345
+ )
346
+
347
+ # Process image
348
  image_data = await image.read()
349
+ result_img, processing_time, original_shape = process_image_core(
350
+ image_data,
351
+ use_gpu=False, # CPU for API endpoint
352
+ output_format=output_format.upper(),
353
+ quality=quality
354
+ )
355
 
356
+ # Save result
357
+ result_filename = f"{uuid.uuid4()}.{output_format.lower()}"
358
  output_path = f"{TMP_FOLDER}/{result_filename}"
359
+
360
+ if output_format.upper() == "PNG":
361
+ result_img.save(output_path, "PNG")
362
+ else:
363
+ # Convert RGBA to RGB for JPEG
364
+ if result_img.mode == 'RGBA':
365
+ rgb_img = Image.new('RGB', result_img.size, (255, 255, 255))
366
+ rgb_img.paste(result_img, mask=result_img.split()[-1])
367
+ rgb_img.save(output_path, "JPEG", quality=quality)
368
+ else:
369
+ result_img.save(output_path, "JPEG", quality=quality)
370
 
371
+ # Encode to base64
372
  with open(output_path, "rb") as img_file:
373
  base64_image = base64.b64encode(img_file.read()).decode("utf-8")
374
 
375
+ # Clean up temporary file
376
+ os.remove(output_path)
377
+
378
+ return BackgroundRemovalResponse(
379
+ status="success",
380
+ image_code=f"data:image/{output_format.lower()};base64,{base64_image}",
381
+ processing_time=round(processing_time, 3),
382
+ original_size=[original_shape[1], original_shape[0]], # width, height
383
+ output_format=output_format.upper()
384
+ )
385
+
386
+ except HTTPException:
387
+ raise
388
+ except Exception as e:
389
+ raise HTTPException(
390
+ status_code=500,
391
+ detail=f"Processing error: {str(e)}"
392
+ )
393
+
394
+ @app.post(
395
+ "/api/remove-background-file",
396
+ response_class=FileResponse,
397
+ tags=["Background Removal"]
398
+ )
399
+ async def remove_background_file(
400
+ image: UploadFile = File(..., description="Image file to process"),
401
+ api_key: str = Depends(verify_api_key),
402
+ output_format: str = Form(default="PNG", description="Output format: PNG or JPEG")
403
+ ):
404
+ """
405
+ Remove background and return image file directly
406
+
407
+ Similar to `/remove-background` but returns the processed image file directly
408
+ instead of base64 encoded data.
409
+ """
410
+ try:
411
+ if not image.content_type.startswith('image/'):
412
+ raise HTTPException(status_code=400, detail="Invalid file type")
413
+
414
+ image_data = await image.read()
415
+ result_img, _, _ = process_image_core(image_data, use_gpu=False)
416
+
417
+ result_filename = f"processed_{uuid.uuid4()}.{output_format.lower()}"
418
+ output_path = f"{TMP_FOLDER}/{result_filename}"
419
+
420
+ if output_format.upper() == "PNG":
421
+ result_img.save(output_path, "PNG")
422
+ else:
423
+ if result_img.mode == 'RGBA':
424
+ rgb_img = Image.new('RGB', result_img.size, (255, 255, 255))
425
+ rgb_img.paste(result_img, mask=result_img.split()[-1])
426
+ rgb_img.save(output_path, "JPEG", quality=95)
427
+ else:
428
+ result_img.save(output_path, "JPEG", quality=95)
429
+
430
+ return FileResponse(
431
+ output_path,
432
+ media_type=f"image/{output_format.lower()}",
433
+ filename=result_filename,
434
+ background=lambda: os.remove(output_path) # Clean up after response
435
+ )
436
 
437
  except Exception as e:
438
+ raise HTTPException(status_code=500, detail=str(e))
439
+
440
+ ###############################################################################
441
+ # 10. Gradio Interface
442
+ ###############################################################################
443
+ interface = gr.Interface(
444
+ fn=gradio_processor,
445
+ inputs=gr.Image(type="numpy", label="Upload Image"),
446
+ outputs=gr.Image(type="pil", label="Processed Image"),
447
+ title="🧠 Background Removal Service",
448
+ description="""
449
+ ## Upload an image to remove its background
450
+
451
+ - **GPU Acceleration**: Uses ZeroGPU when available
452
+ - **High Quality**: ONNX model for precise background removal
453
+ - **Fast Processing**: Optimized for real-time use
454
+
455
+ ### API Access
456
+ This interface is also available via REST API:
457
+ - **Documentation**: [/api/docs](/api/docs)
458
+ - **Health Check**: [/api/health](/api/health)
459
+ - **Background Removal**: POST `/api/remove-background`
460
+ """,
461
+ examples=[
462
+ # Add example images if you have them
463
+ ],
464
+ flagging_options=None,
465
+ allow_flagging="never"
466
+ )
467
+
468
+ interface.api_name = "remove_background_gpu"
469
 
470
  ###############################################################################
471
+ # 11. Mount Gradio App
472
  ###############################################################################
473
+ app = gr.mount_gradio_app(app, interface, path="/gradio", ssr_mode=False)
474
 
475
  ###############################################################################
476
+ # 12. Run Local Dev Server
477
  ###############################################################################
478
  if __name__ == "__main__":
479
+ print("🚀 Starting Background Removal Service")
480
+ print(f"📖 API Documentation: http://localhost:7860/api/docs")
481
+ print(f"🎨 Gradio Interface: http://localhost:7860/gradio")
482
+ print(f"❤️ Health Check: http://localhost:7860/api/health")
483
+ uvicorn.run(
484
+ "main:app",
485
+ host="0.0.0.0",
486
+ port=7860,
487
+ reload=False # Disable reload for production
488
+ )