File size: 2,264 Bytes
be14aa6
 
 
794e3ef
be14aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794e3ef
be14aa6
 
 
 
794e3ef
 
be14aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794e3ef
 
 
 
 
 
be14aa6
794e3ef
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Import Fast API
from fastapi import FastAPI, Request, UploadFile, File
from fastapi.templating import Jinja2Templates
import base64

# Import bytes
from io import BytesIO
import os

# Import logging
import logging

# Import utilities
from src.utils.utils import IMAGE_FORMATS

# Import machine learning
from src.predict import predict
from ultralytics import YOLO
from huggingface_hub import hf_hub_download


# Initialazing FastAPI application
app = FastAPI()

# Initialazing templates
templates = Jinja2Templates(directory="templates")

# Initialazing logger
logger = logging.getLogger(__name__)


logger.info(f"Loading YOLO model...")

# Download YOLO model from Hugging Face Hub
model_path = hf_hub_download(
    repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt"
)


# Load YOLO model
model = YOLO(model_path)


# Index route
@app.get("/")
async def root(request: Request):
    # Render index.html
    return templates.TemplateResponse("index.html", {"request": request})


# Upload images decorator
@app.post("/predict-img")
def predict_image(request: Request, file: UploadFile = File(...)):

    try:
        # Try to read the file
        contents = file.file.read()

        # Open file and write contents
        with open(file.filename, "wb") as f:
            f.write(contents)

        # Get image filename
        image = file.filename

        # Check if image format is valid
        if not image.endswith(IMAGE_FORMATS):
            # If not, raise an error
            raise ValueError("Invalid image format")

    except Exception as e:
        # If there is an error, return the error
        return {f"{e}"}

    finally:
        file.file.close()

    # Getting image path
    image = file.filename

    # Predicting
    results = predict(model, image)

    # TODO
    # extract extension from image and use it to save the image
    # Convert image to bytes
    img_bytes = BytesIO()
    results.save(img_bytes, "JPEG")
    img_bytes.seek(0)

    img_bytes = base64.b64encode(img_bytes.getvalue()).decode()

    try:
        os.remove(image)
    except Exception as e:
        logging.error(f"Error deleting image: {e}")

    return templates.TemplateResponse(
        "index.html", {"request": request, "img": img_bytes}
    )