File size: 6,959 Bytes
2143db5
abf96d3
 
58ff559
 
3ce0474
 
 
58ff559
 
 
 
2143db5
abf96d3
58ff559
 
abf96d3
58ff559
 
f827af4
 
 
 
 
 
 
 
 
58ff559
 
 
 
abf96d3
 
 
3ce0474
58ff559
 
 
 
7078129
 
58ff559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b58773
 
58ff559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abf96d3
58ff559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abf96d3
58ff559
 
 
 
 
 
 
 
 
 
 
3ce0474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58ff559
71a7cbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import streamlit as st
import os
from PIL import Image
import random
import cv2
import shutil
import sys
sys.dont_write_bytecode = True

from infer.yolov7.get_results import get_yolov7_result
from infer.yolov5.get_results import get_yolov5_result
from infer.yolov8.get_results import get_yolov8_result

image_path = "image/"
txt_path = "label/"
random.seed(0)
each_image_name = os.listdir(image_path)
random.shuffle(each_image_name)

# label_names = ['epiglottis', 'vocal cord', 'trachea', 'carina', 'right main bronchus', 'intermediate bronchus',
#                'right upper lobar bronchus', 'right middle lobar bronchus', 'right lower lobar bronchus', 'right superior segment bronchus',
#                'right basal bronchus', 'left main bronchus', 'left upper lobar bronchus', 'left division bronchus',
#                'left lingular bronchus', 'left lower bronchus', 'left superior segment', 'left basal bronchus']

label_names = ['Epiglottis', 'Vocal Fold', 'Trachea', 'Left Main Bronchus', 'Carina', 'Right Main Bronchus', 'Left Upper Lobar Bronchus',
               'Left Lower Bronchus', 'Right Upper Lobar Bronchus', 'Intermediate Bronchus', 'Right Lower Lobar Bronchus',
               'Left Divsion Bronchus', 'Left Lingular Bronchus', 'Left Superior Segment',
               'Left Basal Bronchus', 'Right Middle Lobar Bronchus', 'Right Basal Bronchus', 'Right Superior Segment Bronchus']

model_list = ['YOLO-V8',
              'YOLO-V7',
              'YOLO-V5']

st.set_page_config(layout="wide")


def inference(image, model_name, conf_threshold, iou_threshold):
    if model_name == "YOLO-V7":
        return get_yolov7_result(image, conf_threshold, iou_threshold, label_names)
    elif model_name == "YOLO-V5":
        #return get_yolov5_result(image, conf_threshold, iou_threshold, label_names)
        return None, None
    elif model_name == "YOLO-V8":
        return get_yolov8_result(image, conf_threshold, iou_threshold, label_names)
    else:
        return None, None

def image_on_click(image_index):
    with body1_col2:
        st.header("Image Information")
        image_name = each_image_name[image_index]
        image = Image.open(os.path.join(image_path, image_name))
        cv2_image = cv2.imread(os.path.join(image_path, image_name))
        cv2_image_copy = cv2_image.copy()
        cv2_h, cv2_w, _ = cv2_image.shape
        st.write("Image Width: " ,image.width)
        st.write("Image Height: " ,image.height)
        temp_label_list = []
        with open(os.path.join(txt_path, image_name.replace(".png",".txt")), "r") as f:
            lines = f.readlines()
            for line in lines:
                line = line.split(" ")
                #label_index = int(line[0]) - 1
                label_index = int(line[0])
                label_name = label_names[label_index]
                x_center = float(line[1])
                y_center = float(line[2])
                width = float(line[3])
                height = float(line[4])
                x_center, y_center, width, height = [x_center * cv2_w, y_center * cv2_h, width * cv2_w, height * cv2_h]

                x_min = int(x_center - width / 2)
                y_min = int(y_center - height / 2)
                x_max = int(x_center + width / 2)
                y_max = int(y_center + height / 2)

                cv2.rectangle(cv2_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)

                label_size, _ = cv2.getTextSize(label_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
                label_x_min = x_min
                label_y_min = y_min - label_size[1] - 10
                label_x_max = x_min + label_size[0]
                label_y_max = y_min
                cv2.rectangle(cv2_image, (label_x_min, label_y_min), (label_x_max, label_y_max), (0, 255, 0), cv2.FILLED)
                cv2.putText(cv2_image, label_name, (label_x_min, label_y_min + label_size[1] + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                            (0, 0, 0), 1)

                temp_label_list.append(label_name)
        st.write("Label:" + str(temp_label_list))
        cv2_image = cv2_image[...,::-1]

        st.image(cv2_image, image_name.replace(".png","") + " label image")

        with body1_col3:
            st.header("Inference Result")
            result_image, result_list = inference(cv2_image_copy, selected_model, conf_threshold, iou_threshold)
            if result_list is not None:
                for each_list in result_list:
                    st.markdown(f'Label:  <span style="color:rgb{each_list[1][::-1]}">{each_list[0]}</span> &nbsp; Conf:  <span style="color:red">{"{:.3f}".format(each_list[2])}</span>', unsafe_allow_html=True)
            if result_image is not None:
                st.image(result_image, image_name.replace(".png","") + " result image")
            else:
                st.warning("Not implemented yet")

body1 = st.container()
with body1:
    body1_col1, body1_col2, body1_col3 = st.columns([2,1,1])
    with body1_col1:
        st.header("Select an image")
        image_cols = st.columns(5)
        for i, col in enumerate(image_cols):
            with col:
                image = Image.open(os.path.join(image_path, each_image_name[i]))
                st.image(image, each_image_name[i].replace(".png",""))
        button_cols = st.columns(5)
        for i, col in enumerate(button_cols):
            with col:
                st.button('Select', key=i, use_container_width=True, on_click=image_on_click, args=(i,))

        image_cols = st.columns(5)
        for i, col in enumerate(image_cols, start=5):
            with col:
                image = Image.open(os.path.join(image_path, each_image_name[i]))
                st.image(image, each_image_name[i].replace(".png", ""))
        button_cols = st.columns(5)
        for i, col in enumerate(button_cols, start=5):
            with col:
                st.button('Select', key=i, use_container_width=True, on_click=image_on_click, args=(i,))

        component_col1, component_col2, component_col3 = st.columns(3)
        with component_col1:
            selected_model = st.selectbox('Select the inference model', model_list)
        with component_col2:
            conf_threshold = st.slider('Select the confidence threshold', 0.0, 1.0, 0.50)
        with component_col3:
            iou_threshold = st.slider('Select the IOU threshold', 0.0, 1.0, 0.01)

body2 = st.container()
with body2:
    st.markdown("""
        <style>
        .footer {
            position: fixed;
            left: 0;
            bottom: 0;
            width: 100%;
            text-align: center;
        }
        </style>
        <div class="footer">
            <p>Our paper: <a href="#">Enhanced Object Detection in Pediatric Bronchoscopy Images using YOLO-based Algorithms with CBAM Attention Mechanism</a></p>
            <p>Author: Jianqi Yan, Copyright &copy; 2024, Quanbao Technologies Co. Ltd </p>
        </div>
        """, unsafe_allow_html=True)