Spaces:
Running
Running
# Import Fast API | |
from fastapi import FastAPI, Request, UploadFile, File | |
from fastapi.templating import Jinja2Templates | |
from fastapi.responses import StreamingResponse | |
# Import bytes | |
from io import BytesIO | |
import os | |
# Import logging | |
import logging | |
# Import utilities | |
from src.utils.utils import IMAGE_FORMATS | |
# Import machine learning | |
from src.predict import predict | |
from ultralytics import YOLO | |
from huggingface_hub import hf_hub_download | |
# Initialazing FastAPI application | |
app = FastAPI() | |
# Initialazing templates | |
templates = Jinja2Templates(directory="templates") | |
# Initialazing logger | |
logger = logging.getLogger(__name__) | |
logger.info(f"Loading YOLO model...") | |
# Download YOLO model from Hugging Face Hub | |
model_path = hf_hub_download( | |
repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt" | |
) | |
# Load YOLO model | |
model = YOLO(model_path) | |
# Index route | |
async def root(request: Request): | |
context = {"request": request} | |
# Render index.html | |
return templates.TemplateResponse("index.html", context) | |
# Upload images decorator | |
def predict_image(file: UploadFile = File(...)): | |
try: | |
# Try to read the file | |
contents = file.file.read() | |
# Open file and write contents | |
with open(file.filename, "wb") as f: | |
f.write(contents) | |
# Get image filename | |
image = file.filename | |
# Check if image format is valid | |
if not image.endswith(IMAGE_FORMATS): | |
# If not, raise an error | |
raise ValueError("Invalid image format") | |
except Exception as e: | |
# If there is an error, return the error | |
return {f"{e}"} | |
finally: | |
file.file.close() | |
# Getting image path | |
image = file.filename | |
# Predicting | |
results = predict(model, image) | |
# TODO | |
# extract extension from image and use it to save the image | |
# Convert image to bytes | |
img_bytes = BytesIO() | |
results.save(img_bytes, "JPEG") | |
img_bytes.seek(0) | |
# Removing the image | |
os.remove(image) | |
# Render image | |
return StreamingResponse(content=img_bytes, media_type="image/jpeg") | |