ake178178 commited on
Commit
ed78f65
·
verified ·
1 Parent(s): 1d51954

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -22
app.py CHANGED
@@ -1,33 +1,46 @@
1
  import streamlit as st
2
- from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
3
- import tensorflow as tf
4
  import numpy as np
5
- from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
 
 
6
 
7
  st.title("物体识别应用")
8
  st.write("通过摄像头识别物体,从左到右显示主要物体的名称")
9
 
10
- # 加载 MobileNet 预训练模型
11
- model = tf.keras.applications.MobileNet(weights="imagenet")
 
12
 
13
- class ObjectDetectionTransformer(VideoTransformerBase):
14
- def transform(self, frame):
15
- img = frame.to_ndarray(format="bgr24")
16
 
17
- # 将图像调整大小并进行预处理
18
- image_resized = cv2.resize(img, (224, 224))
19
- image_array = np.expand_dims(image_resized, axis=0)
20
- processed_image = preprocess_input(image_array)
 
 
 
 
 
 
 
21
 
22
- # 进行物体识别
23
- preds = model.predict(processed_image)
24
- decoded_preds = decode_predictions(preds, top=3)[0]
25
- objects = [f"{label}: {round(score * 100, 2)}%" for (_, label, score) in decoded_preds]
26
- detected_text = " | ".join(objects)
 
27
 
28
- # 将检测结果写在图像上
29
- cv2.putText(img, detected_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
30
- return img
31
 
32
- # 使用 streamlit-webrtc 启动摄像头流
33
- webrtc_streamer(key="object-detection", video_transformer_factory=ObjectDetectionTransformer)
 
 
 
 
 
1
  import streamlit as st
2
+ import cv2
 
3
  import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.keras.applications.mobilenet import preprocess_input
6
+ from tensorflow.keras.applications.mobilenet import decode_predictions
7
 
8
  st.title("物体识别应用")
9
  st.write("通过摄像头识别物体,从左到右显示主要物体的名称")
10
 
11
+ # 设置摄像头
12
+ video_capture = cv2.VideoCapture(0)
13
+ stframe = st.empty()
14
 
15
+ # 加载 MobileNet SSD 预训练模型
16
+ model = tf.keras.applications.MobileNet(weights="imagenet")
 
17
 
18
+ def detect_objects(frame):
19
+ # 预处理图像
20
+ image_resized = cv2.resize(frame, (224, 224))
21
+ image_array = np.expand_dims(image_resized, axis=0)
22
+ processed_image = preprocess_input(image_array)
23
+
24
+ # 使用模型进行预测
25
+ preds = model.predict(processed_image)
26
+ decoded_preds = decode_predictions(preds, top=3)[0] # 取前3个结果
27
+ objects = [f"{label}: {round(score * 100, 2)}%" for (_, label, score) in decoded_preds]
28
+ return objects
29
 
30
+ # 读取摄像头流并显示
31
+ while True:
32
+ ret, frame = video_capture.read()
33
+ if not ret:
34
+ st.write("无法读取摄像头数据。")
35
+ break
36
 
37
+ # 检测物体
38
+ objects = detect_objects(frame)
39
+ detected_text = " | ".join(objects)
40
 
41
+ # 显示检测结果
42
+ cv2.putText(frame, detected_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
43
+
44
+ # 将 OpenCV 图像格式转换为 Streamlit 显示格式
45
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
46
+ stframe.image(frame_rgb, caption="检测到的物体", channels="RGB")