DHPR commited on
Commit
5f58bd1
1 Parent(s): 35571e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -25
app.py CHANGED
@@ -2,21 +2,30 @@
2
  import os
3
  import json
4
  import numpy as np
 
5
  from pathlib import Path
6
  from pprint import pprint
7
  from omegaconf import OmegaConf
8
  from PIL import Image, ImageDraw
9
  import streamlit as st
 
10
  # %%
 
 
 
 
 
 
 
 
 
11
 
12
- os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__))
13
- print("os.environ['ROOT'] :",os.environ['ROOT'])
14
  # %%
15
  class ImageRetriever:
16
 
17
  def __init__(self, root_path, anno_path):
18
  self.root_path = Path(root_path)
19
- self.anno = json.load(open(anno_path))
20
 
21
  def key2img_path(self, key):
22
  file_paths = [
@@ -31,21 +40,18 @@ class ImageRetriever:
31
  self.root_path / f"{key}.png",
32
  self.root_path / f"{key}.jpg",
33
  ]
34
- print("file_paths!!!!!!!!", file_paths)
35
  for file_path in file_paths:
36
  if file_path.exists():
37
  return file_path
38
 
39
 
40
- def key2img(self, key, draw_bbox=True):
41
  file_path = self.key2img_path(key)
42
 
43
- print("file_path!!@@@@", key, file_path)
44
-
45
  image = Image.open(file_path)
 
46
  if draw_bbox:
47
- meta = self.anno[key]['details'][-1]
48
- bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)]
49
  image = self.hide_region(image, bboxes)
50
  return image
51
 
@@ -88,11 +94,9 @@ class ImageRetriever:
88
  image = Image.alpha_composite(image, overlay)
89
  return image
90
 
91
- def retrive_data(file_index, mode='direct'):
92
- split = 'val'
93
- mode = mode.lower()
94
- main_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{split}_{mode}.json")))
95
- temp_data = main_dataset[list(main_dataset.keys())[file_index]]['details'][-1]
96
 
97
  message_dict = {}
98
 
@@ -108,29 +112,46 @@ def retrive_data(file_index, mode='direct'):
108
  message_dict['Entity #3'] = temp_data['Entity #3']
109
 
110
  img_retriever = ImageRetriever(
111
- root_path=os.path.join(os.environ['ROOT'], ''),
112
- anno_path=os.path.join(os.environ['ROOT'], f'data/anno_{split}_{mode}.json'),
113
  )
114
- img = img_retriever.key2img(list(main_dataset.keys())[file_index])
115
  img = img.resize((img.width // 2, img.height // 2))
116
 
117
  return img, message_dict
118
 
 
 
 
119
  # %%
120
  if __name__ == '__main__':
121
  st.title("DHPR: Driving Hazard Prediction and Reasoning")
122
- st.subheader("Data Visualization")
123
 
124
- option = st.selectbox(
125
- 'Select the hazard type',
126
- ('Direct', 'Indirect'))
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- st.write('You selected:', option)
129
 
130
- image_index = st.slider('Please Select The Image Index', 0, 999, 0)
131
- st.write("You select the data index of ", image_index," for visualization from the validation set")
 
 
 
 
132
 
133
- img, message_dict = retrive_data(image_index, option)
134
 
135
  st.write("---")
136
 
 
2
  import os
3
  import json
4
  import numpy as np
5
+ import matplotlib.pyplot as plt
6
  from pathlib import Path
7
  from pprint import pprint
8
  from omegaconf import OmegaConf
9
  from PIL import Image, ImageDraw
10
  import streamlit as st
11
+ import random
12
  # %%
13
+ os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__))#'/mnt/Documents/traffic_var_server/visualization'
14
+ # print("os.environ['ROOT'] :",os.environ['ROOT'])
15
+ # %%
16
+
17
+ def get_list_folder(PATH):
18
+ return [name for name in os.listdir(PATH) if os.path.isdir(os.path.join(PATH, name))]
19
+
20
+ def get_file_only(PATH):
21
+ return [f for f in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, f))]
22
 
 
 
23
  # %%
24
  class ImageRetriever:
25
 
26
  def __init__(self, root_path, anno_path):
27
  self.root_path = Path(root_path)
28
+ self.anno_path = Path(anno_path)
29
 
30
  def key2img_path(self, key):
31
  file_paths = [
 
40
  self.root_path / f"{key}.png",
41
  self.root_path / f"{key}.jpg",
42
  ]
 
43
  for file_path in file_paths:
44
  if file_path.exists():
45
  return file_path
46
 
47
 
48
+ def key2img(self, key, temp_data, draw_bbox=True):
49
  file_path = self.key2img_path(key)
50
 
 
 
51
  image = Image.open(file_path)
52
+
53
  if draw_bbox:
54
+ bboxes = [temp_data['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)]
 
55
  image = self.hide_region(image, bboxes)
56
  return image
57
 
 
94
  image = Image.alpha_composite(image, overlay)
95
  return image
96
 
97
+ def retrive_data(temp_data, img_key, mode='direct'):
98
+
99
+ # temp_data = main_dataset[list(main_dataset.keys())[file_index]]['details'][-1]
 
 
100
 
101
  message_dict = {}
102
 
 
112
  message_dict['Entity #3'] = temp_data['Entity #3']
113
 
114
  img_retriever = ImageRetriever(
115
+ root_path=os.path.join(os.environ['ROOT'], ''),
116
+ anno_path=os.path.join(os.environ['ROOT'], f'data/anno_{split}_{mode}.json'),
117
  )
118
+ img = img_retriever.key2img(img_key, temp_data)
119
  img = img.resize((img.width // 2, img.height // 2))
120
 
121
  return img, message_dict
122
 
123
+ # %%
124
+
125
+
126
  # %%
127
  if __name__ == '__main__':
128
  st.title("DHPR: Driving Hazard Prediction and Reasoning")
 
129
 
130
+ img_path = os.path.join(os.environ['ROOT'], 'img/')
131
+ img_path_list = get_file_only(img_path)
132
+
133
+ split = 'val'
134
+ rand_index = 0
135
+ main_direct_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{'val'}_{'direct'}.json")))
136
+ main_indirect_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{'val'}_{'indirect'}.json")))
137
+
138
+ if st.button('Random Data Sample'):
139
+ rand_index = random.randint(0, len(get_file_only(img_path)))
140
+ else:
141
+ pass
142
+
143
+ st.subheader("Data Visualization")
144
 
145
+ img_key = img_path_list[rand_index].split('.')[0]
146
 
147
+ if img_key in main_direct_dataset.keys():
148
+ temp_data = main_direct_dataset[img_key]['details'][-1]
149
+ elif img_key in main_indirect_dataset.keys():
150
+ temp_data = main_indirect_dataset[img_key]['details'][-1]
151
+ else:
152
+ pass
153
 
154
+ img, message_dict = retrive_data(temp_data, img_key)
155
 
156
  st.write("---")
157