aus10powell commited on
Commit
8b9234c
β€’
1 Parent(s): 2d7e7e3

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +133 -152
inference.py CHANGED
@@ -1,168 +1,149 @@
1
- # This code is written at BigVision LLC. It is based on the OpenCV project. It is subject to the license terms in the LICENSE file found in this distribution and at http://opencv.org/license.html
 
2
 
3
- # Usage example: python3 object_detection_yolo.py --video=run.mp4 --device 'cpu'
4
- # python3 object_detection_yolo.py --video=run.mp4 --device 'gpu'
5
- # python3 object_detection_yolo.py --image=bird.jpg --device 'cpu'
6
- # python3 object_detection_yolo.py --image=bird.jpg --device 'gpu'
7
- # python3 test2.py --image=/Users/apowell/Downloads/HerringInTrap.JPG --device 'cpu'
8
- # python3 test2.py --video=/Users/apowell/Downloads/sampleFull.avi --device 'cpu'
9
- # sampleFull.avi
10
 
11
- import cv2 as cv
 
 
 
12
  import argparse
13
  import sys
14
- import numpy as np
15
  import os.path
16
- import os
17
- import matplotlib
18
- import streamlit as st
19
-
20
- matplotlib.use("Agg")
21
- from inference_utils import *
22
- from PIL import Image, ImageOps
23
  import logging
24
 
25
- # Custom
26
- from centroidtracker import CentroidTracker
27
-
28
- # Set default static images for testing while working locally
29
- DEFAULT_IMAGE = "/Users/apowell/Downloads/HerringInTrap.JPG"
30
- DEFAULT_VIDEO = "/Users/apowell/Downloads/sampleFull.avi"
31
- YOUTUBE = "https://www.youtube.com/watch?v=CbB7vl_HUbU&ab_channel=AustinPowell"
32
-
33
-
34
- def main(input_file=None, is_image=False, device="cpu"):
35
- """
36
- Run main inference script. Returns annotated frames from inference and counts of fish.
37
-
38
- Args:
39
- - input_file: image or video file input from OpenCV
40
- - is_image: Binary denoting single image
41
- - device: CPU or GPU processing
42
- """
43
- ## Initialize the parameters
44
- # Confidence threshold
45
- conf_threshold = 0.5
46
- # Non-maximum suppression threshold (maximum bounding box)
47
- nms_threshold = 0.05
48
- input_width = 416 # Width of network's input image
49
- input_height = 416 # Height of network's input image
50
-
51
- # Generic name assignment for output file
52
- outputFile = "yolo2_out_py.mp4"
53
- # Load class name
54
- classes = "Herring"
55
- # Give the configuration and weight files for the model and load the network using them.
56
- modelConfiguration = "herring.cfg"
57
- modelWeights = "herring_final.weights"
58
-
59
- # Centroid tracker to Id specific objects (NOTE: This is temporary and not fully tested)
60
- tracker = CentroidTracker(maxDisappeared=80, maxDistance=90)
61
-
62
- # Process inputs
63
- if (
64
- type(input_file) == cv.VideoCapture
65
- ): # Video objects passed from something like Streamlit
66
- cap = input_file
67
- elif type(input_file) == str: # For local uploads
68
- cap = cv.VideoCapture(input_file)
69
- logging.info("INFO: Loading file locally: {}".format(input_file))
70
- else:
71
- sys.exit(
72
- "Input file is of type {} and not solved for.".format(type(input_file))
73
- )
74
 
