|
import os |
|
|
|
from app import settings as config |
|
from app import utils |
|
from app.auth.jwt import get_current_user |
|
from app.model.schema import PredictResponse |
|
from app.model.services import model_predict |
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, status |
|
|
|
router = APIRouter(tags=["Model"], prefix="/model") |
|
|
|
|
|
@router.post("/predict") |
|
async def predict(file: UploadFile, current_user=Depends(get_current_user)): |
|
rpse = {"success": False, "prediction": None, "score": None} |
|
|
|
|
|
if not file or not utils.allowed_file(file.filename): |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail="File type is not supported.", |
|
) |
|
|
|
|
|
new_filename = await utils.get_file_hash(file) |
|
file_path = os.path.join(config.UPLOAD_FOLDER, new_filename) |
|
|
|
if not os.path.exists(file_path): |
|
with open(file_path, "wb") as out_file: |
|
content = await file.read() |
|
out_file.write(content) |
|
|
|
|
|
await file.seek(0) |
|
|
|
|
|
prediction, score = await model_predict(file_path) |
|
|
|
|
|
rpse["success"] = True |
|
rpse["prediction"] = prediction |
|
rpse["score"] = score |
|
rpse["image_file_name"] = new_filename |
|
|
|
return PredictResponse(**rpse) |
|
|