Spaces:
Sleeping
Sleeping
File size: 7,802 Bytes
6f50ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
from shared import upload_records
from ultralytics import YOLO
import streamlit as st
import cv2
from PIL import Image
import numpy as np
import tempfile
import datetime
import os
import io
import time
def _display_detected_frames(conf, model, st_frame, image, save_path, task_type):
"""
Display the detected objects on a video frame using the YOLO model.
:param conf (float): Confidence threshold for object detection.
:param model (YOLO): An instance of the YOLO class containing the YOLO model.
:param st_frame (Streamlit object): A Streamlit object to display the detected video.
:param image (numpy array): A numpy array representing the video frame.
:param save_path (str): The path to save the results.
:param task_type (str): The type of task, either 'detection' or 'segmentation'.
:return: None
"""
# Ensure the image is a 3-channel彩色图像
if image.ndim == 2 or image.shape[2] == 1: # 灰度图像或单通道
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
elif image.shape[2] == 4: # 四通道RGBA图像
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
# Resize the image to the standard size expected by the model
image_resized = cv2.resize(image, (640, 480))
# Perform object detection or segmentation using the YOLO model
results = model.predict(image_resized, conf=conf)
# Convert the results to the correct format for display and saving
if task_type == 'detection':
result_image = results[0].plot()
else: # segmentation
result_image = results[0].plot()
# Convert from BGR to RGB for Streamlit display
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
# Resize the result image to the fixed output size (750, 500) while maintaining aspect ratio
h, w = result_image_rgb.shape[:2]
scale_factor = min(550 / w, 450 / h)
new_w, new_h = int(w * scale_factor), int(h * scale_factor)
result_image_resized = cv2.resize(result_image_rgb, (new_w, new_h))
# Pad the image to ensure it is 750x500
padded_image = np.full((500, 750, 3), 255, dtype=np.uint8) # Create a white background
start_x = (750 - new_w) // 2
start_y = (500 - new_h) // 2
padded_image[start_y:start_y+new_h, start_x:start_x+new_w, :] = result_image_resized
# Display the frame with detections or segmentations in the Streamlit app
st_frame.image(
padded_image, # Directly use RGB image for display
caption=f'运行结果',
use_column_width=True
)
# If a save path is provided, save the frame with detections or segmentations
if save_path:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{task_type}_frame_{timestamp}.png"
save_path_full = os.path.join(save_path, filename)
# Save the padded image in RGB format
cv2.imwrite(save_path_full, result_image_resized) # Save in RGB format
st.write(f"文件保存在: {save_path_full}")
@st.cache_resource
def load_model(model_path):
"""
Loads a YOLO object detection or segmentation model from the specified model_path.
Parameters:
model_path (str): The path to the YOLO model file.
Returns:
A YOLO object detection or segmentation model.
"""
model = YOLO(model_path)
return model
def infer_uploaded_image(conf, model, save_path, task_type):
"""
Execute inference for uploaded images in batch.
:param conf: Confidence of YOLO model
:param model: An instance of the YOLO class containing the YOLO model.
:param save_path: The path to save the results.
:param task_type: The type of task, either 'detection' or 'segmentation'.
:return: None
"""
source_imgs = st.sidebar.file_uploader(
"选择图像",
type=("jpg", "jpeg", "png", 'bmp', 'webp'),
accept_multiple_files=True,
)
if source_imgs:
for img_info in source_imgs:
file_type = os.path.splitext(img_info.name)[1][1:].lower()
upload_records.append({
"file_name": img_info.name,
"file_type": file_type,
"uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
uploaded_image = Image.open(img_info)
img_byte_arr = io.BytesIO()
uploaded_image.save(img_byte_arr, format=file_type.upper() if file_type != 'jpg' else 'JPEG')
img_byte_arr = img_byte_arr.getvalue()
image = np.array(Image.open(io.BytesIO(img_byte_arr)))
st.image(
img_byte_arr,
caption=f"上传的图像: {img_info.name}",
use_column_width=True
)
with st.spinner("正在运行..."):
_display_detected_frames(conf, model, st.empty(), image, save_path, task_type)
def infer_uploaded_video(conf, model, save_path, task_type):
"""
Execute inference for uploaded video and display the detected objects on the video.
:param conf: Confidence of YOLO model
:param model: An instance of the YOLO class containing the YOLO model.
:param save_path: The path to save the results.
:param task_type: The type of task, either 'detection' or 'segmentation'.
:return: None
"""
source_video = st.sidebar.file_uploader(
"选择视频",
accept_multiple_files=True
)
if source_video:
for video_file in source_video:
file_type = os.path.splitext(video_file.name)[1][1:].lower()
upload_records.append({
"file_name": video_file.name,
"file_type": file_type,
"uploaded_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
st.video(video_file)
if st.button("开始运行"):
with st.spinner("运行中..."):
try:
tfile = tempfile.NamedTemporaryFile()
tfile.write(video_file.read())
vid_cap = cv2.VideoCapture(tfile.name)
st_frame = st.empty()
frame_rate = vid_cap.get(cv2.CAP_PROP_FPS)
delay = int(1000 / frame_rate)
start_time = time.time()
while True:
success, image = vid_cap.read()
if not success:
break
current_time = time.time()
if current_time - start_time >= 1.0:
_display_detected_frames(conf, model, st_frame, image, save_path, task_type)
start_time = current_time
vid_cap.release()
except Exception as e:
st.error(f"Error loading video: {e}")
def infer_uploaded_webcam(conf, model, save_path, task_type):
"""
Execute inference for webcam.
:param conf: Confidence of YOLO model
:param model: An instance of the YOLO class containing the YOLO model.
:param save_path: The path to save the results.
:param task_type: The type of task, either 'detection' or 'segmentation'.
:return: None
"""
try:
flag = st.button(
"关闭摄像头"
)
vid_cap = cv2.VideoCapture(0)
st_frame = st.empty()
while not flag:
success, image = vid_cap.read()
if success:
_display_detected_frames(conf, model, st_frame, image, save_path, task_type)
else:
vid_cap.release()
break
except Exception as e:
st.error(f"Error loading video: {str(e)}") |