from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import cv2
import numpy as np
from tqdm import tqdm
import tempfile

# Function to load the model
def load_model(repo_id):
    """Download and load the YOLO model."""
    download_dir = snapshot_download(repo_id)
    path = os.path.join(download_dir, "best_int8_openvino_model")
    detection_model = YOLO(path, task="detect")
    return detection_model

# Function to process an image
def predict_image(pilimg, conf_threshold, iou_threshold):
    """Process an image with user-defined thresholds."""
    try:
        result = detection_model.predict(pilimg, conf=conf_threshold, iou=iou_threshold)
        img_bgr = result[0].plot()
        out_pilimg = Image.fromarray(img_bgr[..., ::-1])  # Convert to RGB PIL image
        return out_pilimg
    except Exception as e:
        return f"Error processing image: {e}"

# Function to process a video
def predict_video(video_file, conf_threshold, iou_threshold, start_time, end_time):
    """Process a video and return the path for displaying."""
    cap = cv2.VideoCapture(video_file)

    if not cap.isOpened():
        return "Error: Unable to open the video file."

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Use a temporary file to store the processed video
    temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
    output_path = temp_video_file.name

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    start_frame = int(start_time * fps) if start_time else 0
    end_frame = int(end_time * fps) if end_time else total_frames

    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

    with tqdm(total=end_frame - start_frame, desc="Processing Video") as pbar:
        while cap.isOpened():
            current_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
            if current_frame >= end_frame:
                break

            ret, frame = cap.read()
            if not ret:
                break

            resized_frame = cv2.resize(frame, (640, 640))  # Resize for inference
            result = detection_model.predict(resized_frame, conf=conf_threshold, iou=iou_threshold)
            output_frame = result[0].plot()
            output_frame = cv2.resize(output_frame, (frame_width, frame_height))  # Restore size

            out.write(output_frame)
            pbar.update(1)

    cap.release()
    out.release()
    return output_path

# Load YOLO model
REPO_ID = "Ganrong/107project"
detection_model = load_model(REPO_ID)

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Pangolin and Axolotl Detection")

    # Image Processing Tab
    with gr.Tab("Image Input"):
        img_input = gr.Image(type="pil", label="Upload an Image")
        conf_slider_img = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
        iou_slider_img = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold")
        img_output = gr.Image(type="pil", label="Processed Image")
        img_submit = gr.Button("Process Image")

        img_submit.click(
            predict_image,
            inputs=[img_input, conf_slider_img, iou_slider_img],
            outputs=img_output
        )

    # Video Processing Tab
    with gr.Tab("Video Input"):
        video_input = gr.Video(label="Upload a Video")
        conf_slider_video = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
        iou_slider_video = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold")
        start_time = gr.Number(value=0, label="Start Time (seconds)")
        end_time = gr.Number(value=0, label="End Time (seconds, 0 for full video)")
        video_output = gr.Video(label="Processed Video")
        video_submit = gr.Button("Process Video")

        video_submit.click(
            predict_video,
            inputs=[video_input, conf_slider_video, iou_slider_video, start_time, end_time],
            outputs=video_output
        )

demo.launch(share=True)