Spaces:
Runtime error
Runtime error
| # Imports | |
| import cv2 | |
| from detector import Detector | |
| from pathlib import Path | |
| import gradio as gr | |
| import os | |
| import numpy as np | |
| import csv | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| # Choose weights, names and config file | |
| chosen_weights = "cfg/vessels_tyv4.weights" | |
| chosen_config_file = "cfg/vessels_tyv4.cfg" | |
| chosen_names = "cfg/vessel.names" | |
| with open(chosen_names, "r") as f: | |
| classes = [line.strip() for line in f.readlines()] | |
| labels = np.array(classes, dtype=str) | |
| # Function for analyzing images | |
| def analyze_image(selected_image, chosen_conf_thresh, chosen_nms_thresh): | |
| # Delete existing output files | |
| if os.path.exists("Classes.csv"): | |
| os.remove("Classes.csv") | |
| if selected_image is None: | |
| raise RuntimeError("No image found!") | |
| print("Starting image scan") | |
| # Initialize the detector | |
| detector = Detector(weights = str(chosen_weights), config_file = str(chosen_config_file), classes_file = chosen_names, conf_thresh = chosen_conf_thresh, nms_thresh = chosen_nms_thresh) | |
| # Detect the image | |
| img_det, classes_id, scores, boxes = detector.detect(selected_image) | |
| class_names = [] | |
| for _class in classes_id: | |
| class_names.append(labels[_class]) | |
| tags = [] | |
| for i in range(len(class_names)): | |
| tags.append([str(class_names[i]), str(boxes[i][0]), str(boxes[i][1]), str(boxes[i][2]), str(boxes[i][3]), str(scores[i])]) | |
| print("Image scan finished succefully.") | |
| # Save tags in a csv file | |
| with open("Classes.csv", "w") as f: | |
| write = csv.writer(f) | |
| write.writerow(["Class", "X", "Y", "Width", "Height", "Score"]) | |
| write.writerows(tags) | |
| f.close() | |
| df = pd.DataFrame(tags, columns = ["Class", "X", "Y", "Width", "Height", "Score"]) | |
| return img_det, "Classes.csv", df.head(10) | |
| # Function for analyzing video | |
| def analyze_video(selected_video, chosen_conf_thresh, chosen_nms_thresh, start_sec, duration): | |
| # Delete existing output files | |
| if os.path.exists("demo_film.mp4"): | |
| os.remove("demo_film.mp4") | |
| if os.path.exists("output.mp4"): | |
| os.remove("output.mp4") | |
| if os.path.exists("Classes.csv"): | |
| os.remove("Classes.csv") | |
| if selected_video is None: | |
| raise RuntimeError("No video found!") | |
| print("Starting video scan") | |
| # Capture the video input | |
| video = cv2.VideoCapture(selected_video) | |
| ret, frame = video.read() | |
| if not ret: # Checking | |
| raise RuntimeError("Cannot read video stream!") | |
| # Calculate start and end frame | |
| total_frames = round(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = video.get(cv2.CAP_PROP_FPS) | |
| start_frame = round(start_sec * fps) | |
| end_frame = round(start_frame + (duration * fps)) | |
| # Initialize the detector | |
| detector = Detector(weights = str(chosen_weights), config_file = str(chosen_config_file), classes_file = chosen_names, conf_thresh = chosen_conf_thresh, nms_thresh = chosen_nms_thresh) | |
| frame_id = 0 | |
| vid_out = [] | |
| save_file_name = "demo_film.mp4" | |
| tags = [] | |
| unique_objects = [] | |
| if start_frame > total_frames: | |
| raise RuntimeError("Start second is out of bounds!") | |
| while True: | |
| # Read the image | |
| ret, frame = video.read() | |
| if not ret: # Error or end of stream heck | |
| break | |
| if frame is None: continue | |
| if start_frame <= frame_id <= end_frame: | |
| class_names = [] | |
| # Detect the image | |
| img_det, classes_id, scores, boxes = detector.detect(frame) | |
| for _class in classes_id: | |
| class_names.append(labels[_class]) | |
| if unique_objects.count(labels[_class]) == 0: | |
| unique_objects.append(labels[_class]) | |
| for i in range(len(class_names)): | |
| tags.append([str(class_names[i]), str(boxes[i][0]), str(boxes[i][1]), str(boxes[i][2]), str(boxes[i][3]), str(scores[i]), str(frame_id)]) | |
| # video writer | |
| if frame_id == start_frame: | |
| Height, Width = img_det.shape[:2] | |
| fps = video.get(cv2.CAP_PROP_FPS) if 15 < video.get(cv2.CAP_PROP_FPS) <= 30 else 15 | |
| vid_out = cv2.VideoWriter(save_file_name, cv2.VideoWriter_fourcc(*"MP4V"), fps, (Width, Height)) | |
| vid_out.write(img_det) | |
| if frame_id > end_frame: | |
| break | |
| frame_id += 1 | |
| # Release videos | |
| video.release() | |
| vid_out.release() | |
| # Save tags in a csv file | |
| with open("Classes.csv", "w") as f: | |
| write = csv.writer(f) | |
| write.writerow(["Class", "X", "Y", "Width", "Height", "Score","Frame"]) | |
| write.writerows(tags) | |
| f.close() | |
| if end_frame > total_frames: | |
| end_frame = total_frames | |
| # Create graph | |
| plt.switch_backend("agg") | |
| fig = plt.figure() | |
| df = pd.DataFrame(tags, columns = ["Class", "X", "Y", "Width", "Height", "Score", "Frame"]) | |
| # For every different object found, check how many times it appears in each frame | |
| for unique_object in unique_objects: | |
| object_array = df[df["Class"].str.fullmatch(unique_object)==True] | |
| obj_per_frame = [] | |
| for i in range(start_frame, end_frame + 1): | |
| temp_array = [] | |
| temp_array = object_array[object_array["Frame"].astype("str").str.fullmatch(str(i))==True] | |
| rows = temp_array.shape[0] | |
| obj_per_frame.append(rows) | |
| # Plot line graph for every individual object found | |
| plt.plot(list(range(start_frame, end_frame+1)), obj_per_frame, label = unique_object) | |
| plt.title("Objects per frame") | |
| plt.ylabel("Objects") | |
| plt.xlabel("Frame") | |
| plt.legend() | |
| print("Video scan finished succefully.") | |
| # Changes video fourcc to h264 so that it can be displayed in the browser | |
| os.system("ffmpeg -i demo_film.mp4 -vcodec libx264 -f mp4 output.mp4") | |
| return "output.mp4", fig, "Classes.csv", df.head(10) | |
| # Dradio interfaces take mandatory parameters an input function, the input type(s) and output type(s) | |
| # Demo is hosted on http://localhost:7860/ | |
| # Examples | |
| image_examples = [ | |
| ["examples/vessels.png", 0.20, 0.40], | |
| ["examples/boat.png", 0.25, 0.40], | |
| ] | |
| video_examples =[["examples/vessels.mp4", 0.25, 0.40, 0, 10]] | |
| # Image interface | |
| image_interface = gr.Interface(fn = analyze_image, | |
| inputs = [gr.Image(label = "Image"), | |
| gr.Slider(0, 1, value = 0.25, label = "Confidence Threshold"), | |
| gr.Slider(0, 1, value = 0.40, label = "Non Maxima Supression threshold")], | |
| outputs = [gr.Image(label="Image"), gr.File(label="All classes"), gr.Dataframe(label="Ten first classes", headers = ["Class", "X", "Y", "Width", "Height", "Score"])], | |
| allow_flagging = False, | |
| cache_examples = False, | |
| examples = image_examples) | |
| # Video interface | |
| video_interface = gr.Interface(fn = analyze_video, | |
| inputs = [gr.Video(label = "Video"), | |
| gr.Slider(0, 1, value = 0.25, label = "Confidence Threshold"), | |
| gr.Slider(0, 1, value = 0.40, label = "Non Maxima Supression threshold"), | |
| gr.Slider(0, 60, value = 0, label = "Start Second", step = 1), | |
| gr.Slider(1, 10, value = 4, label = "Duration", step = 1)], | |
| outputs = [gr.Video(label="Video"), gr.Plot(label="Objects per frame"), gr.File(label="All classes"), gr.Dataframe(label=" Ten first classes", headers = ["Class", "X", "Y", "Width", "Height", "Score", "Frame"])], | |
| allow_flagging = False, | |
| cache_examples = False, | |
| examples = video_examples) | |
| gr.TabbedInterface( | |
| [video_interface, image_interface], | |
| ["Scan Videos", "Scan Images"] | |
| ).launch() |