Spaces:
Sleeping
Sleeping
| # app.py | |
| import io | |
| import logging | |
| import traceback | |
| import time | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from starlette.responses import RedirectResponse | |
| from PIL import Image | |
| import numpy as np | |
| import depth_texture_mask # make sure depth_texture_mask.py is at repo root | |
| logger = logging.getLogger("uvicorn.error") | |
| app = FastAPI(title="Depth & Structural Masking API with UI") | |
| # Mount static folder for CSS/JS/images | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| async def startup_event(): | |
| try: | |
| logger.info("Initializing MiDaS model...") | |
| # This will initialize the heavy model once | |
| depth_texture_mask.init_midas() | |
| logger.info("MiDaS initialized.") | |
| except Exception as e: | |
| logger.exception("Error initializing MiDaS: %s", e) | |
| async def index(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| def pil_image_from_uploadfile(upload_file: UploadFile) -> Image.Image: | |
| contents = upload_file.file.read() | |
| img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| upload_file.file.close() | |
| return img | |
| def numpy_from_pil(pil_img: Image.Image) -> np.ndarray: | |
| return np.asarray(pil_img) | |
| def pil_from_mask_array(mask: np.ndarray) -> Image.Image: | |
| arr = mask.copy() | |
| if np.issubdtype(arr.dtype, np.floating): | |
| if arr.max() <= 1.0: | |
| arr = (arr * 255.0).astype("uint8") | |
| else: | |
| arr = np.clip(arr, 0, 255).astype("uint8") | |
| else: | |
| arr = np.clip(arr, 0, 255).astype("uint8") | |
| if arr.ndim == 3 and arr.shape[2] == 3: | |
| arr = (0.2989 * arr[...,0] + 0.5870 * arr[...,1] + 0.1140 * arr[...,2]).astype("uint8") | |
| return Image.fromarray(arr, mode="L") | |
| async def generate_mask_endpoint(file: UploadFile = File(...)): | |
| """ | |
| Accept an image file and return a PNG mask. | |
| Adds a response header 'X-Inference-Time-ms' with inference time in milliseconds. | |
| """ | |
| try: | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=415, detail="Unsupported file type.") | |
| # read + convert | |
| pil_img = pil_image_from_uploadfile(file) | |
| input_np = numpy_from_pil(pil_img) | |
| # Call model & measure time | |
| start = time.perf_counter() | |
| mask = depth_texture_mask.generate_texture_depth_mask(input_np, mask_only=True) | |
| end = time.perf_counter() | |
| infer_ms = int((end - start) * 1000) | |
| if mask is None: | |
| raise HTTPException(status_code=500, detail="Mask generation failed.") | |
| mask_pil = pil_from_mask_array(mask) | |
| buf = io.BytesIO() | |
| mask_pil.save(buf, format="PNG") | |
| buf.seek(0) | |
| headers = {"X-Inference-Time-ms": str(infer_ms)} | |
| return StreamingResponse(buf, media_type="image/png", headers=headers) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error("Error in /mask/: %s", e) | |
| logger.debug(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def redirect_ui(): | |
| return RedirectResponse("/") | |