import streamlit as st import os from PIL import Image import random import cv2 import shutil import sys sys.dont_write_bytecode = True from infer.yolov7.get_results import get_yolov7_result from infer.yolov5.get_results import get_yolov5_result from infer.yolov8.get_results import get_yolov8_result image_path = "image/" txt_path = "label/" random.seed(0) each_image_name = os.listdir(image_path) random.shuffle(each_image_name) # label_names = ['epiglottis', 'vocal cord', 'trachea', 'carina', 'right main bronchus', 'intermediate bronchus', # 'right upper lobar bronchus', 'right middle lobar bronchus', 'right lower lobar bronchus', 'right superior segment bronchus', # 'right basal bronchus', 'left main bronchus', 'left upper lobar bronchus', 'left division bronchus', # 'left lingular bronchus', 'left lower bronchus', 'left superior segment', 'left basal bronchus'] label_names = ['Epiglottis', 'Vocal Fold', 'Trachea', 'Left Main Bronchus', 'Carina', 'Right Main Bronchus', 'Left Upper Lobar Bronchus', 'Left Lower Bronchus', 'Right Upper Lobar Bronchus', 'Intermediate Bronchus', 'Right Lower Lobar Bronchus', 'Left Divsion Bronchus', 'Left Lingular Bronchus', 'Left Superior Segment', 'Left Basal Bronchus', 'Right Middle Lobar Bronchus', 'Right Basal Bronchus', 'Right Superior Segment Bronchus'] model_list = ['YOLO-V8', 'YOLO-V7', 'YOLO-V5'] st.set_page_config(layout="wide") def inference(image, model_name, conf_threshold, iou_threshold): if model_name == "YOLO-V7": return get_yolov7_result(image, conf_threshold, iou_threshold, label_names) elif model_name == "YOLO-V5": #return get_yolov5_result(image, conf_threshold, iou_threshold, label_names) return None, None elif model_name == "YOLO-V8": return get_yolov8_result(image, conf_threshold, iou_threshold, label_names) else: return None, None def image_on_click(image_index): with body1_col2: st.header("Image Information") image_name = each_image_name[image_index] image = Image.open(os.path.join(image_path, image_name)) cv2_image = cv2.imread(os.path.join(image_path, image_name)) cv2_image_copy = cv2_image.copy() cv2_h, cv2_w, _ = cv2_image.shape st.write("Image Width: " ,image.width) st.write("Image Height: " ,image.height) temp_label_list = [] with open(os.path.join(txt_path, image_name.replace(".png",".txt")), "r") as f: lines = f.readlines() for line in lines: line = line.split(" ") #label_index = int(line[0]) - 1 label_index = int(line[0]) label_name = label_names[label_index] x_center = float(line[1]) y_center = float(line[2]) width = float(line[3]) height = float(line[4]) x_center, y_center, width, height = [x_center * cv2_w, y_center * cv2_h, width * cv2_w, height * cv2_h] x_min = int(x_center - width / 2) y_min = int(y_center - height / 2) x_max = int(x_center + width / 2) y_max = int(y_center + height / 2) cv2.rectangle(cv2_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) label_size, _ = cv2.getTextSize(label_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2) label_x_min = x_min label_y_min = y_min - label_size[1] - 10 label_x_max = x_min + label_size[0] label_y_max = y_min cv2.rectangle(cv2_image, (label_x_min, label_y_min), (label_x_max, label_y_max), (0, 255, 0), cv2.FILLED) cv2.putText(cv2_image, label_name, (label_x_min, label_y_min + label_size[1] + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) temp_label_list.append(label_name) st.write("Label:" + str(temp_label_list)) cv2_image = cv2_image[...,::-1] st.image(cv2_image, image_name.replace(".png","") + " label image") with body1_col3: st.header("Inference Result") result_image, result_list = inference(cv2_image_copy, selected_model, conf_threshold, iou_threshold) if result_list is not None: for each_list in result_list: st.markdown(f'Label: {each_list[0]}   Conf: {"{:.3f}".format(each_list[2])}', unsafe_allow_html=True) if result_image is not None: st.image(result_image, image_name.replace(".png","") + " result image") else: st.warning("Not implemented yet") body1 = st.container() with body1: body1_col1, body1_col2, body1_col3 = st.columns([2,1,1]) with body1_col1: st.header("Select an image") image_cols = st.columns(5) for i, col in enumerate(image_cols): with col: image = Image.open(os.path.join(image_path, each_image_name[i])) st.image(image, each_image_name[i].replace(".png","")) button_cols = st.columns(5) for i, col in enumerate(button_cols): with col: st.button('Select', key=i, use_container_width=True, on_click=image_on_click, args=(i,)) image_cols = st.columns(5) for i, col in enumerate(image_cols, start=5): with col: image = Image.open(os.path.join(image_path, each_image_name[i])) st.image(image, each_image_name[i].replace(".png", "")) button_cols = st.columns(5) for i, col in enumerate(button_cols, start=5): with col: st.button('Select', key=i, use_container_width=True, on_click=image_on_click, args=(i,)) component_col1, component_col2, component_col3 = st.columns(3) with component_col1: selected_model = st.selectbox('Select the inference model', model_list) with component_col2: conf_threshold = st.slider('Select the confidence threshold', 0.0, 1.0, 0.50) with component_col3: iou_threshold = st.slider('Select the IOU threshold', 0.0, 1.0, 0.01) body2 = st.container() with body2: st.markdown(""" """, unsafe_allow_html=True)