File size: 4,174 Bytes
fdc1efd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import logging
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Dict
from fastapi import Depends, FastAPI, File, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from src.api.ModelService import ModelServiceAST
from pydantic import BaseModel, validator
LOG_SAVE_DIR = Path(__file__).parent / "logs"
if not os.path.exists(LOG_SAVE_DIR):
os.makedirs(LOG_SAVE_DIR)
ml_models = {}
ml_models["Accuracy"] = ModelServiceAST(model_type="accuracy")
ml_models["Speed"] = ModelServiceAST(model_type="speed")
app = FastAPI()
# Define the allowed file formats and maximum file size (in bytes)
ALLOWED_FILE_FORMATS = ["wav"]
# Configure logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# Create a rotating file handler to save logs to a file
handler = RotatingFileHandler(f"{LOG_SAVE_DIR}/app.log", maxBytes=100000, backupCount=5)
handler.setLevel(logging.DEBUG)
# Define the log format
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
# Add the handler to the logger
logger.addHandler(handler)
class InvalidFileTypeError(Exception):
def __init__(self):
self.message = "Only wav files are supported"
super().__init__(self.message)
class InvalidModelError(Exception):
def __init__(self):
self.message = "Selected model doesn't exist"
super().__init__(self.message)
class MissingFileError(Exception):
def __init__(self):
self.message = "File cannot be None"
super().__init__(self.message)
class PredictionRequest(BaseModel):
model_name: str
@validator("model_name")
@classmethod
def valid_model(cls, v):
if v not in ml_models.keys():
raise InvalidModelError
return v
class PredictionResult(BaseModel):
prediction: Dict[str, Dict[str, int]]
@app.exception_handler(RequestValidationError)
def validation_exception_handler(request, ex):
logger.error(f"Request validation error: {ex}")
return JSONResponse(content={"error": "Bad Request", "detail": ex.errors()}, status_code=400)
@app.exception_handler(InvalidFileTypeError)
def filetype_exception_handler(request, ex):
logger.error(f"Invalid file type error: {ex}")
return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
@app.exception_handler(InvalidModelError)
def model_exception_handler(request, ex):
logger.error(f"Invalid model error: {ex}")
return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
@app.exception_handler(MissingFileError)
def handle_missing_file_error(request, ex):
logger.error(f"Missing file error: {ex}")
return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
@app.exception_handler(Exception)
def handle_exceptions(request, ex):
logger.exception(f"Internal server error: {ex}")
# If an exception occurs during processing, return a JSON response with an error message
return JSONResponse(content={"error": "Internal Server Error", "detail": str(ex)}, status_code=500)
@app.get("/")
def root():
logger.info("Received request to root endpoint")
return {"message": "Welcome to my API. Go to /docs to view the documentation."}
@app.get("/health-check")
def health_check():
"""
Health check endpoint to verify if the API is running.
"""
logger.info("Health check endpoint was hit")
return {"status": "API is running"}
@app.post("/predict")
def predict(request: PredictionRequest = Depends(), file: UploadFile = File(...)) -> PredictionResult: # noqa
if not file:
raise MissingFileError
if file.filename.split(".")[-1].lower() not in ALLOWED_FILE_FORMATS:
raise InvalidFileTypeError
logger.info(f"Prediction request received: {request}")
output = ml_models[request.model_name].get_prediction(file.file)
logger.info(f"Prediction result: {output}")
prediction_result = PredictionResult(prediction={file.filename: output})
return prediction_result
|