Mr7Explorer commited on
Commit
5c59863
·
verified ·
1 Parent(s): 80abaa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -8,20 +8,19 @@ import cv2
8
  import io
9
  from datetime import datetime, timedelta
10
  from collections import defaultdict
11
- import os
12
 
13
- app = FastAPI(title="Backdrop Studio API", version="2.0.0")
14
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
- allow_origins=["*"], # Restrict for production!
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
23
- MODEL_PATH = "models/modnet.onnx"
24
- MODEL_WIDTH = 512 # Official MODNet ONNX input size (for best speed, use 512; 768 or 1024 for higher res if model supports)
25
  MODEL_HEIGHT = 512
26
 
27
  print("🔄 Loading MODNet ONNX model...")
@@ -47,17 +46,15 @@ def preprocess_image(image: Image.Image, target_size=(MODEL_WIDTH, MODEL_HEIGHT)
47
  image = image.convert('RGB')
48
  orig_width, orig_height = image.size
49
  image_resized = image.resize(target_size, Image.LANCZOS)
50
- img_array = np.array(image_resized).astype(np.float32) / 255.0 # shape (512, 512, 3)
51
- img_array = np.transpose(img_array, (2, 0, 1)) # (3, 512, 512)
52
- img_array = np.expand_dims(img_array, axis=0) # (1, 3, 512, 512)
53
  return img_array, (orig_width, orig_height)
54
 
55
  def postprocess_mask(mask: np.ndarray, original_size):
56
- # MODNet returns (1,1,H,W) float in [0,1]
57
- mask = mask[0, 0] # (H,W)
58
  mask = (mask * 255).round().astype(np.uint8)
59
  mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR)
60
- # Optional: Apply threshold to get crisp mask
61
  mask = np.where(mask > 127, 255, 0).astype(np.uint8)
62
  return mask
63
 
@@ -73,7 +70,7 @@ def remove_background(image: Image.Image):
73
 
74
  @app.get("/")
75
  async def root():
76
- return {"status": "healthy", "service": "Backdrop Studio MODNet API", "version": "2.0.0"}
77
 
78
  @app.get("/quota/{user_id}")
79
  async def get_quota(user_id: str):
@@ -137,5 +134,4 @@ async def remove_background_endpoint(
137
 
138
  if __name__ == "__main__":
139
  import uvicorn
140
- port = int(os.environ.get("PORT", 8080))
141
- uvicorn.run(app, host="0.0.0.0", port=port)
 
8
  import io
9
  from datetime import datetime, timedelta
10
  from collections import defaultdict
 
11
 
12
+ app = FastAPI(title="MODNet API", version="1.0.0")
13
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
+ allow_origins=["*"],
17
  allow_credentials=True,
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
22
+ MODEL_PATH = "modnet.onnx" # model ONNX is in root folder!
23
+ MODEL_WIDTH = 512
24
  MODEL_HEIGHT = 512
25
 
26
  print("🔄 Loading MODNet ONNX model...")
 
46
  image = image.convert('RGB')
47
  orig_width, orig_height = image.size
48
  image_resized = image.resize(target_size, Image.LANCZOS)
49
+ img_array = np.array(image_resized).astype(np.float32) / 255.0
50
+ img_array = np.transpose(img_array, (2, 0, 1))
51
+ img_array = np.expand_dims(img_array, axis=0)
52
  return img_array, (orig_width, orig_height)
53
 
54
  def postprocess_mask(mask: np.ndarray, original_size):
55
+ mask = mask[0, 0]
 
56
  mask = (mask * 255).round().astype(np.uint8)
57
  mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR)
 
58
  mask = np.where(mask > 127, 255, 0).astype(np.uint8)
59
  return mask
60
 
 
70
 
71
  @app.get("/")
72
  async def root():
73
+ return {"status": "healthy", "service": "MODNet API", "version": "1.0.0"}
74
 
75
  @app.get("/quota/{user_id}")
76
  async def get_quota(user_id: str):
 
134
 
135
  if __name__ == "__main__":
136
  import uvicorn
137
+ uvicorn.run(app, host="0.0.0.0", port=7860)