|
from typing import Optional |
|
|
|
from fastapi import APIRouter |
|
from fastapi import FastAPI |
|
from schemas import ClassificationResult |
|
from utils import load_image |
|
from utils import load_model |
|
|
|
|
|
|
|
|
|
model = load_model() |
|
|
|
app = FastAPI( |
|
title="MosAl", |
|
openapi_url="/openapi.json", |
|
description="""Obtain classification predictions for mosquito image""", |
|
version="0.1.0", |
|
) |
|
|
|
api_router = APIRouter() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@api_router.get("/classify/{image_name}", status_code=200, response_model=ClassificationResult) |
|
async def predict_image(image_name, model=model): |
|
img = load_image(image_name) |
|
prediction, pred_idx, probs = model.predict(img) |
|
if prediction: |
|
return {"prediction": prediction, |
|
"score": round(probs.numpy()[pred_idx], 3), |
|
} |
|
else: |
|
return {"message": [0]} |
|
|
|
|
|
|
|
app.include_router(api_router) |
|
|
|
if __name__ == "__main__": |
|
|
|
import uvicorn |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="debug") |
|
|