|
|
|
import pandas as pd |
|
|
|
from inference_mtl import * |
|
from utils.os_util import get_file_paths |
|
from utils.plt_util import plot_scores_mtl, syn_plot_scores_mtl, DrawMTL |
|
from utils.time_util import convert_input_time, convert_stamp |
|
|
|
|
|
FOURCC = 'mp4v' |
|
|
|
|
|
def inference_videos(input_video, save_dir='output/out_vdo', detect_mode='frame', frequency=1.0, continuous=True, |
|
show_res=False, plot_score=False, save_score=False, save_vdo=False, save_img=False, img_save_dir=None, |
|
syn_plot=True, resize=None, time_delta=1, save_vdo_path=None, save_plt_path=None, save_csv_path=None): |
|
""" |
|
@param input_video: 输入视频路径,文件夹或单个视频 |
|
@param save_dir: 输出视频存储目录 |
|
@param detect_mode: 抽帧模式,second/frame,表示按秒/帧推理 |
|
@param frequency: 抽帧频率倍数(0-n),乘法关系,数值越大间隔越大,可输入小数 |
|
@param show_res: 是否可视化推理过程的结果 |
|
@param continuous: 输入一个文件夹时是否其中的视频连续推理 |
|
@param plot_score: 是否绘制分数图 |
|
@param save_score: 是否保存推理结果csv文件 |
|
@param save_vdo: 是否保存推理结果视频 |
|
@param save_img: 是否保存中间帧图片,需单独设置保存条件 |
|
@param img_save_dir: 图像保存目录 |
|
@param syn_plot: 是否绘制动态同步的分数图(人数座次不支持) |
|
@param resize: 输出调整 |
|
@param time_delta: 展示视频间隔 0 按键下一张 |
|
@param save_vdo_path: 视频保存路径,默认为None |
|
@param save_plt_path: 分数图保存路径,默认为None |
|
@param save_csv_path: 分数明细保存路径,默认为None |
|
@return: |
|
""" |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
if save_img: |
|
img_save_dir = os.path.join(save_dir, 'images') if img_save_dir is None else img_save_dir |
|
if not os.path.exists(img_save_dir): |
|
os.makedirs(img_save_dir) |
|
save_count = 0 |
|
|
|
separately_save = True |
|
if os.path.isfile(input_video): |
|
vdo_list = [input_video] |
|
elif os.path.isdir(input_video): |
|
vdo_list = get_file_paths(input_video, mod='vdo') |
|
if continuous: |
|
title = os.path.basename(input_video) |
|
save_vdo_path = os.path.join(save_dir, title + '.mp4') if save_vdo_path is None else save_vdo_path |
|
save_plt_path = os.path.join(save_dir, title + '.jpg') if save_plt_path is None else save_plt_path |
|
save_csv_path = os.path.join(save_dir, title + '.csv') if save_csv_path is None else save_csv_path |
|
separately_save = False |
|
frames, seconds = 0, 0 |
|
else: |
|
print(f'No {input_video}') |
|
return |
|
|
|
if save_score: |
|
columns = ['index'] |
|
for ti, task in enumerate(tasks): |
|
if ti < 7: |
|
sub_columns = [f"{task}-{sc}" for sc in classes[ti]] |
|
else: |
|
sub_columns = [task] |
|
columns += sub_columns |
|
|
|
if save_vdo and not separately_save: |
|
video = cv2.VideoCapture(vdo_list[0]) |
|
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) if resize is None else resize[0] |
|
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) if resize is None else resize[1] |
|
fps = video.get(cv2.CAP_PROP_FPS) |
|
fourcc = cv2.VideoWriter_fourcc(*FOURCC) |
|
out_video = cv2.VideoWriter(save_vdo_path, fourcc, fps if detect_mode == 'frame' else 5, |
|
(int(width*1.5), height) if syn_plot else (width, height)) |
|
print(f"result video save in '{save_vdo_path}'") |
|
|
|
res_list = [] |
|
for vdo_path in vdo_list: |
|
vdo_name = os.path.basename(vdo_path) |
|
try: |
|
start_time_str = vdo_name.split('_')[2][:14] |
|
start_time_stamp = convert_input_time(start_time_str, digit=10) |
|
except: |
|
start_time_str, start_time_stamp = '', 0 |
|
|
|
cap = cv2.VideoCapture(vdo_path) |
|
cur_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
if not cur_frames: |
|
continue |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
cur_seconds = int(cur_frames / (fps + 1e-6)) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
print(f"video:{vdo_name} width:{width} height:{height} fps:{fps:.1f} frames:{cur_frames} " |
|
f"seconds: {cur_seconds} start_time:{start_time_str}") |
|
|
|
if resize is not None: |
|
width, height = resize |
|
|
|
cur_res_list = [] |
|
if separately_save: |
|
title = os.path.splitext(vdo_name)[0] |
|
save_vdo_path = os.path.join(save_dir, title + '_res.mp4') if save_vdo_path is None else save_vdo_path |
|
save_plt_path = os.path.join(save_dir, title + '_res.jpg') if save_plt_path is None else save_plt_path |
|
save_csv_path = os.path.join(save_dir, title + '_res.csv') if save_csv_path is None else save_csv_path |
|
if save_vdo: |
|
fourcc = cv2.VideoWriter_fourcc(*FOURCC) |
|
out_video = cv2.VideoWriter(save_vdo_path, fourcc, fps if detect_mode == 'frame' else 5, |
|
(int(1.5*width), height) if syn_plot else (width, height)) |
|
print(f"result video save in '{save_vdo_path}'") |
|
else: |
|
frames += cur_frames |
|
seconds += cur_seconds |
|
|
|
base_box, dx, dy, dl = None, None, None, None |
|
|
|
step = 1 if detect_mode == 'frame' else fps |
|
step = max(1, round(step * frequency)) |
|
count = 0 |
|
for i in range(0, cur_frames, step): |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, i) |
|
ret, frame = cap.read() |
|
if not ret: |
|
|
|
break |
|
count += 1 |
|
|
|
if resize is not None and frame.shape[0] != resize[1] and frame.shape[1] != resize[0]: |
|
frame = cv2.resize(frame, resize) |
|
|
|
cur_res = inference_xyl(frame, vis=False, base_box=base_box, dx=dx, dy=dy, dl=dl) |
|
if cur_res is None: |
|
preds = [0] * len(tasks) |
|
probs = [[0] * len(sub_classes) for sub_classes in classes] |
|
msg_list = ['no driver'] |
|
base_box, dx, dy, dl = None, None, None, None |
|
else: |
|
preds, probs, box, crop_box, msg_list = cur_res |
|
base_box, dx, dy, dl = crop_box.copy(), preds[-3], preds[-2], preds[-1] |
|
|
|
if start_time_stamp: |
|
time_stamp = start_time_stamp + round(i/fps) |
|
cur_time = convert_stamp(time_stamp) |
|
cur_res_list.append(tuple([cur_time+f'-{i}'] + [round(p, 3) for sub_probs in probs for p in sub_probs])) |
|
res_list.append(tuple([cur_time+f'-{i}'] + [round(p, 3) for sub_probs in probs for p in sub_probs])) |
|
else: |
|
cur_time = '' |
|
cur_res_list.append(tuple([count] + [round(p, 3) for sub_probs in probs for p in sub_probs])) |
|
res_list.append(tuple([count] + [round(p, 3) for sub_probs in probs for p in sub_probs])) |
|
|
|
if not count % 10 and count: |
|
msg = "{} {} => {}".format(i, cur_time, '\t'.join(msg_list)) |
|
print(msg) |
|
|
|
if save_img and probs[1][1] > 0.8: |
|
img_name = vdo_name.replace(".mp4", f'_{i}.jpg') if not cur_time else \ |
|
f"{convert_stamp(convert_input_time(cur_time), '%Y%m%d%H%M%S')}_{i}.jpg" |
|
img_save_path = os.path.join(img_save_dir, img_name) |
|
cv2.imwrite(img_save_path, frame) |
|
save_count += 1 |
|
|
|
if show_res or save_vdo: |
|
drawn = drawer.draw_ind(frame) if cur_res is None else \ |
|
drawer.draw_result(frame, preds, probs, box, crop_box, use_mask=False, use_frame=False) |
|
if syn_plot: |
|
score_array = np.array([r[1:] for r in res_list]) |
|
if detect_mode == 'second' and cur_time: |
|
indexes = [r[0][-5:] for r in res_list] |
|
else: |
|
indexes = list(range(len(res_list))) |
|
window_length = 300 if detect_mode == 'frame' else 30 |
|
assert len(score_array) == len(indexes) |
|
score_chart = syn_plot_scores_mtl( |
|
tasks, [[0, 1, 2, 3, 4, 5, 6], [7, 8], [9, 10], [11]], classes, indexes, score_array, |
|
int(0.5*width), height, window_length, width/1280) |
|
|
|
drawn = np.concatenate([drawn, score_chart], axis=1) |
|
|
|
if show_res: |
|
cv2.namedWindow(title, 0) |
|
|
|
|
|
cv2.imshow(title, drawn) |
|
cv2.waitKey(time_delta) |
|
|
|
|
|
if save_vdo: |
|
out_video.write(drawn) |
|
|
|
if separately_save: |
|
if show_res: |
|
cv2.destroyWindow(title) |
|
if plot_score: |
|
res = np.array([r[1:] for r in cur_res_list]) |
|
plot_scores_mtl(tasks, task_types, classes, res, title, detect_mode, save_dir=save_dir, |
|
save_path=save_plt_path, show=show_res) |
|
if save_score: |
|
df = pd.DataFrame(cur_res_list, columns=columns) |
|
df.to_csv(save_csv_path, index=False, float_format='%.3f') |
|
|
|
if save_img: |
|
print(f"total save {save_count} images") |
|
|
|
if not separately_save: |
|
if show_res: |
|
cv2.destroyWindow(title) |
|
if plot_score: |
|
res = np.array([r[1:] for r in res_list]) |
|
plot_scores_mtl(tasks, task_types, classes, res, title, detect_mode, save_dir=save_dir, |
|
save_path=save_plt_path, show=show_res) |
|
if save_score: |
|
df = pd.DataFrame(res_list, columns=columns) |
|
df.to_csv(save_csv_path, index=False, float_format='%.3f') |
|
|
|
return save_vdo_path, save_plt_path, save_vdo_path |
|
|
|
|
|
if __name__ == '__main__': |
|
inference_videos('/Users/didi/Desktop/CARVIDEO_03afqwj7uw5d801e_20230708132305000_20230708132325000.mp4', |
|
save_dir='/Users/didi/Desktop/error_res_v1.0', |
|
detect_mode='second', frequency=0.2, plot_score=True, save_score=True, syn_plot=True, |
|
save_vdo=True, save_img=False, continuous=False, show_res=True, resize=(1280, 720), time_delta=1) |
|
|