yaghi27 commited on
Commit
431cc3e
·
1 Parent(s): 732b983

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +57 -20
main.py CHANGED
@@ -1,14 +1,25 @@
1
  import logging
2
- from fastapi import FastAPI, UploadFile, File
 
 
 
 
 
 
3
  from fastapi.responses import HTMLResponse, JSONResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.middleware.cors import CORSMiddleware
6
- import base64
7
  import mmcv
8
- from model.run_inference import infer_images
9
 
10
  logging.basicConfig(level=logging.INFO)
11
 
 
 
 
 
 
12
  app = FastAPI()
13
  app.add_middleware(
14
  CORSMiddleware,
@@ -19,6 +30,7 @@ app.add_middleware(
19
  )
20
  app.mount("/static", StaticFiles(directory="static"), name="static")
21
 
 
22
  @app.get("/", response_class=HTMLResponse)
23
  async def root():
24
  with open("static/index.html", "r", encoding="utf-8") as f:
@@ -26,26 +38,51 @@ async def root():
26
 
27
 
28
  @app.post("/infer")
29
- async def run_inference(images: list[UploadFile] = File(...)):
 
 
 
 
 
 
 
 
 
 
 
30
  img_paths = []
31
- for upload in images:
32
- data = await upload.read()
33
- # Drop any alpha channel, force 3-channel BGR
34
- bgr = mmcv.imfrombytes(data, flag="color")
35
- tmp = f"/tmp/{upload.filename}"
36
- mmcv.imwrite(bgr, tmp)
37
- img_paths.append(tmp)
38
 
39
  try:
40
- bev_paths = infer_images(img_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
  logging.exception("inference failed")
43
  return JSONResponse(status_code=500, content={"error": str(e)})
44
-
45
- output = []
46
- for p in bev_paths:
47
- with open(p, "rb") as f:
48
- b64 = base64.b64encode(f.read()).decode("utf-8")
49
- output.append({"bev_image": b64})
50
-
51
- return JSONResponse(content=output)
 
1
  import logging
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ import base64
6
+ from typing import List
7
+
8
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
9
  from fastapi.responses import HTMLResponse, JSONResponse
10
  from fastapi.staticfiles import StaticFiles
11
  from fastapi.middleware.cors import CORSMiddleware
12
+
13
  import mmcv
14
+ from model.run_inference import infer_images # <- update this to accept model=...
15
 
16
  logging.basicConfig(level=logging.INFO)
17
 
18
+ ALLOWED_MODELS = {
19
+ "regnetx4.0gf+dterd",
20
+ "regnetx4.0gf+petr",
21
+ }
22
+
23
  app = FastAPI()
24
  app.add_middleware(
25
  CORSMiddleware,
 
30
  )
31
  app.mount("/static", StaticFiles(directory="static"), name="static")
32
 
33
+
34
  @app.get("/", response_class=HTMLResponse)
35
  async def root():
36
  with open("static/index.html", "r", encoding="utf-8") as f:
 
38
 
39
 
40
  @app.post("/infer")
41
+ async def run_inference(
42
+ model: str = Form(...),
43
+ images: List[UploadFile] = File(...),
44
+ ):
45
+ model = model.strip().lower()
46
+ if model not in ALLOWED_MODELS:
47
+ raise HTTPException(status_code=400, detail=f"Invalid model '{model}'. Allowed: {sorted(ALLOWED_MODELS)}")
48
+
49
+ if len(images) != 6:
50
+ raise HTTPException(status_code=400, detail=f"Expected 6 images, received {len(images)}")
51
+
52
+ tmpdir = tempfile.mkdtemp(prefix="bev_infer_")
53
  img_paths = []
 
 
 
 
 
 
 
54
 
55
  try:
56
+ for idx, upload in enumerate(images):
57
+ data = await upload.read()
58
+ bgr = mmcv.imfrombytes(data, flag="color")
59
+ if bgr is None:
60
+ raise HTTPException(status_code=400, detail=f"File '{upload.filename}' is not a valid image.")
61
+
62
+ out_path = os.path.join(tmpdir, f"cam_{idx}.png")
63
+ mmcv.imwrite(bgr, out_path)
64
+ img_paths.append(out_path)
65
+
66
+ logging.info("Starting inference with model=%s on %d images", model, len(img_paths))
67
+
68
+ bev_paths = infer_images(img_paths, model=model)
69
+
70
+ output = []
71
+ for p in bev_paths:
72
+ with open(p, "rb") as f:
73
+ b64 = base64.b64encode(f.read()).decode("utf-8")
74
+ output.append({"bev_image": b64})
75
+
76
+ return JSONResponse(content=output)
77
+
78
+ except HTTPException:
79
+
80
+ raise
81
  except Exception as e:
82
  logging.exception("inference failed")
83
  return JSONResponse(status_code=500, content={"error": str(e)})
84
+ finally:
85
+ try:
86
+ shutil.rmtree(tmpdir)
87
+ except Exception:
88
+ logging.warning("Failed to clean tmpdir %s", tmpdir)