File size: 7,802 Bytes
6f50ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from shared import upload_records
from ultralytics import YOLO
import streamlit as st
import cv2
from PIL import Image
import numpy as np
import tempfile
import datetime
import os
import io
import time

def _display_detected_frames(conf, model, st_frame, image, save_path, task_type):
    """
    Display the detected objects on a video frame using the YOLO model.
    :param conf (float): Confidence threshold for object detection.
    :param model (YOLO): An instance of the YOLO class containing the YOLO model.
    :param st_frame (Streamlit object): A Streamlit object to display the detected video.
    :param image (numpy array): A numpy array representing the video frame.
    :param save_path (str): The path to save the results.
    :param task_type (str): The type of task, either 'detection' or 'segmentation'.
    :return: None
    """
    # Ensure the image is a 3-channel彩色图像
    if image.ndim == 2 or image.shape[2] == 1:  # 灰度图像或单通道
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    elif image.shape[2] == 4:  # 四通道RGBA图像
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)

    # Resize the image to the standard size expected by the model
    image_resized = cv2.resize(image, (640, 480))

    # Perform object detection or segmentation using the YOLO model
    results = model.predict(image_resized, conf=conf)

    # Convert the results to the correct format for display and saving
    if task_type == 'detection':
        result_image = results[0].plot()
    else:  # segmentation
        result_image = results[0].plot()

    # Convert from BGR to RGB for Streamlit display
    result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)

    # Resize the result image to the fixed output size (750, 500) while maintaining aspect ratio
    h, w = result_image_rgb.shape[:2]
    scale_factor = min(550 / w, 450 / h)
    new_w, new_h = int(w * scale_factor), int(h * scale_factor)
    result_image_resized = cv2.resize(result_image_rgb, (new_w, new_h))

    # Pad the image to ensure it is 750x500
    padded_image = np.full((500, 750, 3), 255, dtype=np.uint8)  # Create a white background
    start_x = (750 - new_w) // 2
    start_y = (500 - new_h) // 2
    padded_image[start_y:start_y+new_h, start_x:start_x+new_w, :] = result_image_resized

    # Display the frame with detections or segmentations in the Streamlit app
    st_frame.image(
        padded_image,  # Directly use RGB image for display
        caption=f'运行结果',
        use_column_width=True
    )

    # If a save path is provided, save the frame with detections or segmentations
    if save_path:
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{task_type}_frame_{timestamp}.png"
        save_path_full = os.path.join(save_path, filename)
        # Save the padded image in RGB format
        cv2.imwrite(save_path_full, result_image_resized)  # Save in RGB format
        st.write(f"文件保存在: {save_path_full}")
@st.cache_resource
def load_model(model_path):
    """
    Loads a YOLO object detection or segmentation model from the specified model_path.
    Parameters:
        model_path (str): The path to the YOLO model file.
    Returns:
        A YOLO object detection or segmentation model.
    """
    model = YOLO(model_path)
    return model

def infer_uploaded_image(conf, model, save_path, task_type):
    """
    Execute inference for uploaded images in batch.
    :param conf: Confidence of YOLO model
    :param model: An instance of the YOLO class containing the YOLO model.
    :param save_path: The path to save the results.
    :param task_type: The type of task, either 'detection' or 'segmentation'.
    :return: None
    """
    source_imgs = st.sidebar.file_uploader(
        "选择图像",
        type=("jpg", "jpeg", "png", 'bmp', 'webp'),
        accept_multiple_files=True,
    )

    if source_imgs:
        for img_info in source_imgs:
            file_type = os.path.splitext(img_info.name)[1][1:].lower()
            upload_records.append({
                "file_name": img_info.name,
                "file_type": file_type,
                "uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            })

            uploaded_image = Image.open(img_info)
            img_byte_arr = io.BytesIO()
            uploaded_image.save(img_byte_arr, format=file_type.upper() if file_type != 'jpg' else 'JPEG')
            img_byte_arr = img_byte_arr.getvalue()
            image = np.array(Image.open(io.BytesIO(img_byte_arr)))

            st.image(
                img_byte_arr,
                caption=f"上传的图像: {img_info.name}",
                use_column_width=True
            )

            with st.spinner("正在运行..."):
                _display_detected_frames(conf, model, st.empty(), image, save_path, task_type)

def infer_uploaded_video(conf, model, save_path, task_type):
    """
    Execute inference for uploaded video and display the detected objects on the video.
    :param conf: Confidence of YOLO model
    :param model: An instance of the YOLO class containing the YOLO model.
    :param save_path: The path to save the results.
    :param task_type: The type of task, either 'detection' or 'segmentation'.
    :return: None
    """
    source_video = st.sidebar.file_uploader(
        "选择视频",
        accept_multiple_files=True
    )

    if source_video:
        for video_file in source_video:
            file_type = os.path.splitext(video_file.name)[1][1:].lower()
            upload_records.append({
                "file_name": video_file.name,
                "file_type": file_type,
                "uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            })

            st.video(video_file)

            if st.button("开始运行"):
                with st.spinner("运行中..."):
                    try:
                        tfile = tempfile.NamedTemporaryFile()
                        tfile.write(video_file.read())
                        vid_cap = cv2.VideoCapture(tfile.name)
                        st_frame = st.empty()
                        frame_rate = vid_cap.get(cv2.CAP_PROP_FPS)
                        delay = int(1000 / frame_rate)

                        start_time = time.time()
                        while True:
                            success, image = vid_cap.read()
                            if not success:
                                break

                            current_time = time.time()
                            if current_time - start_time >= 1.0:
                                _display_detected_frames(conf, model, st_frame, image, save_path, task_type)
                                start_time = current_time

                        vid_cap.release()
                    except Exception as e:
                        st.error(f"Error loading video: {e}")

def infer_uploaded_webcam(conf, model, save_path, task_type):
    """
    Execute inference for webcam.
    :param conf: Confidence of YOLO model
    :param model: An instance of the YOLO class containing the YOLO model.
    :param save_path: The path to save the results.
    :param task_type: The type of task, either 'detection' or 'segmentation'.
    :return: None
    """
    try:
        flag = st.button(
            "关闭摄像头"
        )
        vid_cap = cv2.VideoCapture(0)
        st_frame = st.empty()
        while not flag:
            success, image = vid_cap.read()
            if success:
                _display_detected_frames(conf, model, st_frame, image, save_path, task_type)
            else:
                vid_cap.release()
                break
    except Exception as e:
        st.error(f"Error loading video: {str(e)}")