face-detector / main.py
osrokas's picture
new stuff
be14aa6
raw
history blame
2.16 kB
# Import Fast API
from fastapi import FastAPI, Request, UploadFile, File
from fastapi.templating import Jinja2Templates
from fastapi.responses import StreamingResponse
# Import bytes
from io import BytesIO
import os
# Import logging
import logging
# Import utilities
from src.utils.utils import IMAGE_FORMATS
# Import machine learning
from src.predict import predict
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
# Initialazing FastAPI application
app = FastAPI()
# Initialazing templates
templates = Jinja2Templates(directory="templates")
# Initialazing logger
logger = logging.getLogger(__name__)
logger.info(f"Loading YOLO model...")
# Download YOLO model from Hugging Face Hub
model_path = hf_hub_download(
repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt"
)
# Load YOLO model
model = YOLO(model_path)
# Index route
@app.get("/")
async def root(request: Request):
context = {"request": request}
# Render index.html
return templates.TemplateResponse("index.html", context)
# Upload images decorator
@app.post("/predict-img")
def predict_image(file: UploadFile = File(...)):
try:
# Try to read the file
contents = file.file.read()
# Open file and write contents
with open(file.filename, "wb") as f:
f.write(contents)
# Get image filename
image = file.filename
# Check if image format is valid
if not image.endswith(IMAGE_FORMATS):
# If not, raise an error
raise ValueError("Invalid image format")
except Exception as e:
# If there is an error, return the error
return {f"{e}"}
finally:
file.file.close()
# Getting image path
image = file.filename
# Predicting
results = predict(model, image)
# TODO
# extract extension from image and use it to save the image
# Convert image to bytes
img_bytes = BytesIO()
results.save(img_bytes, "JPEG")
img_bytes.seek(0)
# Removing the image
os.remove(image)
# Render image
return StreamingResponse(content=img_bytes, media_type="image/jpeg")