birdcount / app.py
pyresearch's picture
Update app.py
d27251e verified
import cv2
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from typing import Generator
from ultralytics import YOLO
import numpy as np
app = FastAPI()
# Load the YOLOv8 model
model = YOLO("yolov8l.pt")
# Open the video file
video_path = "demo.mp4"
cap = cv2.VideoCapture(video_path)
bird_count = 0
tracker_initialized = False
# Initialize trackers based on OpenCV version
try:
if hasattr(cv2, 'legacy'):
trackers = cv2.legacy.MultiTracker_create()
else:
trackers = cv2.MultiTracker_create()
except AttributeError:
trackers = None
tracker_initialized = False
def process_video() -> Generator[bytes, None, None]:
global bird_count, tracker_initialized, trackers
while cap.isOpened():
success, frame = cap.read()
if not success:
break
frame_height, frame_width = frame.shape[:2]
if not tracker_initialized:
results = model(frame)
detections = results[0].boxes.data.cpu().numpy()
bird_results = [detection for detection in detections if int(detection[5]) == 14]
try:
if hasattr(cv2, 'legacy'):
trackers = cv2.legacy.MultiTracker_create()
else:
trackers = cv2.MultiTracker_create()
for res in bird_results:
x1, y1, x2, y2, confidence, class_id = res
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
if 0 <= x1 < frame_width and 0 <= y1 < frame_height and x2 <= frame_width and y2 <= frame_height:
bbox = (x1, y1, x2 - x1, y2 - y1)
tracker = cv2.legacy.TrackerCSRT_create() if hasattr(cv2, 'legacy') else cv2.TrackerCSRT_create()
trackers.add(tracker, frame, bbox)
bird_count = len(bird_results)
tracker_initialized = True
except AttributeError:
trackers = None
tracker_initialized = False
else:
success, boxes = trackers.update(frame)
if success:
bird_count = len(boxes)
for box in boxes:
x, y, w, h = [int(v) for v in box]
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
cv2.putText(frame, 'bird', (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
else:
tracker_initialized = False
ret, buffer = cv2.imencode('.jpg', frame)
frame = buffer.tobytes()
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n')
cap.release()
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request, "bird_count": bird_count})
@app.get("/video_feed")
async def video_feed():
return StreamingResponse(process_video(), media_type='multipart/x-mixed-replace; boundary=frame')
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)