|
import os |
|
|
|
from gradio_webrtc import WebRTC |
|
import requests |
|
from PIL import Image |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
from random import choice |
|
import io |
|
|
|
import gradio as gr |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
from io import BytesIO |
|
import random |
|
import tempfile |
|
from pathlib import Path |
|
|
|
import torch |
|
from transformers import pipeline |
|
|
|
from PIL import Image |
|
|
|
import matplotlib.patches as patches |
|
|
|
|
|
detector50 = pipeline(model="facebook/detr-resnet-50") |
|
|
|
detector101 = pipeline(model="facebook/detr-resnet-101") |
|
|
|
|
|
|
|
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff", |
|
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf", |
|
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"] |
|
|
|
fdic = { |
|
|
|
"style" : "italic", |
|
"size" : 15, |
|
"color" : "yellow", |
|
"weight" : "bold" |
|
} |
|
|
|
|
|
def infer(model, in_pil_img): |
|
|
|
results = None |
|
if model == "detr-resnet-101": |
|
results = detector101(in_pil_img) |
|
else: |
|
results = detector50(in_pil_img) |
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
def query_data(model, in_pil_img: Image.Image): |
|
return infer(model, in_pil_img) |
|
|
|
|
|
|
|
def get_figure(in_pil_img): |
|
plt.figure(figsize=(16, 10)) |
|
plt.imshow(in_pil_img) |
|
|
|
ax = plt.gca() |
|
|
|
in_results = query_data(in_pil_img) |
|
|
|
for prediction in in_results: |
|
selected_color = choice(COLORS) |
|
|
|
x, y = prediction['box']['xmin'], prediction['box']['ymin'], |
|
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin'] |
|
|
|
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3)) |
|
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic) |
|
|
|
plt.axis("off") |
|
|
|
return plt.gcf() |
|
|
|
|
|
def infer(in_pil_img): |
|
figure = get_figure(in_pil_img) |
|
|
|
buf = io.BytesIO() |
|
figure.savefig(buf, bbox_inches='tight') |
|
buf.seek(0) |
|
output_pil_img = Image.open(buf) |
|
|
|
return output_pil_img |
|
|
|
|
|
def process_single_frame(frame): |
|
|
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
pil_image = Image.fromarray(rgb_frame) |
|
|
|
|
|
figure = get_figure(pil_image) |
|
|
|
buf = BytesIO() |
|
figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
annotated_image = Image.open(buf).convert('RGB') |
|
|
|
return np.array(annotated_image) |
|
|
|
|
|
def infer_video(input_video_path): |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
|
cap = cv2.VideoCapture(input_video_path) |
|
|
|
if not cap.isOpened(): |
|
raise ValueError("无法打开输入视频文件") |
|
|
|
|
|
|
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
|
|
frame_count = 0 |
|
try: |
|
while frame_count < total_frames: |
|
ret, frame = cap.read() |
|
if not ret: |
|
print(f"提前结束:在第 {frame_count} 帧时无法读取帧") |
|
break |
|
|
|
frame_count += 1 |
|
|
|
|
|
processed_frame = process_single_frame(frame) |
|
bgr_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR) |
|
|
|
yield bgr_frame |
|
|
|
|
|
if frame_count % 30 == 0: |
|
print(f"已处理 {frame_count}/{total_frames} 帧") |
|
|
|
|
|
|
|
|
|
|
|
finally: |
|
cap.release() |
|
|
|
return None |
|
|
|
|
|
|
|
with gr.Blocks(title="长沙电网项目", |
|
css=".gradio-container {background:lightyellow;}" |
|
) as demo: |
|
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>长沙电网项目</div>") |
|
|
|
with gr.Row(): |
|
input_video = gr.Video(label="输入视频") |
|
output_video = WebRTC(label="WebRTC Stream", |
|
rtc_configuration=None, |
|
mode="receive", |
|
modality="video") |
|
detect = gr.Button("Detect", variant="primary") |
|
output_video.stream( |
|
fn=infer_video, |
|
inputs=[input_video], |
|
outputs=[output_video], |
|
trigger=detect.click |
|
) |
|
|
|
demo.launch(debug=True) |
|
|
|
|
|
demo.launch(debug=True) |
|
|
|
|
|
|
|
|