|  |  | 
					
						
						|  | import os | 
					
						
						|  | import json | 
					
						
						|  | import numpy as np | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from pprint import pprint | 
					
						
						|  | from omegaconf import OmegaConf | 
					
						
						|  | from PIL import Image, ImageDraw | 
					
						
						|  | import streamlit as st | 
					
						
						|  | import random | 
					
						
						|  |  | 
					
						
						|  | os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_list_folder(PATH): | 
					
						
						|  | return [name for name in os.listdir(PATH) if os.path.isdir(os.path.join(PATH, name))] | 
					
						
						|  |  | 
					
						
						|  | def get_file_only(PATH): | 
					
						
						|  | return [f for f in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, f))] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ImageRetriever: | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, root_path, anno_path): | 
					
						
						|  | self.root_path = Path(root_path) | 
					
						
						|  | self.anno_path = Path(anno_path) | 
					
						
						|  |  | 
					
						
						|  | def key2img_path(self, key): | 
					
						
						|  | file_paths = [ | 
					
						
						|  | self.root_path / f"var_images/{key}.jpg", | 
					
						
						|  | self.root_path / f"var_images/{key}.png", | 
					
						
						|  | self.root_path / f"images/{key}.jpg", | 
					
						
						|  | self.root_path / f"img/train/{key.split('_')[0]}/{key}.png", | 
					
						
						|  | self.root_path / f"img/val/{key.split('_')[0]}/{key}.png", | 
					
						
						|  | self.root_path / f"img/test/{key.split('_')[0]}/{key}.png", | 
					
						
						|  | self.root_path / f"img/{key}.png", | 
					
						
						|  | self.root_path / f"img/{key}.jpg", | 
					
						
						|  | self.root_path / f"{key}.png", | 
					
						
						|  | self.root_path / f"{key}.jpg", | 
					
						
						|  | ] | 
					
						
						|  | for file_path in file_paths: | 
					
						
						|  | if file_path.exists(): | 
					
						
						|  | return file_path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def key2img(self, key, temp_data, draw_bbox=True): | 
					
						
						|  | file_path = self.key2img_path(key) | 
					
						
						|  |  | 
					
						
						|  | image = Image.open(file_path) | 
					
						
						|  |  | 
					
						
						|  | if draw_bbox: | 
					
						
						|  | bboxes = [temp_data['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)] | 
					
						
						|  | image = self.hide_region(image, bboxes) | 
					
						
						|  | return image | 
					
						
						|  |  | 
					
						
						|  | def hide_region(self, image, bboxes): | 
					
						
						|  | self.hide_true_bbox = 2 | 
					
						
						|  |  | 
					
						
						|  | image = image.convert('RGBA') | 
					
						
						|  |  | 
					
						
						|  | if self.hide_true_bbox == 1: | 
					
						
						|  | draw = ImageDraw.Draw(image, 'RGBA') | 
					
						
						|  |  | 
					
						
						|  | if self.hide_true_bbox in [2, 5, 7, 8, 9]: | 
					
						
						|  | overlay = Image.new('RGBA', image.size, '#00000000') | 
					
						
						|  | draw = ImageDraw.Draw(overlay, 'RGBA') | 
					
						
						|  |  | 
					
						
						|  | if self.hide_true_bbox == 3 or self.hide_true_bbox == 6: | 
					
						
						|  | overlay = Image.new('RGBA', image.size, '#7B7575ff') | 
					
						
						|  | draw = ImageDraw.Draw(overlay, 'RGBA') | 
					
						
						|  |  | 
					
						
						|  | color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] | 
					
						
						|  |  | 
					
						
						|  | for idx, bbox in enumerate(bboxes): | 
					
						
						|  | if bbox == None: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | color_fill = color_fill_list[idx] | 
					
						
						|  | x, y = bbox['left'], bbox['top'] | 
					
						
						|  |  | 
					
						
						|  | if self.hide_true_bbox == 1: | 
					
						
						|  | draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575') | 
					
						
						|  | elif self.hide_true_bbox in [2, 5, 7, 8, 9]: | 
					
						
						|  | draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff', | 
					
						
						|  | width=3) | 
					
						
						|  | elif self.hide_true_bbox == 3: | 
					
						
						|  | draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000') | 
					
						
						|  | elif self.hide_true_bbox == 6: | 
					
						
						|  | draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill) | 
					
						
						|  |  | 
					
						
						|  | if self.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]: | 
					
						
						|  | image = Image.alpha_composite(image, overlay) | 
					
						
						|  | return image | 
					
						
						|  |  | 
					
						
						|  | def retrive_data(temp_data, img_key, mode='direct'): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | message_dict = {} | 
					
						
						|  |  | 
					
						
						|  | message_dict['img'] = temp_data['img'] | 
					
						
						|  | message_dict['plausible_speed'] = temp_data['plausible_speed'] | 
					
						
						|  | message_dict['bounding_box'] = temp_data['bounding_box'] | 
					
						
						|  | try: | 
					
						
						|  | message_dict['hazard'] = temp_data['hazard'] | 
					
						
						|  | except: | 
					
						
						|  | message_dict['hazard'] = temp_data['rationale'] | 
					
						
						|  | message_dict['Entity #1'] = temp_data['Entity #1'] | 
					
						
						|  | message_dict['Entity #2'] = temp_data['Entity #2'] | 
					
						
						|  | message_dict['Entity #3'] = temp_data['Entity #3'] | 
					
						
						|  |  | 
					
						
						|  | img_retriever = ImageRetriever( | 
					
						
						|  | root_path=os.path.join(os.environ['ROOT'], ''), | 
					
						
						|  | anno_path=os.path.join(os.environ['ROOT'], f'data/anno_{split}_{mode}.json'), | 
					
						
						|  | ) | 
					
						
						|  | img = img_retriever.key2img(img_key, temp_data) | 
					
						
						|  | img = img.resize((img.width // 2, img.height // 2)) | 
					
						
						|  |  | 
					
						
						|  | return img, message_dict | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | st.title("DHPR: Driving Hazard Prediction and Reasoning") | 
					
						
						|  |  | 
					
						
						|  | img_path = os.path.join(os.environ['ROOT'], 'img/') | 
					
						
						|  | img_path_list = get_file_only(img_path) | 
					
						
						|  |  | 
					
						
						|  | split = 'val' | 
					
						
						|  | rand_index = 0 | 
					
						
						|  | main_direct_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{'val'}_{'direct'}.json"))) | 
					
						
						|  | main_indirect_dataset = json.load(open(os.path.join(os.environ['ROOT'], f"data/anno_{'val'}_{'indirect'}.json"))) | 
					
						
						|  |  | 
					
						
						|  | if st.button('Random Data Sample'): | 
					
						
						|  | rand_index = random.randint(0, len(get_file_only(img_path))) | 
					
						
						|  | else: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | st.subheader("Data Visualization") | 
					
						
						|  |  | 
					
						
						|  | img_key = img_path_list[rand_index].split('.')[0] | 
					
						
						|  |  | 
					
						
						|  | if img_key in main_direct_dataset.keys(): | 
					
						
						|  | temp_data = main_direct_dataset[img_key]['details'][-1] | 
					
						
						|  | elif img_key in main_indirect_dataset.keys(): | 
					
						
						|  | temp_data = main_indirect_dataset[img_key]['details'][-1] | 
					
						
						|  | else: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | img, message_dict =  retrive_data(temp_data, img_key) | 
					
						
						|  |  | 
					
						
						|  | st.write("---") | 
					
						
						|  |  | 
					
						
						|  | st.image(img) | 
					
						
						|  | st.subheader("Annotation Details") | 
					
						
						|  | st.json(message_dict) | 
					
						
						|  | st.write('---') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  |