75
- net = cv.dnn.readNetFromDarknet(modelConfiguration, modelWeights)
76
-
77
- # Get the video writer initialized to save the output video
78
- vid_writer = cv.VideoWriter(
79
- outputFile,
80
- cv.VideoWriter_fourcc("M", "J", "P", "G"),
81
- 30,
82
- (
83
- round(cap.get(cv.CAP_PROP_FRAME_WIDTH)),
84
- round(cap.get(cv.CAP_PROP_FRAME_HEIGHT)),
85
- ),
86
- )
87
-
88
- total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
89
- video_fps = cap.get(cv.CAP_PROP_FPS)
90
- logging.info(
91
- "INFO: Starting inference process on video frames. Total: {}, fps: {}".format(
92
- total_frames, video_fps
93
- )
94
- )
95
- timestamps = [cap.get(cv.CAP_PROP_POS_MSEC)] # Timestamp for frame
96
- calc_timestamps = [0.0] # Relative timestamps to first timestamp
97
- saved_frames = [] # Save CV2 frames
98
- count_list = []
99
- while cap.isOpened():
100
-
101
- # Get frame from the video
102
- hasFrame, frame = cap.read()
103
-
104
- # Stop the program if reached end of video
105
- if not hasFrame:
106
- print("Done processing !!!")
107
- print("Output file is stored as ", outputFile)
108
- # Release device
109
- cap.release()
110
- break
111
-
112
- # Create a 4D blob from a frame.
113
- blob = cv.dnn.blobFromImage(
114
- frame, 1 / 255, (input_width, input_height), [0, 0, 0], 1, crop=False
 
 
 
 
 
 
 
 
 
 
115
  )
116
 
117
- # Sets the input to the network
118
- net.setInput(blob)
119
-
120
- # Runs the forward pass to get output of the output layers
121
- outs = net.forward(get_outputs_names(net=net))
122
-
123
- # Remove the bounding boxes with low confidence
124
- counts = postprocess(
125
- frame=frame,
126
- outs=outs,
127
- tracker=tracker,
128
- conf_threshold=conf_threshold,
129
- nms_threshold=nms_threshold,
130
- classes=classes,
131
  )
132
- count_list.append(counts)
133
 
134
- # Put efficiency information. The function getPerfProfile returns the overall time for inference(t) and the timings for each of the layers(in layersTimes)
135
- t, _ = net.getPerfProfile()
136
- label = "Inference time: %.2f ms" % (t * 1000.0 / cv.getTickFrequency())
137
- cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- # Save frame
140
- saved_frames.append(frame.astype(np.uint8)) # )
141
-
142
- # Write the frame with the detection boxes
143
- if is_image:
144
- cv.imwrite(outputFile, frame.astype(np.uint8))
145
  else:
146
- vid_writer.write(frame.astype(np.uint8))
147
-
148
- timestamps.append(cap.get(cv.CAP_PROP_POS_MSEC))
149
- calc_timestamps.append(calc_timestamps[-1] + 1000 / video_fps)
150
- # Calculate time difference for different timestamps
151
- time_diffs = [
152
- abs(ts - cts) for i, (ts, cts) in enumerate(zip(timestamps, calc_timestamps))
153
- ]
154
-
155
- with open("your_file.csv", "w") as f:
156
- for i in range(len(count_list)):
157
- f.write(
158
- f"{count_list[i]}, {time_diffs[i+1]}, {timestamps[i]}, {calc_timestamps[i]}\n"
159
- )
160
-
161
- return saved_frames, count_list, timestamps
162
-
163
 
 
164
  if __name__ == "__main__":
165
-
166
- # Script below to enable running pure inference from command line
167
- file_path = "/Users/apowell/Downloads/2_2018-04-27_15-50-53.mp4"
168
- saved_frames, counts, timestamps = main(input_file=file_path)
 
1
+ # Import necessary libraries
2
+ import matplotlib
3
 
4
+ # Use Agg backend for Matplotlib
5
+ matplotlib.use("Agg")
 
 
 
 
 
6
 
7
+ # Libraries for the app
8
+ import streamlit as st
9
+ import time
10
+ import io
11
  import argparse
12
  import sys
 
13
  import os.path
14
+ import subprocess
15
+ import tempfile
 
 
 
 
 
16
  import logging
17
 
