shark_detection / app.py
Alexander Fengler
add dashboard
588ce8d
raw
history blame
3.84 kB
import subprocess
import os
if os.getenv('SYSTEM') == 'spaces':
subprocess.call('pip install -U openmim'.split())
subprocess.call('pip install python-dotenv'.split())
subprocess.call('pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113'.split())
subprocess.call('mim install mmcv>=2.0.0'.split())
subprocess.call('mim install mmengine'.split())
subprocess.call('mim install mmdet'.split())
subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
import gradio as gr
from huggingface_hub import snapshot_download
import cv2
import dotenv
dotenv.load_dotenv()
import numpy as np
import gradio as gr
import glob
from inference import inference_frame,inference_frame_serial
from inference import inference_frame_par_ready
from inference import process_frame
from inference import classes
from inference import class_sizes_lower
from metrics import process_results_for_plot
from metrics import prediction_dashboard
import os
import pathlib
import multiprocessing as mp
from time import time
REPO_ID='SharkSpace/videos_examples'
snapshot_download(repo_id=REPO_ID, token=os.environ.get('SHARK_MODEL'),repo_type='dataset',local_dir='videos_example')
theme = gr.themes.Soft(
primary_hue="sky",
neutral_hue="slate",
)
def process_video(input_video, out_fps = 'auto', skip_frames = 7):
cap = cv2.VideoCapture(input_video)
output_path = "output.mp4"
if out_fps != 'auto' and type(out_fps) == int:
fps = int(out_fps)
else:
fps = int(cap.get(cv2.CAP_PROP_FPS))
if out_fps == 'auto':
fps = int(fps / skip_frames)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
iterating, frame = cap.read()
cnt = 0
while iterating:
if (cnt % skip_frames) == 0:
# flip frame vertically
display_frame, result = inference_frame_serial(frame)
video.write(cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB))
#print(result)
top_pred = process_results_for_plot(predictions = result.numpy(),
classes = classes,
class_sizes = class_sizes_lower)
pred_dashbord = prediction_dashboard(top_pred = top_pred)
#print('sending frame')
print(cnt)
yield cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), None, pred_dashbord
cnt += 1
iterating, frame = cap.read()
video.release()
yield None, None, output_path, None
with gr.Blocks(theme=theme) as demo:
with gr.Row():
input_video = gr.Video(label="Input")
output_video = gr.Video(label="Output Video")
with gr.Row():
original_frames = gr.Image(label="Original Frame")
dashboard = gr.Image(label="Dashboard")
processed_frames = gr.Image(label="Shark Engine")
with gr.Row():
paths = sorted(pathlib.Path('videos_example/').rglob('*.mp4'))
samples=[[path.as_posix()] for path in paths if 'raw_videos' in str(path)]
examples = gr.Examples(samples, inputs=input_video)
process_video_btn = gr.Button("Process Video")
process_video_btn.click(process_video, input_video, [processed_frames, original_frames, output_video, dashboard])
demo.queue()
if os.getenv('SYSTEM') == 'spaces':
demo.launch(width='40%',auth=(os.environ.get('SHARK_USERNAME'), os.environ.get('SHARK_PASSWORD')))
else:
demo.launch()