|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from transformers import pipeline |
|
import torch |
|
from random import choice |
|
from io import BytesIO |
|
import os |
|
from datetime import datetime |
|
|
|
|
|
detector = pipeline(model="facebook/detr-resnet-101", use_fast=True) |
|
if torch.cuda.is_available(): |
|
detector.model.to('cuda') |
|
|
|
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff", |
|
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf", |
|
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"] |
|
|
|
fdic = { |
|
"style": "italic", |
|
"size": 15, |
|
"color": "yellow", |
|
"weight": "bold" |
|
} |
|
|
|
def query_data(in_pil_img: Image.Image): |
|
results = detector(in_pil_img) |
|
print(f"检测结果:{results}") |
|
return results |
|
|
|
def get_annotated_image(in_pil_img): |
|
plt.figure(figsize=(16, 10)) |
|
plt.imshow(in_pil_img) |
|
ax = plt.gca() |
|
in_results = query_data(in_pil_img) |
|
|
|
for prediction in in_results: |
|
color = choice(COLORS) |
|
box = prediction['box'] |
|
label = prediction['label'] |
|
score = round(prediction['score'] * 100, 1) |
|
|
|
ax.add_patch(plt.Rectangle((box['xmin'], box['ymin']), |
|
box['xmax'] - box['xmin'], |
|
box['ymax'] - box['ymin'], |
|
fill=False, color=color, linewidth=3)) |
|
ax.text(box['xmin'], box['ymin'], f"{label}: {score}%", fontdict=fdic) |
|
|
|
plt.axis("off") |
|
buf = BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
plt.close() |
|
buf.seek(0) |
|
annotated_image = Image.open(buf).convert('RGB') |
|
return np.array(annotated_image) |
|
|
|
def process_video(input_video_path): |
|
cap = cv2.VideoCapture(input_video_path) |
|
if not cap.isOpened(): |
|
raise ValueError("无法打开输入视频文件") |
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
output_dir = './output_videos' |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_video_filename = f"output_{timestamp}.mp4" |
|
output_video_path = os.path.join(output_dir, output_video_filename) |
|
|
|
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) |
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
pil_image = Image.fromarray(rgb_frame) |
|
annotated_frame = get_annotated_image(pil_image) |
|
bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
if bgr_frame.shape[:2] != (height, width): |
|
bgr_frame = cv2.resize(bgr_frame, (width, height)) |
|
|
|
print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") |
|
out.write(bgr_frame) |
|
|
|
cap.release() |
|
out.release() |
|
|
|
|
|
return output_video_path |
|
|
|
with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo: |
|
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>") |
|
|
|
with gr.Row(): |
|
input_video = gr.Video(label="输入视频") |
|
detect_button = gr.Button("开始检测", variant="primary") |
|
output_video = gr.Video(label="输出视频") |
|
|
|
|
|
detect_button.click(process_video, inputs=input_video, outputs=output_video) |
|
|
|
demo.launch() |