File size: 4,174 Bytes
fdc1efd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import logging
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Dict


from fastapi import Depends, FastAPI, File, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from src.api.ModelService import ModelServiceAST
from pydantic import BaseModel, validator

LOG_SAVE_DIR = Path(__file__).parent / "logs"
if not os.path.exists(LOG_SAVE_DIR):
    os.makedirs(LOG_SAVE_DIR)

ml_models = {}
ml_models["Accuracy"] = ModelServiceAST(model_type="accuracy")
ml_models["Speed"] = ModelServiceAST(model_type="speed")

app = FastAPI()

# Define the allowed file formats and maximum file size (in bytes)
ALLOWED_FILE_FORMATS = ["wav"]

# Configure logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# Create a rotating file handler to save logs to a file
handler = RotatingFileHandler(f"{LOG_SAVE_DIR}/app.log", maxBytes=100000, backupCount=5)
handler.setLevel(logging.DEBUG)

# Define the log format
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)


class InvalidFileTypeError(Exception):
    def __init__(self):
        self.message = "Only wav files are supported"
        super().__init__(self.message)


class InvalidModelError(Exception):
    def __init__(self):
        self.message = "Selected model doesn't exist"
        super().__init__(self.message)


class MissingFileError(Exception):
    def __init__(self):
        self.message = "File cannot be None"
        super().__init__(self.message)


class PredictionRequest(BaseModel):
    model_name: str

    @validator("model_name")
    @classmethod
    def valid_model(cls, v):
        if v not in ml_models.keys():
            raise InvalidModelError
        return v


class PredictionResult(BaseModel):
    prediction: Dict[str, Dict[str, int]]


@app.exception_handler(RequestValidationError)
def validation_exception_handler(request, ex):
    logger.error(f"Request validation error: {ex}")
    return JSONResponse(content={"error": "Bad Request", "detail": ex.errors()}, status_code=400)


@app.exception_handler(InvalidFileTypeError)
def filetype_exception_handler(request, ex):
    logger.error(f"Invalid file type error: {ex}")
    return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)


@app.exception_handler(InvalidModelError)
def model_exception_handler(request, ex):
    logger.error(f"Invalid model error: {ex}")
    return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)


@app.exception_handler(MissingFileError)
def handle_missing_file_error(request, ex):
    logger.error(f"Missing file error: {ex}")
    return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)


@app.exception_handler(Exception)
def handle_exceptions(request, ex):
    logger.exception(f"Internal server error: {ex}")
    # If an exception occurs during processing, return a JSON response with an error message
    return JSONResponse(content={"error": "Internal Server Error", "detail": str(ex)}, status_code=500)


@app.get("/")
def root():
    logger.info("Received request to root endpoint")
    return {"message": "Welcome to my API. Go to /docs to view the documentation."}


@app.get("/health-check")
def health_check():
    """
    Health check endpoint to verify if the API is running.
    """
    logger.info("Health check endpoint was hit")
    return {"status": "API is running"}


@app.post("/predict")
def predict(request: PredictionRequest = Depends(), file: UploadFile = File(...)) -> PredictionResult:  # noqa
    if not file:
        raise MissingFileError
    if file.filename.split(".")[-1].lower() not in ALLOWED_FILE_FORMATS:
        raise InvalidFileTypeError
    logger.info(f"Prediction request received: {request}")
    output = ml_models[request.model_name].get_prediction(file.file)
    logger.info(f"Prediction result: {output}")
    prediction_result = PredictionResult(prediction={file.filename: output})

    return prediction_result