Spaces:
Runtime error
Runtime error
# Imports | |
import cv2 | |
from detector import Detector | |
from pathlib import Path | |
import gradio as gr | |
import os | |
import numpy as np | |
import csv | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
# Choose weights, names and config file | |
chosen_weights = "cfg/vessels_tyv4.weights" | |
chosen_config_file = "cfg/vessels_tyv4.cfg" | |
chosen_names = "cfg/vessel.names" | |
with open(chosen_names, "r") as f: | |
classes = [line.strip() for line in f.readlines()] | |
labels = np.array(classes, dtype=str) | |
# Function for analyzing images | |
def analyze_image(selected_image, chosen_conf_thresh, chosen_nms_thresh): | |
# Delete existing output files | |
if os.path.exists("Classes.csv"): | |
os.remove("Classes.csv") | |
if selected_image is None: | |
raise RuntimeError("No image found!") | |
print("Starting image scan") | |
# Initialize the detector | |
detector = Detector(weights = str(chosen_weights), config_file = str(chosen_config_file), classes_file = chosen_names, conf_thresh = chosen_conf_thresh, nms_thresh = chosen_nms_thresh) | |
# Detect the image | |
img_det, classes_id, scores, boxes = detector.detect(selected_image) | |
class_names = [] | |
for _class in classes_id: | |
class_names.append(labels[_class]) | |
tags = [] | |
for i in range(len(class_names)): | |
tags.append([str(class_names[i]), str(boxes[i][0]), str(boxes[i][1]), str(boxes[i][2]), str(boxes[i][3]), str(scores[i])]) | |
print("Image scan finished succefully.") | |
# Save tags in a csv file | |
with open("Classes.csv", "w") as f: | |
write = csv.writer(f) | |
write.writerow(["Class", "X", "Y", "Width", "Height", "Score"]) | |
write.writerows(tags) | |
f.close() | |
return img_det, "Classes.csv", tags[:10] | |
# Function for analyzing video | |
def analyze_video(selected_video, chosen_conf_thresh, chosen_nms_thresh, start_sec, duration): | |
# Delete existing output files | |
if os.path.exists("demo_film.mp4"): | |
os.remove("demo_film.mp4") | |
if os.path.exists("output.mp4"): | |
os.remove("output.mp4") | |
if os.path.exists("Classes.csv"): | |
os.remove("Classes.csv") | |
if selected_video is None: | |
raise RuntimeError("No video found!") | |
print("Starting video scan") | |
# Capture the video input | |
video = cv2.VideoCapture(selected_video) | |
ret, frame = video.read() | |
if not ret: # Checking | |
raise RuntimeError("Cannot read video stream!") | |
# Calculate start and end frame | |
total_frames = round(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = video.get(cv2.CAP_PROP_FPS) | |
start_frame = round(start_sec * fps) | |
end_frame = round(start_frame + (duration * fps)) | |
# Initialize the detector | |
detector = Detector(weights = str(chosen_weights), config_file = str(chosen_config_file), classes_file = chosen_names, conf_thresh = chosen_conf_thresh, nms_thresh = chosen_nms_thresh) | |
frame_id = 0 | |
vid_out = [] | |
save_file_name = "demo_film.mp4" | |
tags = [] | |
unique_objects = [] | |
if start_frame > total_frames: | |
raise RuntimeError("Start second is out of bounds!") | |
while True: | |
# Read the image | |
ret, frame = video.read() | |
if not ret: # Error or end of stream heck | |
break | |
if frame is None: continue | |
if start_frame <= frame_id <= end_frame: | |
class_names = [] | |
# Detect the image | |
img_det, classes_id, scores, boxes = detector.detect(frame) | |
for _class in classes_id: | |
class_names.append(labels[_class]) | |
if unique_objects.count(labels[_class]) == 0: | |
unique_objects.append(labels[_class]) | |
for i in range(len(class_names)): | |
tags.append([str(class_names[i]), str(boxes[i][0]), str(boxes[i][1]), str(boxes[i][2]), str(boxes[i][3]), str(scores[i]), str(frame_id)]) | |
# video writer | |
if frame_id == start_frame: | |
Height, Width = img_det.shape[:2] | |
fps = video.get(cv2.CAP_PROP_FPS) if 15 < video.get(cv2.CAP_PROP_FPS) <= 30 else 15 | |
vid_out = cv2.VideoWriter(save_file_name, cv2.VideoWriter_fourcc(*"MP4V"), fps, (Width, Height)) | |
vid_out.write(img_det) | |
if frame_id > end_frame: | |
break | |
frame_id += 1 | |
# Release videos | |
video.release() | |
vid_out.release() | |
with open("Classes.csv", "w") as f: | |
write = csv.writer(f) | |
write.writerow(["Class", "X", "Y", "Width", "Height", "Score","Frame"]) | |
write.writerows(tags) | |
f.close() | |
plt.switch_backend("agg") | |
if end_frame > total_frames: | |
end_frame = total_frames | |
fig = plt.figure() | |
df = pd.DataFrame(tags, columns = ["Class", "X", "Y", "Width", "Height", "Score", "Frame"]) | |
for unique_object in unique_objects: | |
object_array = df[df["Class"].str.fullmatch(unique_object)==True] | |
obj_per_frame = [] | |
for i in range(start_frame, end_frame + 1): | |
temp_array = [] | |
temp_array = object_array[object_array["Frame"].astype("str").str.fullmatch(str(i))==True] | |
rows = temp_array.shape[0] | |
obj_per_frame.append(rows) | |
plt.plot(list(range(start_frame, end_frame+1)), obj_per_frame, label = unique_object) | |
plt.title("Objects per frame") | |
plt.ylabel("Objects") | |
plt.xlabel("Frame") | |
plt.legend() | |
print("Video scan finished succefully.") | |
# Changes video fourcc to h264 so that it can be displayed in the browser | |
os.system("ffmpeg -i demo_film.mp4 -vcodec libx264 -f mp4 output.mp4") | |
return "output.mp4", fig, "Classes.csv", tags[:10] | |
# Dradio interfaces take mandatory parameters an input function, the input type(s) and output type(s) | |
# Demo is hosted on http://localhost:7860/ | |
# Examples | |
image_examples = [ | |
["examples/horses.png", 0.20, 0.40], | |
["examples/basketball.png", 0.25, 0.40], | |
] | |
video_examples =[["examples/scene_from_series.mp4", 0.25, 0.40, 0, 10]] | |
# Image interface | |
image_interface = gr.Interface(fn = analyze_image, | |
inputs = [gr.Image(label = "Image"), | |
gr.Slider(0, 1, value = 0.25, label = "Confidence Threshold"), | |
gr.Slider(0, 1, value = 0.40, label = "Non Maxima Supression threshold")], | |
outputs = [gr.Image(label="Image"), gr.File(label="All classes"), gr.Dataframe(label="Ten first classes", headers=["Class", "X", "Y", "Width", "Height", "Score"])], | |
allow_flagging = False, | |
cache_examples = False, | |
examples = image_examples) | |
# Video interface | |
video_interface = gr.Interface(fn = analyze_video, | |
inputs = [gr.Video(label = "Video"), | |
gr.Slider(0, 1, value = 0.25, label = "Confidence Threshold"), | |
gr.Slider(0, 1, value = 0.40, label = "Non Maxima Supression threshold"), | |
gr.Slider(0, 60, value = 0, label = "Start Second", step = 1), | |
gr.Slider(1, 10, value = 4, label = "Duration", step = 1)], | |
outputs = [gr.Video(label="Video"), gr.Plot(label="Objects per frame"), gr.File(label="All classes"), gr.Dataframe(label=" Ten first classes", headers=["Class", "X", "Y", "Width", "Height", "Score", "Frame"])], | |
allow_flagging = False, | |
cache_examples = False, | |
examples = video_examples) | |
gr.TabbedInterface( | |
[video_interface, image_interface], | |
["Scan Videos", "Scan Images"] | |
).launch() |