dms3_demo / inference_video_mtl.py
qilongyu
first
91b511b
# -*- coding:utf-8 –*-
import pandas as pd
from inference_mtl import *
from utils.os_util import get_file_paths
from utils.plt_util import plot_scores_mtl, syn_plot_scores_mtl, DrawMTL
from utils.time_util import convert_input_time, convert_stamp
FOURCC = 'mp4v' # output video codec [h264] for internet
def inference_videos(input_video, save_dir='output/out_vdo', detect_mode='frame', frequency=1.0, continuous=True,
show_res=False, plot_score=False, save_score=False, save_vdo=False, save_img=False, img_save_dir=None,
syn_plot=True, resize=None, time_delta=1, save_vdo_path=None, save_plt_path=None, save_csv_path=None):
"""
@param input_video: 输入视频路径,文件夹或单个视频
@param save_dir: 输出视频存储目录
@param detect_mode: 抽帧模式,second/frame,表示按秒/帧推理
@param frequency: 抽帧频率倍数(0-n),乘法关系,数值越大间隔越大,可输入小数
@param show_res: 是否可视化推理过程的结果
@param continuous: 输入一个文件夹时是否其中的视频连续推理
@param plot_score: 是否绘制分数图
@param save_score: 是否保存推理结果csv文件
@param save_vdo: 是否保存推理结果视频
@param save_img: 是否保存中间帧图片,需单独设置保存条件
@param img_save_dir: 图像保存目录
@param syn_plot: 是否绘制动态同步的分数图(人数座次不支持)
@param resize: 输出调整
@param time_delta: 展示视频间隔 0 按键下一张
@param save_vdo_path: 视频保存路径,默认为None
@param save_plt_path: 分数图保存路径,默认为None
@param save_csv_path: 分数明细保存路径,默认为None
@return:
"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if save_img:
img_save_dir = os.path.join(save_dir, 'images') if img_save_dir is None else img_save_dir
if not os.path.exists(img_save_dir):
os.makedirs(img_save_dir)
save_count = 0
separately_save = True
if os.path.isfile(input_video):
vdo_list = [input_video]
elif os.path.isdir(input_video):
vdo_list = get_file_paths(input_video, mod='vdo')
if continuous:
title = os.path.basename(input_video)
save_vdo_path = os.path.join(save_dir, title + '.mp4') if save_vdo_path is None else save_vdo_path
save_plt_path = os.path.join(save_dir, title + '.jpg') if save_plt_path is None else save_plt_path
save_csv_path = os.path.join(save_dir, title + '.csv') if save_csv_path is None else save_csv_path
separately_save = False
frames, seconds = 0, 0
else:
print(f'No {input_video}')
return
if save_score:
columns = ['index']
for ti, task in enumerate(tasks):
if ti < 7:
sub_columns = [f"{task}-{sc}" for sc in classes[ti]]
else:
sub_columns = [task]
columns += sub_columns
if save_vdo and not separately_save:
video = cv2.VideoCapture(vdo_list[0])
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) if resize is None else resize[0]
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) if resize is None else resize[1]
fps = video.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*FOURCC)
out_video = cv2.VideoWriter(save_vdo_path, fourcc, fps if detect_mode == 'frame' else 5,
(int(width*1.5), height) if syn_plot else (width, height))
print(f"result video save in '{save_vdo_path}'")
res_list = []
for vdo_path in vdo_list:
vdo_name = os.path.basename(vdo_path)
try:
start_time_str = vdo_name.split('_')[2][:14]
start_time_stamp = convert_input_time(start_time_str, digit=10)
except:
start_time_str, start_time_stamp = '', 0
cap = cv2.VideoCapture(vdo_path)
cur_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if not cur_frames:
continue
fps = cap.get(cv2.CAP_PROP_FPS)
cur_seconds = int(cur_frames / (fps + 1e-6))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"video:{vdo_name} width:{width} height:{height} fps:{fps:.1f} frames:{cur_frames} "
f"seconds: {cur_seconds} start_time:{start_time_str}")
if resize is not None:
width, height = resize
cur_res_list = []
if separately_save:
title = os.path.splitext(vdo_name)[0]
save_vdo_path = os.path.join(save_dir, title + '_res.mp4') if save_vdo_path is None else save_vdo_path
save_plt_path = os.path.join(save_dir, title + '_res.jpg') if save_plt_path is None else save_plt_path
save_csv_path = os.path.join(save_dir, title + '_res.csv') if save_csv_path is None else save_csv_path
if save_vdo:
fourcc = cv2.VideoWriter_fourcc(*FOURCC)
out_video = cv2.VideoWriter(save_vdo_path, fourcc, fps if detect_mode == 'frame' else 5,
(int(1.5*width), height) if syn_plot else (width, height))
print(f"result video save in '{save_vdo_path}'")
else:
frames += cur_frames
seconds += cur_seconds
base_box, dx, dy, dl = None, None, None, None
step = 1 if detect_mode == 'frame' else fps
step = max(1, round(step * frequency))
count = 0
for i in range(0, cur_frames, step):
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret:
# print('video end!')
break
count += 1
if resize is not None and frame.shape[0] != resize[1] and frame.shape[1] != resize[0]:
frame = cv2.resize(frame, resize)
cur_res = inference_xyl(frame, vis=False, base_box=base_box, dx=dx, dy=dy, dl=dl)
if cur_res is None:
preds = [0] * len(tasks)
probs = [[0] * len(sub_classes) for sub_classes in classes]
msg_list = ['no driver']
base_box, dx, dy, dl = None, None, None, None
else:
preds, probs, box, crop_box, msg_list = cur_res
base_box, dx, dy, dl = crop_box.copy(), preds[-3], preds[-2], preds[-1]
if start_time_stamp:
time_stamp = start_time_stamp + round(i/fps)
cur_time = convert_stamp(time_stamp)
cur_res_list.append(tuple([cur_time+f'-{i}'] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
res_list.append(tuple([cur_time+f'-{i}'] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
else:
cur_time = ''
cur_res_list.append(tuple([count] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
res_list.append(tuple([count] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
if not count % 10 and count:
msg = "{} {} => {}".format(i, cur_time, '\t'.join(msg_list))
print(msg)
if save_img and probs[1][1] > 0.8: # Todo:设置不同的保存条件
img_name = vdo_name.replace(".mp4", f'_{i}.jpg') if not cur_time else \
f"{convert_stamp(convert_input_time(cur_time), '%Y%m%d%H%M%S')}_{i}.jpg"
img_save_path = os.path.join(img_save_dir, img_name)
cv2.imwrite(img_save_path, frame)
save_count += 1
if show_res or save_vdo:
drawn = drawer.draw_ind(frame) if cur_res is None else \
drawer.draw_result(frame, preds, probs, box, crop_box, use_mask=False, use_frame=False)
if syn_plot:
score_array = np.array([r[1:] for r in res_list])
if detect_mode == 'second' and cur_time:
indexes = [r[0][-5:] for r in res_list]
else:
indexes = list(range(len(res_list)))
window_length = 300 if detect_mode == 'frame' else 30
assert len(score_array) == len(indexes)
score_chart = syn_plot_scores_mtl(
tasks, [[0, 1, 2, 3, 4, 5, 6], [7, 8], [9, 10], [11]], classes, indexes, score_array,
int(0.5*width), height, window_length, width/1280)
drawn = np.concatenate([drawn, score_chart], axis=1)
if show_res:
cv2.namedWindow(title, 0)
# cv2.moveWindow(title, 0, 0)
# cv2.setWindowProperty(title, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
cv2.imshow(title, drawn)
cv2.waitKey(time_delta)
# write the frame after processing
if save_vdo:
out_video.write(drawn)
if separately_save:
if show_res:
cv2.destroyWindow(title)
if plot_score:
res = np.array([r[1:] for r in cur_res_list])
plot_scores_mtl(tasks, task_types, classes, res, title, detect_mode, save_dir=save_dir,
save_path=save_plt_path, show=show_res)
if save_score:
df = pd.DataFrame(cur_res_list, columns=columns)
df.to_csv(save_csv_path, index=False, float_format='%.3f')
if save_img:
print(f"total save {save_count} images")
if not separately_save:
if show_res:
cv2.destroyWindow(title)
if plot_score:
res = np.array([r[1:] for r in res_list])
plot_scores_mtl(tasks, task_types, classes, res, title, detect_mode, save_dir=save_dir,
save_path=save_plt_path, show=show_res)
if save_score:
df = pd.DataFrame(res_list, columns=columns)
df.to_csv(save_csv_path, index=False, float_format='%.3f')
return save_vdo_path, save_plt_path, save_vdo_path
if __name__ == '__main__':
inference_videos('/Users/didi/Desktop/CARVIDEO_03afqwj7uw5d801e_20230708132305000_20230708132325000.mp4',
save_dir='/Users/didi/Desktop/error_res_v1.0',
detect_mode='second', frequency=0.2, plot_score=True, save_score=True, syn_plot=True,
save_vdo=True, save_img=False, continuous=False, show_res=True, resize=(1280, 720), time_delta=1)