Charalambos Georgiades
Added app.py
8c6ff1d
raw
history blame
7.8 kB
# Imports
import cv2
from detector import Detector
from pathlib import Path
import gradio as gr
import os
import numpy as np
import csv
import matplotlib.pyplot as plt
import pandas as pd
# Choose weights, names and config file
chosen_weights = "cfg/vessels_tyv4.weights"
chosen_config_file = "cfg/vessels_tyv4.cfg"
chosen_names = "cfg/vessel.names"
with open(chosen_names, "r") as f:
classes = [line.strip() for line in f.readlines()]
labels = np.array(classes, dtype=str)
# Function for analyzing images
def analyze_image(selected_image, chosen_conf_thresh, chosen_nms_thresh):
# Delete existing output files
if os.path.exists("Classes.csv"):
os.remove("Classes.csv")
if selected_image is None:
raise RuntimeError("No image found!")
print("Starting image scan")
# Initialize the detector
detector = Detector(weights = str(chosen_weights), config_file = str(chosen_config_file), classes_file = chosen_names, conf_thresh = chosen_conf_thresh, nms_thresh = chosen_nms_thresh)
# Detect the image
img_det, classes_id, scores, boxes = detector.detect(selected_image)
class_names = []
for _class in classes_id:
class_names.append(labels[_class])
tags = []
for i in range(len(class_names)):
tags.append([str(class_names[i]), str(boxes[i][0]), str(boxes[i][1]), str(boxes[i][2]), str(boxes[i][3]), str(scores[i])])
print("Image scan finished succefully.")
# Save tags in a csv file
with open("Classes.csv", "w") as f:
write = csv.writer(f)
write.writerow(["Class", "X", "Y", "Width", "Height", "Score"])
write.writerows(tags)
f.close()
return img_det, "Classes.csv", tags[:10]
# Function for analyzing video
def analyze_video(selected_video, chosen_conf_thresh, chosen_nms_thresh, start_sec, duration):
# Delete existing output files
if os.path.exists("demo_film.mp4"):
os.remove("demo_film.mp4")
if os.path.exists("output.mp4"):
os.remove("output.mp4")
if os.path.exists("Classes.csv"):
os.remove("Classes.csv")
if selected_video is None:
raise RuntimeError("No video found!")
print("Starting video scan")
# Capture the video input
video = cv2.VideoCapture(selected_video)
ret, frame = video.read()
if not ret: # Checking
raise RuntimeError("Cannot read video stream!")
# Calculate start and end frame
total_frames = round(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = video.get(cv2.CAP_PROP_FPS)
start_frame = round(start_sec * fps)
end_frame = round(start_frame + (duration * fps))
# Initialize the detector
detector = Detector(weights = str(chosen_weights), config_file = str(chosen_config_file), classes_file = chosen_names, conf_thresh = chosen_conf_thresh, nms_thresh = chosen_nms_thresh)
frame_id = 0
vid_out = []
save_file_name = "demo_film.mp4"
tags = []
unique_objects = []
if start_frame > total_frames:
raise RuntimeError("Start second is out of bounds!")
while True:
# Read the image
ret, frame = video.read()
if not ret: # Error or end of stream heck
break
if frame is None: continue
if start_frame <= frame_id <= end_frame:
class_names = []
# Detect the image
img_det, classes_id, scores, boxes = detector.detect(frame)
for _class in classes_id:
class_names.append(labels[_class])
if unique_objects.count(labels[_class]) == 0:
unique_objects.append(labels[_class])
for i in range(len(class_names)):
tags.append([str(class_names[i]), str(boxes[i][0]), str(boxes[i][1]), str(boxes[i][2]), str(boxes[i][3]), str(scores[i]), str(frame_id)])
# video writer
if frame_id == start_frame:
Height, Width = img_det.shape[:2]
fps = video.get(cv2.CAP_PROP_FPS) if 15 < video.get(cv2.CAP_PROP_FPS) <= 30 else 15
vid_out = cv2.VideoWriter(save_file_name, cv2.VideoWriter_fourcc(*"MP4V"), fps, (Width, Height))
vid_out.write(img_det)
if frame_id > end_frame:
break
frame_id += 1
# Release videos
video.release()
vid_out.release()
with open("Classes.csv", "w") as f:
write = csv.writer(f)
write.writerow(["Class", "X", "Y", "Width", "Height", "Score","Frame"])
write.writerows(tags)
f.close()
plt.switch_backend("agg")
if end_frame > total_frames:
end_frame = total_frames
fig = plt.figure()
df = pd.DataFrame(tags, columns = ["Class", "X", "Y", "Width", "Height", "Score", "Frame"])
for unique_object in unique_objects:
object_array = df[df["Class"].str.fullmatch(unique_object)==True]
obj_per_frame = []
for i in range(start_frame, end_frame + 1):
temp_array = []
temp_array = object_array[object_array["Frame"].astype("str").str.fullmatch(str(i))==True]
rows = temp_array.shape[0]
obj_per_frame.append(rows)
plt.plot(list(range(start_frame, end_frame+1)), obj_per_frame, label = unique_object)
plt.title("Objects per frame")
plt.ylabel("Objects")
plt.xlabel("Frame")
plt.legend()
print("Video scan finished succefully.")
# Changes video fourcc to h264 so that it can be displayed in the browser
os.system("ffmpeg -i demo_film.mp4 -vcodec libx264 -f mp4 output.mp4")
return "output.mp4", fig, "Classes.csv", tags[:10]
# Dradio interfaces take mandatory parameters an input function, the input type(s) and output type(s)
# Demo is hosted on http://localhost:7860/
# Examples
image_examples = [
["examples/horses.png", 0.20, 0.40],
["examples/basketball.png", 0.25, 0.40],
]
video_examples =[["examples/scene_from_series.mp4", 0.25, 0.40, 0, 10]]
# Image interface
image_interface = gr.Interface(fn = analyze_image,
inputs = [gr.Image(label = "Image"),
gr.Slider(0, 1, value = 0.25, label = "Confidence Threshold"),
gr.Slider(0, 1, value = 0.40, label = "Non Maxima Supression threshold")],
outputs = [gr.Image(label="Image"), gr.File(label="All classes"), gr.Dataframe(label="Ten first classes", headers=["Class", "X", "Y", "Width", "Height", "Score"])],
allow_flagging = False,
cache_examples = False,
examples = image_examples)
# Video interface
video_interface = gr.Interface(fn = analyze_video,
inputs = [gr.Video(label = "Video"),
gr.Slider(0, 1, value = 0.25, label = "Confidence Threshold"),
gr.Slider(0, 1, value = 0.40, label = "Non Maxima Supression threshold"),
gr.Slider(0, 60, value = 0, label = "Start Second", step = 1),
gr.Slider(1, 10, value = 4, label = "Duration", step = 1)],
outputs = [gr.Video(label="Video"), gr.Plot(label="Objects per frame"), gr.File(label="All classes"), gr.Dataframe(label=" Ten first classes", headers=["Class", "X", "Y", "Width", "Height", "Score", "Frame"])],
allow_flagging = False,
cache_examples = False,
examples = video_examples)
gr.TabbedInterface(
[video_interface, image_interface],
["Scan Videos", "Scan Images"]
).launch()