amosfang commited on
Commit
6d85dc3
1 Parent(s): 92b0eee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -39
app.py CHANGED
@@ -16,31 +16,22 @@ PATH_TO_LABELS = 'data/label_map.pbtxt'
16
  category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
17
 
18
  def pil_image_as_numpy_array(pilimg):
19
-
20
  img_array = tf.keras.utils.img_to_array(pilimg)
21
- # img_array = np.expand_dims(img_array, axis=0)
22
  return img_array
23
-
24
 
25
  def load_model():
26
  model_dir = 'saved_model'
27
  detection_model = tf.saved_model.load(str(model_dir))
28
- return detection_model
29
-
30
 
31
  def predict(image_np):
32
-
33
  image_np = pil_image_as_numpy_array(image_np)
34
  image_np = np.expand_dims(image_np, axis=0)
35
-
36
  results = detection_model(image_np)
37
-
38
- # different object detection models have additional results
39
- result = {key:value.numpy() for key,value in results.items()}
40
-
41
  label_id_offset = 0
42
  image_np_with_detections = image_np.copy()
43
-
44
  viz_utils.visualize_boxes_and_labels_on_image_array(
45
  image_np_with_detections[0],
46
  result['detection_boxes'][0],
@@ -51,57 +42,47 @@ def predict(image_np):
51
  max_boxes_to_draw=200,
52
  min_score_thresh=.60,
53
  agnostic_mode=False,
54
- line_thickness=2)
55
-
56
  result_pil_img = tf.keras.utils.array_to_img(image_np_with_detections[0])
57
-
58
  return result_pil_img
59
 
60
-
61
  detection_model = load_model()
62
 
63
  # Specify paths to example images
64
- sample_images = [["test_1.jpg"],["test_9.jpg"],["test_6.jpg"],["test_7.jpg"],
65
- ["test_10.jpg"], ["test_11.jpg"],["test_8.jpg"]]
66
-
67
- tab1 = gr.Interface(fn=predict,
68
- inputs=gr.Image(label='Upload an expressway image', type="pil"),
69
- outputs=gr.Image(type="pil"),
70
- title='Blue and Yellow Taxi detection in live expressway traffic conditions (data.gov.sg)',
71
- examples = sample_images
72
- )
 
 
 
73
 
74
  def predict_on_video(video_in_filepath, video_out_filepath, detection_model, category_index):
 
75
  video_reader = cv2.VideoCapture(video_in_filepath)
76
-
77
  frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
78
  frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
79
  fps = video_reader.get(cv2.CAP_PROP_FPS)
80
-
81
  video_writer = cv2.VideoWriter(
82
  video_out_filepath,
83
  cv2.VideoWriter_fourcc(*'mp4v'),
84
  fps,
85
  (frame_w, frame_h)
86
  )
87
-
88
  label_id_offset = 0
89
-
90
  while True:
91
  ret, frame = video_reader.read()
92
-
93
  if not ret:
94
  break # Break the loop if the video is finished
95
-
96
- processed_frame = predict(frame, detection_model, category_index, label_id_offset)
97
-
98
- # Convert processed frame to numpy array
99
  processed_frame_np = np.array(processed_frame)
100
-
101
- # Write the frame to the output video
102
  video_writer.write(processed_frame_np)
103
-
104
- # Release video reader and writer
105
  video_reader.release()
106
  video_writer.release()
107
  cv2.destroyAllWindows()
@@ -110,7 +91,6 @@ def predict_on_video(video_in_filepath, video_out_filepath, detection_model, cat
110
  # Function to process a video
111
  def process_video(video_path):
112
  output_path = "output_video.mp4" # Output path for the processed video
113
- # Assuming you have detection_model and category_index defined
114
  predict_on_video(video_path, output_path, detection_model, category_index)
115
  return output_path
116
 
 
16
  category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
17
 
18
  def pil_image_as_numpy_array(pilimg):
 
19
  img_array = tf.keras.utils.img_to_array(pilimg)
 
20
  return img_array
 
21
 
22
  def load_model():
23
  model_dir = 'saved_model'
24
  detection_model = tf.saved_model.load(str(model_dir))
25
+ return detection_model
 
26
 
27
  def predict(image_np):
28
+ global detection_model # Declare as a global variable
29
  image_np = pil_image_as_numpy_array(image_np)
30
  image_np = np.expand_dims(image_np, axis=0)
 
31
  results = detection_model(image_np)
32
+ result = {key: value.numpy() for key, value in results.items()}
 
 
 
33
  label_id_offset = 0
34
  image_np_with_detections = image_np.copy()
 
35
  viz_utils.visualize_boxes_and_labels_on_image_array(
36
  image_np_with_detections[0],
37
  result['detection_boxes'][0],
 
42
  max_boxes_to_draw=200,
43
  min_score_thresh=.60,
44
  agnostic_mode=False,
45
+ line_thickness=2
46
+ )
47
  result_pil_img = tf.keras.utils.array_to_img(image_np_with_detections[0])
 
48
  return result_pil_img
49
 
 
50
  detection_model = load_model()
51
 
52
  # Specify paths to example images
53
+ sample_images = [
54
+ ["test_1.jpg"], ["test_9.jpg"], ["test_6.jpg"],
55
+ ["test_7.jpg"], ["test_10.jpg"], ["test_11.jpg"], ["test_8.jpg"]
56
+ ]
57
+
58
+ tab1 = gr.Interface(
59
+ fn=predict,
60
+ inputs=gr.Image(label='Upload an expressway image', type="pil"),
61
+ outputs=gr.Image(type="pil"),
62
+ title='Blue and Yellow Taxi detection in live expressway traffic conditions (data.gov.sg)',
63
+ examples=sample_images
64
+ )
65
 
66
  def predict_on_video(video_in_filepath, video_out_filepath, detection_model, category_index):
67
+ global detection_model # Declare as a global variable
68
  video_reader = cv2.VideoCapture(video_in_filepath)
 
69
  frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
70
  frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
71
  fps = video_reader.get(cv2.CAP_PROP_FPS)
 
72
  video_writer = cv2.VideoWriter(
73
  video_out_filepath,
74
  cv2.VideoWriter_fourcc(*'mp4v'),
75
  fps,
76
  (frame_w, frame_h)
77
  )
 
78
  label_id_offset = 0
 
79
  while True:
80
  ret, frame = video_reader.read()
 
81
  if not ret:
82
  break # Break the loop if the video is finished
83
+ processed_frame = predict(frame)
 
 
 
84
  processed_frame_np = np.array(processed_frame)
 
 
85
  video_writer.write(processed_frame_np)
 
 
86
  video_reader.release()
87
  video_writer.release()
88
  cv2.destroyAllWindows()
 
91
  # Function to process a video
92
  def process_video(video_path):
93
  output_path = "output_video.mp4" # Output path for the processed video
 
94
  predict_on_video(video_path, output_path, detection_model, category_index)
95
  return output_path
96