18
+ # Visualization libraries
19
+ import altair as alt
20
+ import av
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Machine Learning and Image Processing libraries
23
+ import numpy as np
24
+ import pandas as pd
25
+ import cv2 as cv
26
+ from PIL import Image, ImageOps
27
+ from tqdm import tqdm
28
+
29
+ # Custom modules
30
+ import inference
31
+ from app_utils import *
32
+
33
+ @st.cache_data
34
+ def load_video(video_url):
35
+ video_bytes = open(video_url, "rb").read()
36
+ return video_bytes
37
+
38
+ @st.cache_data
39
+ def load_historical(fpath):
40
+ return pd.read_csv(fpath)
41
+
42
+ st.set_page_config(layout="wide")
43
+
44
+ # Define the main function to run the Streamlit app
45
+ def run_app():
46
+ # Set Streamlit options
47
+ st.set_option("deprecation.showfileUploaderEncoding", False)
48
+
49
+ # App title and description
50
+ st.title("MIT Count Fish Counter")
51
+ st.text("Upload a video file to detect and count fish")
52
+
53
+ # Example video URL or file path (replace with actual video URL or file path)
54
+ video_url = "yolo2_out_py.mp4"
55
+ video_bytes = load_video(video_url)
56
+
57
+ # Load historical herring
58
+ df_historical_herring = load_historical(fpath="herring_count_all.csv")
59
+
60
+ tab1, map_tab = st.tabs(["πŸ“ˆ Chart", "Map of Fishery Locations"])
61
+
62
+ # Create two columns for layout
63
+ col1, col2 = st.columns(2)
64
+
65
+ ## Col1 #########################################
66
+ with col1:
67
+ ## Initial visualizations
68
+ # Plot historical data
69
+ st.altair_chart(
70
+ plot_historical_data(df_historical_herring),
71
+ use_container_width=True,
72
  )
73
 
74
+ # Display map of fishery locations
75
+ st.subheader("Map of Fishery Locations")
76
+ st.map(
77
+ pd.DataFrame(
78
+ np.random.randn(5, 2) / [50, 50] + [42.41, -71.38],
79
+ columns=["lat", "lon"],
80
+ )
 
 
 
 
 
 
 
81
  )
 
82
 
83
+ ## Col2 #########################################
84
+ with col2:
85
+ # Display example processed video
86
+ st.subheader("Example of processed video")
87
+ st.video(video_bytes)
88
+ st.subheader("Upload your own video...")
89
+
90
+ # Initialize accepted file types for upload
91
+ img_types = ["jpg", "png", "jpeg"]
92
+ video_types = ["mp4", "avi"]
93
+
94
+ # Allow user to upload an image or video file
95
+ uploaded_file = st.file_uploader("Select an image or video file...", type=img_types + video_types)
96
+
97
+ # Display the uploaded file
98
+ if uploaded_file is not None:
99
+ if str(uploaded_file.type).split("/")[-1] in img_types:
100
+ # Display uploaded image
101
+ image = Image.open(uploaded_file)
102
+ st.image(image, caption="Uploaded image", use_column_width=True)
103
+
104
+ # TBD: Inference code to run and display for single image
105
+
106
+ elif str(uploaded_file.type).split("/")[-1] in video_types:
107
+ # Display uploaded video
108
+ st.video(uploaded_file)
109
+
110
+ # Convert streamlit video object to OpenCV format to run inferences
111
+ tfile = tempfile.NamedTemporaryFile(delete=False)
112
+ tfile.write(uploaded_file.read())
113
+ vf = cv.VideoCapture(tfile.name)
114
+
115
+ # Run inference on the uploaded video
116
+ with st.spinner("Running inference..."):
117
+ frames, counts, timestamps = inference.main(vf)
118
+ logging.info("INFO: Completed running inference on frames")
119
+ st.balloons()
120
+
121
+ # Convert OpenCV Numpy frames in-memory to IO Bytes for streamlit
122
+ streamlit_video_file = frames_to_video(frames=frames, fps=11)
123
+
124
+ # Show processed video and provide download button
125
+ st.video(streamlit_video_file)
126
+ st.download_button(
127
+ label="Download processed video",
128
+ data=streamlit_video_file,
129
+ mime="mp4",
130
+ file_name="processed_video.mp4",
131
+ )
132
+
133
+ # Create dataframe for fish counts and timestamps
134
+ df_counts_time = pd.DataFrame(
135
+ data={"fish_count": counts, "timestamps": timestamps[1:]}
136
+ )
137
+
138
+ # Display fish count vs. timestamp chart
139
+ st.altair_chart(
140
+ plot_count_date(dataframe=df_counts_time),
141
+ use_container_width=True,
142
+ )
143
 
 
 
 
 
 
 
144
  else:
145
+ st.write("No file uploaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ # Run the app if the script is executed directly
148
  if __name__ == "__main__":
149
+ run_app()