qilongyu
commited on
Commit
·
446f9ef
1
Parent(s):
3c46991
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- app.py +179 -0
- data/dms3/v1.0/dms3_mtl_v1.0.onnx +3 -0
- data/dms3/v1.0/dms3_mtl_v1.0.pth +3 -0
- data/dms3/v1.0//345/216/273/346/216/211/345/215/261/351/231/251/345/212/250/344/275/234litehead.png +0 -0
- files/dms3_icon.png +0 -0
- files/frame_b.png +0 -0
- files/info_box.png +0 -0
- head_detect/__pycache__/demo.cpython-38.pyc +0 -0
- head_detect/demo.py +59 -0
- head_detect/demo_ori.py +58 -0
- head_detect/head_detector/__pycache__/face_utils.cpython-38.pyc +0 -0
- head_detect/head_detector/__pycache__/head_detectorv4.cpython-38.pyc +0 -0
- head_detect/head_detector/face_detector.py +258 -0
- head_detect/head_detector/face_utils.py +49 -0
- head_detect/head_detector/head_detectorv2.py +168 -0
- head_detect/head_detector/head_detectorv3.py +216 -0
- head_detect/head_detector/head_detectorv4.py +216 -0
- head_detect/head_detector/pose.py +55 -0
- head_detect/models/HeadDetectorv1.6.onnx +3 -0
- head_detect/utils_quailty_assurance/__pycache__/utils_quailty_assurance.cpython-38.pyc +0 -0
- head_detect/utils_quailty_assurance/draw_tools.py +253 -0
- head_detect/utils_quailty_assurance/metrics.py +123 -0
- head_detect/utils_quailty_assurance/result_to_coco.py +108 -0
- head_detect/utils_quailty_assurance/utils_quailty_assurance.py +79 -0
- head_detect/utils_quailty_assurance/video2imglist.py +57 -0
- inference_mtl.py +266 -0
- inference_video_mtl.py +224 -0
- models/__pycache__/shufflenet2_att_m.cpython-38.pyc +0 -0
- models/module/__pycache__/activation.cpython-38.pyc +0 -0
- models/module/__pycache__/conv.cpython-38.pyc +0 -0
- models/module/__pycache__/init_weights.cpython-38.pyc +0 -0
- models/module/__pycache__/norm.cpython-38.pyc +0 -0
- models/module/activation.py +17 -0
- models/module/blocks.py +300 -0
- models/module/conv.py +340 -0
- models/module/fpn.py +165 -0
- models/module/init_weights.py +44 -0
- models/module/norm.py +55 -0
- models/shufflenet2_att_m.py +265 -0
- utils/__pycache__/common.cpython-38.pyc +0 -0
- utils/__pycache__/images.cpython-38.pyc +0 -0
- utils/__pycache__/labels.cpython-38.pyc +0 -0
- utils/__pycache__/multiprogress.cpython-38.pyc +0 -0
- utils/__pycache__/os_util.cpython-38.pyc +0 -0
- utils/__pycache__/plt_util.cpython-38.pyc +0 -0
- utils/__pycache__/time_util.cpython-38.pyc +0 -0
- utils/common.py +227 -0
- utils/export_util.py +118 -0
- utils/images.py +539 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/dms3/v1.0/dms3_mtl_v1.0.pth filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/dms3/v1.0/dms3_mtl_v1.0.onnx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
head_detect/models/HeadDetectorv1.6.onnx filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from inference_mtl import inference_xyl
|
| 6 |
+
from inference_video_mtl import inference_videos
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
SAVE_DIR = "/nfs/volume-236-2/eval_model/model_res"
|
| 10 |
+
# SAVE_DIR = "/Users/didi/Desktop/model_res"
|
| 11 |
+
CURRENT_DIR = ''
|
| 12 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
PLOT_DIR = ""
|
| 15 |
+
TASKS = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l", 'shift_x', 'shift_y', 'expand']
|
| 16 |
+
TASK_IDX = 0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def last_img():
|
| 20 |
+
global TASK_IDX
|
| 21 |
+
TASK_IDX -= 1
|
| 22 |
+
if TASK_IDX < 0:
|
| 23 |
+
TASK_IDX = len(TASKS) - 1
|
| 24 |
+
plt_path = f"{PLOT_DIR}/{TASKS[TASK_IDX]}.jpg"
|
| 25 |
+
|
| 26 |
+
if not os.path.exists(plt_path):
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
return plt_path
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def next_img():
|
| 33 |
+
global TASK_IDX
|
| 34 |
+
TASK_IDX += 1
|
| 35 |
+
if TASK_IDX >= len(TASKS):
|
| 36 |
+
TASK_IDX = 0
|
| 37 |
+
plt_path = f"{PLOT_DIR}/{TASKS[TASK_IDX]}.jpg"
|
| 38 |
+
|
| 39 |
+
if not os.path.exists(plt_path):
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
return plt_path
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def inference_img(inp, engine):
|
| 46 |
+
inp = inp[:, :, ::-1]
|
| 47 |
+
|
| 48 |
+
if engine == "DMS3":
|
| 49 |
+
drawn = inference_xyl(inp, vis=False, return_drawn=True)
|
| 50 |
+
drawn = drawn[:, :, ::-1]
|
| 51 |
+
return drawn
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def inference_txt(inp, engine):
|
| 55 |
+
inp = inp[:, :, ::-1]
|
| 56 |
+
|
| 57 |
+
if engine == "DMS3":
|
| 58 |
+
msg_list = inference_xyl(inp, vis=False, return_drawn=False)[-1]
|
| 59 |
+
return "\n".join(msg_list)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def inference_video(video, engine, device_type_, show_score_, syn_plot_):
|
| 63 |
+
global CURRENT_DIR
|
| 64 |
+
CURRENT_DIR = os.path.join(SAVE_DIR, engine)
|
| 65 |
+
os.makedirs(CURRENT_DIR, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
base_name = os.path.splitext(os.path.basename(video))[0][:-8]
|
| 68 |
+
|
| 69 |
+
video_info = []
|
| 70 |
+
if device_type_ == "B100":
|
| 71 |
+
resize = (1280, 720)
|
| 72 |
+
video_info.append("720")
|
| 73 |
+
elif device_type_ == "B200":
|
| 74 |
+
resize = (1280, 960)
|
| 75 |
+
video_info.append("960")
|
| 76 |
+
else:
|
| 77 |
+
resize = None
|
| 78 |
+
|
| 79 |
+
is_syn_plot = syn_plot_ == "是"
|
| 80 |
+
is_show_score = show_score_ == "是"
|
| 81 |
+
|
| 82 |
+
video_info.append("1" if is_syn_plot else "0")
|
| 83 |
+
video_info.append(base_name + "_res.mp4")
|
| 84 |
+
|
| 85 |
+
save_video_name = "_".join(video_info)
|
| 86 |
+
save_plt_dir = base_name + "_plots"
|
| 87 |
+
save_csv_name = base_name + "_res.csv"
|
| 88 |
+
|
| 89 |
+
save_video_path = f"{CURRENT_DIR}/{save_video_name}"
|
| 90 |
+
save_plt_path = f"{CURRENT_DIR}/{save_plt_dir}"
|
| 91 |
+
save_csv_path = f"{CURRENT_DIR}/{save_csv_name}"
|
| 92 |
+
|
| 93 |
+
global PLOT_DIR, TASK_IDX
|
| 94 |
+
PLOT_DIR = save_plt_path
|
| 95 |
+
TASK_IDX = 0
|
| 96 |
+
|
| 97 |
+
if os.path.exists(save_video_path) and os.path.exists(save_plt_path):
|
| 98 |
+
if not is_show_score:
|
| 99 |
+
return save_video_path, f"{save_plt_path}/{TASKS[TASK_IDX]}.jpg", None
|
| 100 |
+
elif os.path.exists(save_csv_path):
|
| 101 |
+
return save_video_path, f"{save_plt_path}/{TASKS[TASK_IDX]}.jpg", save_csv_path
|
| 102 |
+
|
| 103 |
+
inference_videos(
|
| 104 |
+
video, save_dir=SAVE_DIR, detect_mode='second', frequency=0.2, plot_score=True, save_score=is_show_score,
|
| 105 |
+
syn_plot=is_syn_plot, save_vdo=True, save_img=False, continuous=False, show_res=False, resize=resize,
|
| 106 |
+
time_delta=1, save_vdo_path=save_video_path, save_plt_path=save_plt_path, save_csv_path=save_csv_path)
|
| 107 |
+
|
| 108 |
+
if not is_show_score:
|
| 109 |
+
return save_video_path, f"{save_plt_path}/{TASKS[TASK_IDX]}.jpg", None
|
| 110 |
+
|
| 111 |
+
return save_video_path, f"{save_plt_path}/{TASKS[TASK_IDX]}.jpg", save_csv_path
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def reset_state():
|
| 115 |
+
return None, None, None, None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
with gr.Blocks() as demo:
|
| 119 |
+
gr.HTML("""<h1 align="center">算法可视化Demo</h1>""")
|
| 120 |
+
|
| 121 |
+
with gr.Tab("Text"): # 标签页1
|
| 122 |
+
with gr.Row(): # 并行显示,可开多列
|
| 123 |
+
with gr.Column(): # 并列显示,可开多行
|
| 124 |
+
input_img1 = gr.Image(label="", show_label=False)
|
| 125 |
+
engine1 = gr.Dropdown(["DMS3", "EMS", "TURN2"], label="算法引擎", value="DMS3")
|
| 126 |
+
btn_txt = gr.Button(value="Submit", label="Submit Image", variant="primary")
|
| 127 |
+
btn_clear1 = gr.Button("Clear")
|
| 128 |
+
|
| 129 |
+
out1 = gr.Text(label="")
|
| 130 |
+
|
| 131 |
+
btn_txt.click(inference_txt, inputs=[input_img1, engine1], outputs=out1, show_progress=True) # 触发
|
| 132 |
+
btn_clear1.click(reset_state, inputs=[], outputs=[input_img1, out1])
|
| 133 |
+
|
| 134 |
+
with gr.Tab("Image"): # 标签页1
|
| 135 |
+
with gr.Row(): # 并行显示,可开多列
|
| 136 |
+
with gr.Column(): # 并列显示,可开多行
|
| 137 |
+
input_img2 = gr.Image(label="", show_label=False)
|
| 138 |
+
engine2 = gr.Dropdown(["DMS3", "EMS", "TURN2"], label="算法引擎", value="DMS3")
|
| 139 |
+
btn_img = gr.Button(value="Submit", label="Submit Image", variant="primary")
|
| 140 |
+
btn_clear2 = gr.Button("Clear")
|
| 141 |
+
|
| 142 |
+
out2 = gr.Image(label="", show_label=False)
|
| 143 |
+
|
| 144 |
+
btn_img.click(inference_img, inputs=[input_img2, engine2], outputs=out2, show_progress=True) # 触发
|
| 145 |
+
btn_clear2.click(reset_state, inputs=[], outputs=[input_img2, out2])
|
| 146 |
+
|
| 147 |
+
with gr.Tab("Video"): # 标签页2
|
| 148 |
+
with gr.Row(): # 并行显示,可开多列
|
| 149 |
+
|
| 150 |
+
with gr.Column(): # 并列显示,可开多行
|
| 151 |
+
input_vdo = gr.Video(label="", show_label=False)
|
| 152 |
+
engine3 = gr.Dropdown(["DMS3", "EMS", "TURN2"], label="算法引擎", value="DMS3")
|
| 153 |
+
device_type = gr.Radio(["原始", "B100", "B200"], value="原始", label="Device Type") # 单选
|
| 154 |
+
|
| 155 |
+
with gr.Row():
|
| 156 |
+
show_score = gr.Radio(["是", "否"], value="是", label="分数明细") # 单选
|
| 157 |
+
syn_plot = gr.Radio(["是", "否"], value="是", label="Syn Plot") # 单选
|
| 158 |
+
|
| 159 |
+
btn_vdo = gr.Button(value="Submit", label="Submit Video", variant="primary")
|
| 160 |
+
btn_clear3 = gr.Button("Clear")
|
| 161 |
+
|
| 162 |
+
with gr.Column():
|
| 163 |
+
out3 = gr.PlayableVideo(label="", show_label=False)
|
| 164 |
+
out3_plt = gr.Image(label=f"{TASKS[TASK_IDX]}", show_label=False)
|
| 165 |
+
|
| 166 |
+
with gr.Row():
|
| 167 |
+
btn_before = gr.Button(value="上一张", label="before")
|
| 168 |
+
btn_next = gr.Button(value="下一张", label="next", variant="primary")
|
| 169 |
+
|
| 170 |
+
out3_df = gr.DataFrame()
|
| 171 |
+
|
| 172 |
+
btn_vdo.click(inference_video, inputs=[input_vdo, engine3, device_type, show_score, syn_plot],
|
| 173 |
+
outputs=[out3, out3_plt, out3_df], show_progress=True) # 触发
|
| 174 |
+
btn_clear3.click(reset_state, inputs=[], outputs=[input_vdo, out3, out3_plt, out3_df])
|
| 175 |
+
|
| 176 |
+
btn_before.click(last_img, inputs=[], outputs=out3_plt)
|
| 177 |
+
btn_next.click(next_img, inputs=[], outputs=out3_plt)
|
| 178 |
+
|
| 179 |
+
demo.queue().launch(share=True, inbrowser=True, favicon_path="files/dms3_icon.png")
|
data/dms3/v1.0/dms3_mtl_v1.0.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4f06fd029196afea2522b194b1ba292a10a31f2792ca5c3d12f35ffd5530f02
|
| 3 |
+
size 2176811
|
data/dms3/v1.0/dms3_mtl_v1.0.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:483c1b16abc39462272a35006cc92993d98c385f8d05cc2ad869ad5d041b17e7
|
| 3 |
+
size 2444005
|
data/dms3/v1.0//345/216/273/346/216/211/345/215/261/351/231/251/345/212/250/344/275/234litehead.png
ADDED
|
files/dms3_icon.png
ADDED
|
|
files/frame_b.png
ADDED
|
files/info_box.png
ADDED
|
head_detect/__pycache__/demo.cpython-38.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
head_detect/demo.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
|
| 8 |
+
from head_detect.head_detector.head_detectorv4 import HeadDetector
|
| 9 |
+
from head_detect.utils_quailty_assurance.utils_quailty_assurance import find_files, write_json
|
| 10 |
+
from utils.images import crop_face_square_rate
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
|
| 14 |
+
model_path = f"{ROOT_DIR}/models/HeadDetectorv1.6.onnx"
|
| 15 |
+
|
| 16 |
+
headDetector = HeadDetector(onnx_path=model_path)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def detect_driver_face(img, show_res=False, show_crop=False):
|
| 20 |
+
if isinstance(img, str):
|
| 21 |
+
img = cv2.imread(img)
|
| 22 |
+
else:
|
| 23 |
+
img = img.copy()
|
| 24 |
+
image_heigth,image_width = img.shape[:2]
|
| 25 |
+
# origin_image = cv2.resize(img, (1280, 720))
|
| 26 |
+
# origin_image = img[:, w//2:, :]
|
| 27 |
+
# short_side = min(image_heigth,image_width)
|
| 28 |
+
width_shift,heigth_shift = image_width//2,0
|
| 29 |
+
# cv2.imwrite("squre.jpg",img[heigth_shift:,width_shift:,:])
|
| 30 |
+
bboxes = headDetector.run(img[:,width_shift:, :], get_largest=True) # 人脸检测,获取面积最大的人脸
|
| 31 |
+
if not bboxes:
|
| 32 |
+
return [0, 0, 0, 0], 0
|
| 33 |
+
box = [int(width_shift+bboxes[0][0]),int(heigth_shift+bboxes[0][1]),int(width_shift+bboxes[0][2]),int(heigth_shift+bboxes[0][3])]
|
| 34 |
+
|
| 35 |
+
# box[0], box[1] = max(0, box[0]), max(0, box[1])
|
| 36 |
+
# box[2], box[3] = min(w, box[2]), min(h, box[3])
|
| 37 |
+
score = bboxes[0][-1]
|
| 38 |
+
|
| 39 |
+
if (box[2] - box[0]) == 0 or (box[3] - box[1]) == 0:
|
| 40 |
+
return [0, 0, 0, 0], 0
|
| 41 |
+
|
| 42 |
+
# print(box, pred, score)
|
| 43 |
+
if show_res:
|
| 44 |
+
x0, y0, x1, y1 = box
|
| 45 |
+
print(box)
|
| 46 |
+
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), thickness=2)
|
| 47 |
+
if show_crop:
|
| 48 |
+
_, crop_box = crop_face_square_rate(img, box, rate=-0.07)
|
| 49 |
+
cx0, cy0, cx1, cy1 = crop_box
|
| 50 |
+
print(crop_box)
|
| 51 |
+
cv2.rectangle(img, (cx0, cy0), (cx1, cy1), (0, 255, 0), thickness=2)
|
| 52 |
+
cv2.imshow('res', img)
|
| 53 |
+
cv2.waitKey(0)
|
| 54 |
+
|
| 55 |
+
return box, score
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
detect_driver_face('../../test_visual/look_down.jpg', show_res=True, show_crop=True)
|
head_detect/demo_ori.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
|
| 8 |
+
from head_detect.head_detector.head_detectorv2 import HeadDetector
|
| 9 |
+
from head_detect.utils_quailty_assurance.utils_quailty_assurance import find_files, write_json
|
| 10 |
+
from utils.images import crop_face_square_rate
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
|
| 14 |
+
model_path = f"{ROOT_DIR}/models/Mona_HeadDetector_v1_straght.onnx"
|
| 15 |
+
|
| 16 |
+
headDetector = HeadDetector(onnx_path=model_path)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def detect_driver_face(img, show_res=False, show_crop=False):
|
| 20 |
+
if isinstance(img, str):
|
| 21 |
+
img = cv2.imread(img)
|
| 22 |
+
else:
|
| 23 |
+
img = img.copy()
|
| 24 |
+
h, w = img.shape[:2]
|
| 25 |
+
# origin_image = cv2.resize(img, (1280, 720))
|
| 26 |
+
# origin_image = img[:, w//2:, :]
|
| 27 |
+
bboxes = headDetector.run(img[:, w//2:, :], get_largest=True) # 人脸检测,获取面积最大的人脸
|
| 28 |
+
if not bboxes:
|
| 29 |
+
return [0, 0, 0, 0], 0
|
| 30 |
+
|
| 31 |
+
box = [int(b) for b in bboxes[0][:4]]
|
| 32 |
+
box[0] += w // 2
|
| 33 |
+
box[2] += w // 2
|
| 34 |
+
# box[0], box[1] = max(0, box[0]), max(0, box[1])
|
| 35 |
+
# box[2], box[3] = min(w, box[2]), min(h, box[3])
|
| 36 |
+
score = bboxes[0][-1]
|
| 37 |
+
|
| 38 |
+
if (box[2] - box[0]) == 0 or (box[3] - box[1]) == 0:
|
| 39 |
+
return [0, 0, 0, 0], 0
|
| 40 |
+
|
| 41 |
+
# print(box, pred, score)
|
| 42 |
+
if show_res:
|
| 43 |
+
x0, y0, x1, y1 = box
|
| 44 |
+
print(box)
|
| 45 |
+
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), thickness=2)
|
| 46 |
+
if show_crop:
|
| 47 |
+
_, crop_box = crop_face_square_rate(img, box, rate=-0.07)
|
| 48 |
+
cx0, cy0, cx1, cy1 = crop_box
|
| 49 |
+
print(crop_box)
|
| 50 |
+
cv2.rectangle(img, (cx0, cy0), (cx1, cy1), (0, 255, 0), thickness=2)
|
| 51 |
+
cv2.imshow('res', img)
|
| 52 |
+
cv2.waitKey(0)
|
| 53 |
+
|
| 54 |
+
return box, score
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
detect_driver_face('../../test_visual/look_down.jpg', show_res=True, show_crop=True)
|
head_detect/head_detector/__pycache__/face_utils.cpython-38.pyc
ADDED
|
Binary file (1.55 kB). View file
|
|
|
head_detect/head_detector/__pycache__/head_detectorv4.cpython-38.pyc
ADDED
|
Binary file (6.65 kB). View file
|
|
|
head_detect/head_detector/face_detector.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#针对1280X720的图片
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import copy
|
| 5 |
+
import onnxruntime
|
| 6 |
+
import numpy as np
|
| 7 |
+
from face_detector.face_utils import letterbox
|
| 8 |
+
#from face_utils import letterbox
|
| 9 |
+
|
| 10 |
+
class FaceDetector:
|
| 11 |
+
def __init__(self, onnx_path="models/smoke_phone_mosaic_v2.onnx"):
|
| 12 |
+
|
| 13 |
+
self.onnx_path = onnx_path
|
| 14 |
+
self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
|
| 15 |
+
self.input_name = self.get_input_name(self.onnx_session)
|
| 16 |
+
self.output_name = self.get_output_name(self.onnx_session)
|
| 17 |
+
self.multi_class = 3 #修改点1:代表输出的结果为[normal smoke,phone,drink]四类
|
| 18 |
+
|
| 19 |
+
def get_output_name(self, onnx_session):
|
| 20 |
+
output_name = []
|
| 21 |
+
for node in onnx_session.get_outputs():
|
| 22 |
+
output_name.append(node.name)
|
| 23 |
+
return output_name
|
| 24 |
+
|
| 25 |
+
def get_input_name(self, onnx_session):
|
| 26 |
+
input_name = []
|
| 27 |
+
for node in onnx_session.get_inputs():
|
| 28 |
+
input_name.append(node.name)
|
| 29 |
+
return input_name
|
| 30 |
+
|
| 31 |
+
def get_input_feed(self, input_name, image_tensor):
|
| 32 |
+
|
| 33 |
+
input_feed = {}
|
| 34 |
+
for name in input_name:
|
| 35 |
+
input_feed[name] = image_tensor
|
| 36 |
+
return input_feed
|
| 37 |
+
|
| 38 |
+
def after_process(self,pred):
|
| 39 |
+
# 输入尺寸320,192 降8、16、32倍,对应输出尺寸为(40、20、10)
|
| 40 |
+
stride = np.array([8., 16., 32.])
|
| 41 |
+
x=[pred[0],pred[1],pred[2]]
|
| 42 |
+
# ============yolov5参数 start============
|
| 43 |
+
nc=1
|
| 44 |
+
no=16 + self.multi_class
|
| 45 |
+
nl=3
|
| 46 |
+
na=3
|
| 47 |
+
#grid=[torch.zeros(1).to(device)] * nl
|
| 48 |
+
grid=[np.zeros(1)]*nl
|
| 49 |
+
anchor_grid=np.array([[[[[[ 4., 5.]]],
|
| 50 |
+
[[[ 8., 10.]]],
|
| 51 |
+
[[[ 13., 16.]]]]],
|
| 52 |
+
[[[[[ 23., 29.]]],
|
| 53 |
+
[[[ 43., 55.]]],
|
| 54 |
+
[[[ 73., 105.]]]]],
|
| 55 |
+
[[[[[146., 217.]]],
|
| 56 |
+
[[[231., 300.]]],
|
| 57 |
+
[[[335., 433.]]]]]])
|
| 58 |
+
# ============yolov5-0.5参数 end============
|
| 59 |
+
z = []
|
| 60 |
+
for i in range(len(x)):
|
| 61 |
+
|
| 62 |
+
bs,ny, nx = x[i].shape[0],x[i].shape[2] ,x[i].shape[3]
|
| 63 |
+
|
| 64 |
+
if grid[i].shape[2:4] != x[i].shape[2:4]:
|
| 65 |
+
grid[i] = self._make_grid(nx, ny)
|
| 66 |
+
|
| 67 |
+
y = np.full_like(x[i],0)
|
| 68 |
+
|
| 69 |
+
#y[..., [0,1,2,3,4,15]] = self.sigmoid_v(x[i][..., [0,1,2,3,4,15]])
|
| 70 |
+
y[..., [0,1,2,3,4,15,16,17,18]] = self.sigmoid_v(x[i][..., [0,1,2,3,4,15,16,17,18]])
|
| 71 |
+
# 同事sigmoid_v人脸的置信度和危险动作置信度
|
| 72 |
+
|
| 73 |
+
y[..., 5:15] = x[i][..., 5:15]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i]) * stride[i] # xy
|
| 77 |
+
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
|
| 78 |
+
|
| 79 |
+
y[..., 5:7] = y[..., 5:7] * anchor_grid[i] + grid[i] * stride[i] # landmark x1 y1
|
| 80 |
+
y[..., 7:9] = y[..., 7:9] * anchor_grid[i] + grid[i] * stride[i]# landmark x2 y2
|
| 81 |
+
y[..., 9:11] = y[..., 9:11] * anchor_grid[i] + grid[i] * stride[i]# landmark x3 y3
|
| 82 |
+
y[..., 11:13] = y[..., 11:13] * anchor_grid[i] + grid[i] * stride[i]# landmark x4 y4
|
| 83 |
+
y[..., 13:15] = y[..., 13:15] * anchor_grid[i] + grid[i] * stride[i]# landmark x5 y5
|
| 84 |
+
|
| 85 |
+
z.append(y.reshape((bs, -1, no)))
|
| 86 |
+
return np.concatenate(z, 1)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _make_grid(self, nx, ny):
|
| 90 |
+
yv, xv = np.meshgrid(np.arange(ny), np.arange(nx),indexing = 'ij')
|
| 91 |
+
return np.stack((xv, yv), 2).reshape((1, 1, ny, nx, 2)).astype(float)
|
| 92 |
+
|
| 93 |
+
def sigmoid_v(self, array):
|
| 94 |
+
return np.reciprocal(np.exp(-array) + 1.0)
|
| 95 |
+
|
| 96 |
+
def img_process(self,orgimg,long_side=320,stride_max=32):
|
| 97 |
+
|
| 98 |
+
#orgimg=cv2.imread(img_path)
|
| 99 |
+
img0 = copy.deepcopy(orgimg)
|
| 100 |
+
h0, w0 = orgimg.shape[:2] # orig hw
|
| 101 |
+
r = long_side/ max(h0, w0) # resize image to img_size
|
| 102 |
+
if r != 1: # always resize down, only resize up if training with augmentation
|
| 103 |
+
interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
|
| 104 |
+
img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
| 105 |
+
img = letterbox(img0, new_shape=(192,320),auto=False)[0] # auto True最小矩形 False固定尺度
|
| 106 |
+
# cv2.imwrite("convert.jpg",img=img)
|
| 107 |
+
# Convert
|
| 108 |
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
|
| 109 |
+
img = img.astype("float32") # uint8 to fp16/32
|
| 110 |
+
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
| 111 |
+
img = img[np.newaxis,:]
|
| 112 |
+
|
| 113 |
+
return img,orgimg
|
| 114 |
+
|
| 115 |
+
def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
|
| 116 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
| 117 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 118 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 119 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 120 |
+
else:
|
| 121 |
+
gain = ratio_pad[0][0]
|
| 122 |
+
pad = ratio_pad[1]
|
| 123 |
+
coords[:, [0, 2, 5, 7, 9, 11, 13]] -= pad[0] # x padding
|
| 124 |
+
coords[:, [1, 3, 6, 8, 10,12, 14]] -= pad[1] # y padding
|
| 125 |
+
|
| 126 |
+
coords[:, [0,1,2,3,5,6,7,8,9,10,11,12,13,14]] /= gain
|
| 127 |
+
|
| 128 |
+
return coords
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def non_max_suppression(self, boxes,confs, iou_thres=0.6):
|
| 132 |
+
|
| 133 |
+
x1 = boxes[:, 0]
|
| 134 |
+
y1 = boxes[:, 1]
|
| 135 |
+
x2 = boxes[:, 2]
|
| 136 |
+
y2 = boxes[:, 3]
|
| 137 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 138 |
+
order = confs.flatten().argsort()[::-1]
|
| 139 |
+
keep = []
|
| 140 |
+
while order.size > 0:
|
| 141 |
+
i = order[0]
|
| 142 |
+
keep.append(i)
|
| 143 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 144 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 145 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 146 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 147 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 148 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 149 |
+
inter = w * h
|
| 150 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 151 |
+
inds = np.where( ovr <= iou_thres)[0]
|
| 152 |
+
order = order[inds + 1]
|
| 153 |
+
|
| 154 |
+
return boxes[keep]
|
| 155 |
+
|
| 156 |
+
def nms(self, pred, conf_thres=0.1,iou_thres=0.5):
|
| 157 |
+
xc = pred[..., 4] > conf_thres
|
| 158 |
+
pred = pred[xc]
|
| 159 |
+
pred[:, 15:] *= pred[:, 4:5]
|
| 160 |
+
|
| 161 |
+
# best class only
|
| 162 |
+
confs = np.amax(pred[:, 15:16], 1, keepdims=True)
|
| 163 |
+
pred[..., 0:4] = self.xywh2xyxy(pred[..., 0:4])
|
| 164 |
+
return self.non_max_suppression(pred, confs, iou_thres)
|
| 165 |
+
|
| 166 |
+
def xywh2xyxy(self, x):
|
| 167 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 168 |
+
y = np.zeros_like(x)
|
| 169 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
| 170 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
| 171 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
| 172 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
| 173 |
+
return y
|
| 174 |
+
|
| 175 |
+
def get_largest_face(self,pred):
|
| 176 |
+
"""[获取图片中最大的人脸]
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
object ([dict]): [人脸数据]
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
[int]: [最大人脸的坐标]
|
| 183 |
+
"""
|
| 184 |
+
max_index = 0
|
| 185 |
+
max_value = 0
|
| 186 |
+
for index in range(len(pred)):
|
| 187 |
+
xmin,ymin,xmax,ymax = pred[index][:4]
|
| 188 |
+
w = xmax - xmin
|
| 189 |
+
h = ymax - ymin
|
| 190 |
+
if w*h > max_value:
|
| 191 |
+
max_value = w*h
|
| 192 |
+
max_index = index
|
| 193 |
+
return max_index
|
| 194 |
+
|
| 195 |
+
def run(self, ori_image,get_largest=True):
|
| 196 |
+
#detial_dict = {}
|
| 197 |
+
|
| 198 |
+
img,orgimg=self.img_process(ori_image,long_side=320) #[1,3,640,640]
|
| 199 |
+
#print(img.shape)
|
| 200 |
+
input_feed = self.get_input_feed(self.input_name, img)
|
| 201 |
+
pred = self.onnx_session.run(self.output_name, input_feed=input_feed)
|
| 202 |
+
#detial_dict["before_decoder"] = [i.tolist() for i in pred]
|
| 203 |
+
pred=self.after_process(pred) # torch后处理
|
| 204 |
+
#detial_dict["after_decoder"] = copy.deepcopy(pred.tolist())
|
| 205 |
+
|
| 206 |
+
pred=self.nms(pred[0],0.3,0.5)
|
| 207 |
+
#detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
|
| 208 |
+
pred=self.scale_coords(img.shape[2:], pred, orgimg.shape)
|
| 209 |
+
#detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
|
| 210 |
+
|
| 211 |
+
if get_largest and pred.shape[0]!=0 :
|
| 212 |
+
pred_index = self.get_largest_face(pred)
|
| 213 |
+
pred = pred[[pred_index]]
|
| 214 |
+
bboxes = pred[:,[0,1,2,3,15]]
|
| 215 |
+
landmarks = pred[:,5:15]
|
| 216 |
+
## 修改点,获取危险动作的标签和置信度
|
| 217 |
+
multi_class = np.argmax(pred[:,16:],axis=1)
|
| 218 |
+
multi_conf = np.amax(pred[:, 16:], axis=1)
|
| 219 |
+
landmarks = np.reshape(landmarks,(-1,5,2))
|
| 220 |
+
return bboxes,landmarks,multi_class,multi_conf
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if __name__ == '__main__':
|
| 224 |
+
onnxmodel = FaceDetector()
|
| 225 |
+
image_path = "/tmp-data/QACode/yolov5_quality_assurance/ONNX_Inference/smoke_phone/03ajqj5uqb013bc5_1614529990_1614530009.jpg"
|
| 226 |
+
#image = np.ones(shape=[1,3,192,320], dtype=np.float32)
|
| 227 |
+
img = cv2.imread(image_path)
|
| 228 |
+
img_resize = cv2.resize(img,(320,180))
|
| 229 |
+
img_resize = cv2.copyMakeBorder(img_resize, 6, 6, 0, 0, cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
| 230 |
+
#img_resize = cv2.cvtColor(img_resize,cv2.COLOR_BGR2RGB)
|
| 231 |
+
img_resize = cv2.cvtColor(img_resize,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, and 交换 h,w c to c,h,w
|
| 232 |
+
img_resize = img_resize.astype("float32") # uint8 to fp16/32
|
| 233 |
+
img_resize /= 255.0 # 0 - 255 to 0.0 - 1.0
|
| 234 |
+
img_resize = img_resize[np.newaxis,:] # 增加维度
|
| 235 |
+
detial_dict = {}
|
| 236 |
+
input_feed = onnxmodel.get_input_feed(onnxmodel.input_name, img_resize)
|
| 237 |
+
pred = onnxmodel.onnx_session.run(onnxmodel.output_name, input_feed=input_feed)
|
| 238 |
+
detial_dict["before_decoder"] = [i.tolist() for i in pred]
|
| 239 |
+
pred=onnxmodel.after_process(pred) # torch后处理
|
| 240 |
+
detial_dict["after_decoder"] = copy.deepcopy(pred.tolist())
|
| 241 |
+
print(pred)
|
| 242 |
+
write_json("test.json",detial_dict)
|
| 243 |
+
pred=onnxmodel.nms(pred[0],0.3,0.5)
|
| 244 |
+
#detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
|
| 245 |
+
#pred=onnxmodel.scale_coords(img.shape[2:], pred, img.shape)
|
| 246 |
+
detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
|
| 247 |
+
write_json("test.json",detial_dict)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
|
head_detect/head_detector/face_utils.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def make_divisible(x, divisor):
|
| 7 |
+
# Returns x evenly divisible by divisor
|
| 8 |
+
return math.ceil(x / divisor) * divisor
|
| 9 |
+
|
| 10 |
+
def check_img_size(img_size, s=32):
|
| 11 |
+
# Verify img_size is a multiple of stride s
|
| 12 |
+
new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
|
| 13 |
+
if new_size != img_size:
|
| 14 |
+
print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
|
| 15 |
+
return new_size
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
|
| 20 |
+
# Resize and pad image while meeting stride-multiple constraints
|
| 21 |
+
shape = im.shape[:2] # current shape [height, width]
|
| 22 |
+
if isinstance(new_shape, int):
|
| 23 |
+
new_shape = (new_shape, new_shape)
|
| 24 |
+
|
| 25 |
+
# Scale ratio (new / old)
|
| 26 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
| 27 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
| 28 |
+
r = min(r, 1.0)
|
| 29 |
+
|
| 30 |
+
# Compute padding
|
| 31 |
+
ratio = r, r # width, height ratios
|
| 32 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
| 33 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
| 34 |
+
if auto: # minimum rectangle
|
| 35 |
+
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
| 36 |
+
elif scaleFill: # stretch
|
| 37 |
+
dw, dh = 0.0, 0.0
|
| 38 |
+
new_unpad = (new_shape[1], new_shape[0])
|
| 39 |
+
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
| 40 |
+
|
| 41 |
+
dw /= 2 # divide padding into 2 sides
|
| 42 |
+
dh /= 2
|
| 43 |
+
|
| 44 |
+
if shape[::-1] != new_unpad: # resize
|
| 45 |
+
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
| 46 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
| 47 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
| 48 |
+
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
| 49 |
+
return im, ratio, (dw, dh)
|
head_detect/head_detector/head_detectorv2.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
'''
|
| 3 |
+
@File : face_detectorv2.py
|
| 4 |
+
@Time : 2022/06/23 18:51:01
|
| 5 |
+
@Author : Xie WenZhen
|
| 6 |
+
@Version : 1.0
|
| 7 |
+
@Contact : xiewenzhen@didiglobal.com
|
| 8 |
+
@Desc : [人头检测去解码版本v2]
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
# here put the import lib
|
| 12 |
+
import os
|
| 13 |
+
import cv2
|
| 14 |
+
import copy
|
| 15 |
+
import onnxruntime
|
| 16 |
+
import numpy as np
|
| 17 |
+
from head_detect.head_detector.face_utils import letterbox
|
| 18 |
+
|
| 19 |
+
class HeadDetector:
|
| 20 |
+
def __init__(self, onnx_path="models/Mona_HeadDetector_v1_straght.onnx"):
|
| 21 |
+
self.onnx_path = onnx_path
|
| 22 |
+
self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
|
| 23 |
+
self.input_name = self.get_input_name(self.onnx_session)
|
| 24 |
+
self.output_name = self.get_output_name(self.onnx_session)
|
| 25 |
+
|
| 26 |
+
def get_output_name(self, onnx_session):
|
| 27 |
+
output_name = []
|
| 28 |
+
for node in onnx_session.get_outputs():
|
| 29 |
+
output_name.append(node.name)
|
| 30 |
+
return output_name
|
| 31 |
+
|
| 32 |
+
def get_input_name(self, onnx_session):
|
| 33 |
+
input_name = []
|
| 34 |
+
for node in onnx_session.get_inputs():
|
| 35 |
+
input_name.append(node.name)
|
| 36 |
+
return input_name
|
| 37 |
+
|
| 38 |
+
def get_input_feed(self, input_name, image_tensor):
|
| 39 |
+
|
| 40 |
+
input_feed = {}
|
| 41 |
+
for name in input_name:
|
| 42 |
+
input_feed[name] = image_tensor
|
| 43 |
+
return input_feed
|
| 44 |
+
|
| 45 |
+
def img_process(self,orgimg,long_side=320,stride_max=32):
|
| 46 |
+
|
| 47 |
+
#orgimg=cv2.imread(img_path)
|
| 48 |
+
img0 = copy.deepcopy(orgimg)
|
| 49 |
+
h0, w0 = orgimg.shape[:2] # orig hw
|
| 50 |
+
r = long_side/ max(h0, w0) # resize image to img_size
|
| 51 |
+
if r != 1: # always resize down, only resize up if training with augmentation
|
| 52 |
+
# interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
|
| 53 |
+
interp = cv2.INTER_LINEAR
|
| 54 |
+
|
| 55 |
+
img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
| 56 |
+
img = letterbox(img0, new_shape=(192,320),auto=False)[0] # auto True最小矩形 False固定尺度
|
| 57 |
+
# cv2.imwrite("convert.jpg",img=img)
|
| 58 |
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
|
| 59 |
+
img = img.astype("float32") # uint8 to fp16/32
|
| 60 |
+
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
| 61 |
+
img = img[np.newaxis,:]
|
| 62 |
+
|
| 63 |
+
return img,orgimg
|
| 64 |
+
|
| 65 |
+
def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
|
| 66 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
| 67 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 68 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 69 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 70 |
+
else:
|
| 71 |
+
gain = ratio_pad[0][0]
|
| 72 |
+
pad = ratio_pad[1]
|
| 73 |
+
coords[:, [0, 2]] -= pad[0] # x padding
|
| 74 |
+
coords[:, [1, 3]] -= pad[1] # y padding
|
| 75 |
+
|
| 76 |
+
coords[:, [0,1,2,3]] /= gain
|
| 77 |
+
|
| 78 |
+
return coords
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def non_max_suppression(self, boxes,confs, iou_thres=0.6):
|
| 82 |
+
|
| 83 |
+
x1 = boxes[:, 0]
|
| 84 |
+
y1 = boxes[:, 1]
|
| 85 |
+
x2 = boxes[:, 2]
|
| 86 |
+
y2 = boxes[:, 3]
|
| 87 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 88 |
+
order = confs.flatten().argsort()[::-1]
|
| 89 |
+
keep = []
|
| 90 |
+
while order.size > 0:
|
| 91 |
+
i = order[0]
|
| 92 |
+
keep.append(i)
|
| 93 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 94 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 95 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 96 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 97 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 98 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 99 |
+
inter = w * h
|
| 100 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 101 |
+
inds = np.where( ovr <= iou_thres)[0]
|
| 102 |
+
order = order[inds + 1]
|
| 103 |
+
|
| 104 |
+
return boxes[keep]
|
| 105 |
+
|
| 106 |
+
def nms(self, pred, conf_thres=0.1,iou_thres=0.5):
|
| 107 |
+
xc = pred[..., 4] > conf_thres
|
| 108 |
+
pred = pred[xc]
|
| 109 |
+
#pred[:, 15:] *= pred[:, 4:5]
|
| 110 |
+
|
| 111 |
+
# best class only
|
| 112 |
+
confs = np.amax(pred[:, 4:5], 1, keepdims=True)
|
| 113 |
+
pred[..., 0:4] = self.xywh2xyxy(pred[..., 0:4])
|
| 114 |
+
return self.non_max_suppression(pred, confs, iou_thres)
|
| 115 |
+
|
| 116 |
+
def xywh2xyxy(self, x):
|
| 117 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 118 |
+
y = np.zeros_like(x)
|
| 119 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
| 120 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
| 121 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
| 122 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
| 123 |
+
return y
|
| 124 |
+
|
| 125 |
+
def get_largest_face(self,pred):
|
| 126 |
+
"""[获取图片中最大的人脸]
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
object ([dict]): [人脸数据]
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
[int]: [最大人脸的坐标]
|
| 133 |
+
"""
|
| 134 |
+
max_index = 0
|
| 135 |
+
max_value = 0
|
| 136 |
+
for index in range(len(pred)):
|
| 137 |
+
xmin,ymin,xmax,ymax = pred[index][:4]
|
| 138 |
+
w = xmax - xmin
|
| 139 |
+
h = ymax - ymin
|
| 140 |
+
if w*h > max_value:
|
| 141 |
+
max_value = w*h
|
| 142 |
+
max_index = index
|
| 143 |
+
return max_index
|
| 144 |
+
|
| 145 |
+
def run(self, ori_image,get_largest=True):
|
| 146 |
+
img,orgimg=self.img_process(ori_image,long_side=320) #[1,3,640,640]
|
| 147 |
+
#print(img.shape)
|
| 148 |
+
input_feed = self.get_input_feed(self.input_name, img)
|
| 149 |
+
pred = self.onnx_session.run(self.output_name, input_feed=input_feed)
|
| 150 |
+
|
| 151 |
+
# pred=self.after_process(pred) # torch后处理
|
| 152 |
+
pred=self.nms(pred[0],0.3,0.5)
|
| 153 |
+
pred=self.scale_coords(img.shape[2:], pred, orgimg.shape)
|
| 154 |
+
if get_largest and pred.shape[0]!=0 :
|
| 155 |
+
pred_index = self.get_largest_face(pred)
|
| 156 |
+
pred = pred[[pred_index]]
|
| 157 |
+
bboxes = pred[:,[0,1,2,3,4]]
|
| 158 |
+
return bboxes.tolist()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
head_detect/head_detector/head_detectorv3.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
'''
|
| 3 |
+
@File : face_detectorv2.py
|
| 4 |
+
@Time : 2022/06/23 18:51:01
|
| 5 |
+
@Author : Xie WenZhen
|
| 6 |
+
@Version : 1.0
|
| 7 |
+
@Contact : xiewenzhen@didiglobal.com
|
| 8 |
+
@Desc : [人头检测去解码版本]
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
# here put the import lib
|
| 12 |
+
import os
|
| 13 |
+
import cv2
|
| 14 |
+
import copy
|
| 15 |
+
import onnxruntime
|
| 16 |
+
import numpy as np
|
| 17 |
+
from head_detect.head_detector.face_utils import letterbox
|
| 18 |
+
|
| 19 |
+
class HeadDetector:
|
| 20 |
+
def __init__(self, onnx_path="models/head_models/HeadDetectorv1.1.onnx"):
|
| 21 |
+
self.onnx_path = onnx_path
|
| 22 |
+
self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
|
| 23 |
+
self.input_name = self.get_input_name(self.onnx_session)
|
| 24 |
+
self.output_name = self.get_output_name(self.onnx_session)
|
| 25 |
+
|
| 26 |
+
def get_output_name(self, onnx_session):
|
| 27 |
+
output_name = []
|
| 28 |
+
for node in onnx_session.get_outputs():
|
| 29 |
+
output_name.append(node.name)
|
| 30 |
+
return output_name
|
| 31 |
+
|
| 32 |
+
def get_input_name(self, onnx_session):
|
| 33 |
+
input_name = []
|
| 34 |
+
for node in onnx_session.get_inputs():
|
| 35 |
+
input_name.append(node.name)
|
| 36 |
+
return input_name
|
| 37 |
+
|
| 38 |
+
def get_input_feed(self, input_name, image_tensor):
|
| 39 |
+
|
| 40 |
+
input_feed = {}
|
| 41 |
+
for name in input_name:
|
| 42 |
+
input_feed[name] = image_tensor
|
| 43 |
+
return input_feed
|
| 44 |
+
|
| 45 |
+
def after_process(self,pred):
|
| 46 |
+
# 输入尺寸320,192 降8、16、32倍,对应输出尺寸为(40、20、10)
|
| 47 |
+
stride = np.array([8., 16., 32.])
|
| 48 |
+
x=[pred[0],pred[1],pred[2]]
|
| 49 |
+
# ============yolov5参数 start============
|
| 50 |
+
nl=3
|
| 51 |
+
|
| 52 |
+
#grid=[torch.zeros(1).to(device)] * nl
|
| 53 |
+
grid=[np.zeros(1)]*nl
|
| 54 |
+
anchor_grid=np.array([[[[[[ 4., 5.]]],
|
| 55 |
+
[[[ 8., 10.]]],
|
| 56 |
+
[[[ 13., 16.]]]]],
|
| 57 |
+
[[[[[ 23., 29.]]],
|
| 58 |
+
[[[ 43., 55.]]],
|
| 59 |
+
[[[ 73., 105.]]]]],
|
| 60 |
+
[[[[[146., 217.]]],
|
| 61 |
+
[[[231., 300.]]],
|
| 62 |
+
[[[335., 433.]]]]]])
|
| 63 |
+
# ============yolov5-0.5参数 end============
|
| 64 |
+
z = []
|
| 65 |
+
for i in range(len(x)):
|
| 66 |
+
|
| 67 |
+
ny, nx = x[i].shape[1],x[i].shape[2]
|
| 68 |
+
if grid[i].shape[2:4] != x[i].shape[2:4]:
|
| 69 |
+
grid[i] = self._make_grid(nx, ny)
|
| 70 |
+
|
| 71 |
+
y = np.full_like(x[i],0)
|
| 72 |
+
|
| 73 |
+
#y[..., [0,1,2,3,4,15]] = self.sigmoid_v(x[i][..., [0,1,2,3,4,15]])
|
| 74 |
+
y[..., [0,1,2,3,4]] = self.sigmoid_v(x[i][..., [0,1,2,3,4]])
|
| 75 |
+
#sigmoid_v人脸的置信度和危险动作置信度
|
| 76 |
+
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i]) * stride[i] # xy
|
| 77 |
+
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
|
| 78 |
+
|
| 79 |
+
z.append(y.reshape((1, -1, 6)))
|
| 80 |
+
return np.concatenate(z, 1)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _make_grid(self, nx, ny):
|
| 84 |
+
yv, xv = np.meshgrid(np.arange(ny), np.arange(nx),indexing = 'ij')
|
| 85 |
+
return np.stack((xv, yv), 2).reshape((1, ny, nx, 2)).astype(float)
|
| 86 |
+
|
| 87 |
+
def sigmoid_v(self, array):
|
| 88 |
+
return np.reciprocal(np.exp(-array) + 1.0)
|
| 89 |
+
|
| 90 |
+
def img_process(self,orgimg,long_side=320,stride_max=32):
|
| 91 |
+
|
| 92 |
+
#orgimg=cv2.imread(img_path)
|
| 93 |
+
img0 = copy.deepcopy(orgimg)
|
| 94 |
+
h0, w0 = orgimg.shape[:2] # orig hw
|
| 95 |
+
r = long_side/ max(h0, w0) # resize image to img_size
|
| 96 |
+
if r != 1: # always resize down, only resize up if training with augmentation
|
| 97 |
+
# interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
|
| 98 |
+
interp = cv2.INTER_LINEAR
|
| 99 |
+
|
| 100 |
+
img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
| 101 |
+
img = letterbox(img0, new_shape=(320,320),auto=False)[0] # auto True最小矩形 False固定尺度
|
| 102 |
+
# cv2.imwrite("convert1.jpg",img=img)
|
| 103 |
+
# Convert
|
| 104 |
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
|
| 105 |
+
img = img.astype("float32") # uint8 to fp16/32
|
| 106 |
+
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
| 107 |
+
img = img[np.newaxis,:]
|
| 108 |
+
|
| 109 |
+
return img,orgimg
|
| 110 |
+
|
| 111 |
+
def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
|
| 112 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
| 113 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 114 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 115 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 116 |
+
else:
|
| 117 |
+
gain = ratio_pad[0][0]
|
| 118 |
+
pad = ratio_pad[1]
|
| 119 |
+
coords[:, [0, 2]] -= pad[0] # x padding
|
| 120 |
+
coords[:, [1, 3]] -= pad[1] # y padding
|
| 121 |
+
|
| 122 |
+
coords[:, [0,1,2,3]] /= gain
|
| 123 |
+
|
| 124 |
+
return coords
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def non_max_suppression(self, boxes,confs, iou_thres=0.6):
|
| 128 |
+
|
| 129 |
+
x1 = boxes[:, 0]
|
| 130 |
+
y1 = boxes[:, 1]
|
| 131 |
+
x2 = boxes[:, 2]
|
| 132 |
+
y2 = boxes[:, 3]
|
| 133 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 134 |
+
order = confs.flatten().argsort()[::-1]
|
| 135 |
+
keep = []
|
| 136 |
+
while order.size > 0:
|
| 137 |
+
i = order[0]
|
| 138 |
+
keep.append(i)
|
| 139 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 140 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 141 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 142 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 143 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 144 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 145 |
+
inter = w * h
|
| 146 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 147 |
+
inds = np.where( ovr <= iou_thres)[0]
|
| 148 |
+
order = order[inds + 1]
|
| 149 |
+
|
| 150 |
+
return boxes[keep]
|
| 151 |
+
|
| 152 |
+
def nms(self, pred, conf_thres=0.1,iou_thres=0.5):
|
| 153 |
+
xc = pred[..., 4] > conf_thres
|
| 154 |
+
pred = pred[xc]
|
| 155 |
+
#pred[:, 15:] *= pred[:, 4:5]
|
| 156 |
+
|
| 157 |
+
# best class only
|
| 158 |
+
confs = np.amax(pred[:, 4:5], 1, keepdims=True)
|
| 159 |
+
pred[..., 0:4] = self.xywh2xyxy(pred[..., 0:4])
|
| 160 |
+
return self.non_max_suppression(pred, confs, iou_thres)
|
| 161 |
+
|
| 162 |
+
def xywh2xyxy(self, x):
|
| 163 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 164 |
+
y = np.zeros_like(x)
|
| 165 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
| 166 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
| 167 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
| 168 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
| 169 |
+
return y
|
| 170 |
+
|
| 171 |
+
def get_largest_face(self,pred):
|
| 172 |
+
"""[获取图片中最大的人脸]
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
object ([dict]): [人脸数据]
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
[int]: [最大人脸的坐标]
|
| 179 |
+
"""
|
| 180 |
+
max_index = 0
|
| 181 |
+
max_value = 0
|
| 182 |
+
for index in range(len(pred)):
|
| 183 |
+
xmin,ymin,xmax,ymax = pred[index][:4]
|
| 184 |
+
w = xmax - xmin
|
| 185 |
+
h = ymax - ymin
|
| 186 |
+
if w*h > max_value:
|
| 187 |
+
max_value = w*h
|
| 188 |
+
max_index = index
|
| 189 |
+
return max_index
|
| 190 |
+
|
| 191 |
+
def run(self, ori_image,get_largest=True):
|
| 192 |
+
img,orgimg=self.img_process(ori_image,long_side=320) #[1,3,640,640]
|
| 193 |
+
#print(img.shape)
|
| 194 |
+
input_feed = self.get_input_feed(self.input_name, img)
|
| 195 |
+
pred = self.onnx_session.run(self.output_name, input_feed=input_feed)
|
| 196 |
+
pred=self.after_process(pred) # torch后处理
|
| 197 |
+
pred=self.nms(pred[0],0.3,0.5)
|
| 198 |
+
#detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
|
| 199 |
+
pred=self.scale_coords(img.shape[2:], pred, orgimg.shape)
|
| 200 |
+
#detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
|
| 201 |
+
|
| 202 |
+
if get_largest and pred.shape[0]!=0 :
|
| 203 |
+
pred_index = self.get_largest_face(pred)
|
| 204 |
+
pred = pred[[pred_index]]
|
| 205 |
+
bboxes = pred[:,[0,1,2,3,4]]
|
| 206 |
+
return bboxes.tolist()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
head_detect/head_detector/head_detectorv4.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- encoding: utf-8 -*-
|
| 2 |
+
'''
|
| 3 |
+
@File : face_detectorv2.py
|
| 4 |
+
@Time : 2022/06/23 18:51:01
|
| 5 |
+
@Author : Xie WenZhen
|
| 6 |
+
@Version : 1.0
|
| 7 |
+
@Contact : xiewenzhen@didiglobal.com
|
| 8 |
+
@Desc : [人头检测去解码版本]
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
# here put the import lib
|
| 12 |
+
import os
|
| 13 |
+
import cv2
|
| 14 |
+
import copy
|
| 15 |
+
import onnxruntime
|
| 16 |
+
import numpy as np
|
| 17 |
+
from head_detect.head_detector.face_utils import letterbox
|
| 18 |
+
|
| 19 |
+
class HeadDetector:
|
| 20 |
+
def __init__(self, onnx_path="models/head_models/HeadDetectorv1.3.onnx"):
|
| 21 |
+
self.onnx_path = onnx_path
|
| 22 |
+
self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
|
| 23 |
+
self.input_name = self.get_input_name(self.onnx_session)
|
| 24 |
+
self.output_name = self.get_output_name(self.onnx_session)
|
| 25 |
+
|
| 26 |
+
def get_output_name(self, onnx_session):
|
| 27 |
+
output_name = []
|
| 28 |
+
for node in onnx_session.get_outputs():
|
| 29 |
+
output_name.append(node.name)
|
| 30 |
+
return output_name
|
| 31 |
+
|
| 32 |
+
def get_input_name(self, onnx_session):
|
| 33 |
+
input_name = []
|
| 34 |
+
for node in onnx_session.get_inputs():
|
| 35 |
+
input_name.append(node.name)
|
| 36 |
+
return input_name
|
| 37 |
+
|
| 38 |
+
def get_input_feed(self, input_name, image_tensor):
|
| 39 |
+
|
| 40 |
+
input_feed = {}
|
| 41 |
+
for name in input_name:
|
| 42 |
+
input_feed[name] = image_tensor
|
| 43 |
+
return input_feed
|
| 44 |
+
|
| 45 |
+
def after_process(self,pred):
|
| 46 |
+
# 输入尺寸320,192 降8、16、32倍,对应输出尺寸为(40、20、10)
|
| 47 |
+
stride = np.array([8., 16., 32.])
|
| 48 |
+
x=[pred[0],pred[1],pred[2]]
|
| 49 |
+
# ============yolov5参数 start============
|
| 50 |
+
nl=3
|
| 51 |
+
|
| 52 |
+
#grid=[torch.zeros(1).to(device)] * nl
|
| 53 |
+
grid=[np.zeros(1)]*nl
|
| 54 |
+
anchor_grid=np.array([[[[[[ 4., 5.]]],
|
| 55 |
+
[[[ 8., 10.]]],
|
| 56 |
+
[[[ 13., 16.]]]]],
|
| 57 |
+
[[[[[ 23., 29.]]],
|
| 58 |
+
[[[ 43., 55.]]],
|
| 59 |
+
[[[ 73., 105.]]]]],
|
| 60 |
+
[[[[[146., 217.]]],
|
| 61 |
+
[[[231., 300.]]],
|
| 62 |
+
[[[335., 433.]]]]]])
|
| 63 |
+
# ============yolov5-0.5参数 end============
|
| 64 |
+
z = []
|
| 65 |
+
for i in range(len(x)):
|
| 66 |
+
|
| 67 |
+
ny, nx = x[i].shape[1],x[i].shape[2]
|
| 68 |
+
if grid[i].shape[2:4] != x[i].shape[2:4]:
|
| 69 |
+
grid[i] = self._make_grid(nx, ny)
|
| 70 |
+
|
| 71 |
+
y = np.full_like(x[i],0)
|
| 72 |
+
|
| 73 |
+
#y[..., [0,1,2,3,4,15]] = self.sigmoid_v(x[i][..., [0,1,2,3,4,15]])
|
| 74 |
+
y[..., [0,1,2,3,4]] = self.sigmoid_v(x[i][..., [0,1,2,3,4]])
|
| 75 |
+
#sigmoid_v人脸的置信度和危险动作置信度
|
| 76 |
+
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i]) * stride[i] # xy
|
| 77 |
+
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
|
| 78 |
+
|
| 79 |
+
z.append(y.reshape((1, -1, 6)))
|
| 80 |
+
return np.concatenate(z, 1)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _make_grid(self, nx, ny):
|
| 84 |
+
yv, xv = np.meshgrid(np.arange(ny), np.arange(nx),indexing = 'ij')
|
| 85 |
+
return np.stack((xv, yv), 2).reshape((1, ny, nx, 2)).astype(float)
|
| 86 |
+
|
| 87 |
+
def sigmoid_v(self, array):
|
| 88 |
+
return np.reciprocal(np.exp(-array) + 1.0)
|
| 89 |
+
|
| 90 |
+
def img_process(self,orgimg,long_side=320,stride_max=32):
|
| 91 |
+
|
| 92 |
+
#orgimg=cv2.imread(img_path)
|
| 93 |
+
img0 = copy.deepcopy(orgimg)
|
| 94 |
+
h0, w0 = orgimg.shape[:2] # orig hw
|
| 95 |
+
r = long_side/ max(h0, w0) # resize image to img_size
|
| 96 |
+
if r != 1: # always resize down, only resize up if training with augmentation
|
| 97 |
+
# interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
|
| 98 |
+
interp = cv2.INTER_LINEAR
|
| 99 |
+
|
| 100 |
+
img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
| 101 |
+
img = letterbox(img0, new_shape=(320,288),auto=False)[0] # auto True最小矩形 False固定尺度
|
| 102 |
+
# cv2.imwrite("convert1.jpg",img=img)
|
| 103 |
+
# Convert
|
| 104 |
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
|
| 105 |
+
img = img.astype("float32") # uint8 to fp16/32
|
| 106 |
+
img /= 255.0 # 0 - 255 to 0.0 - 1.0
|
| 107 |
+
img = img[np.newaxis,:]
|
| 108 |
+
|
| 109 |
+
return img,orgimg
|
| 110 |
+
|
| 111 |
+
def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
|
| 112 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
| 113 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 114 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 115 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 116 |
+
else:
|
| 117 |
+
gain = ratio_pad[0][0]
|
| 118 |
+
pad = ratio_pad[1]
|
| 119 |
+
coords[:, [0, 2]] -= pad[0] # x padding
|
| 120 |
+
coords[:, [1, 3]] -= pad[1] # y padding
|
| 121 |
+
|
| 122 |
+
coords[:, [0,1,2,3]] /= gain
|
| 123 |
+
|
| 124 |
+
return coords
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def non_max_suppression(self, boxes,confs, iou_thres=0.6):
|
| 128 |
+
|
| 129 |
+
x1 = boxes[:, 0]
|
| 130 |
+
y1 = boxes[:, 1]
|
| 131 |
+
x2 = boxes[:, 2]
|
| 132 |
+
y2 = boxes[:, 3]
|
| 133 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 134 |
+
order = confs.flatten().argsort()[::-1]
|
| 135 |
+
keep = []
|
| 136 |
+
while order.size > 0:
|
| 137 |
+
i = order[0]
|
| 138 |
+
keep.append(i)
|
| 139 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 140 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 141 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 142 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 143 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 144 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 145 |
+
inter = w * h
|
| 146 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 147 |
+
inds = np.where( ovr <= iou_thres)[0]
|
| 148 |
+
order = order[inds + 1]
|
| 149 |
+
|
| 150 |
+
return boxes[keep]
|
| 151 |
+
|
| 152 |
+
def nms(self, pred, conf_thres=0.1,iou_thres=0.5):
|
| 153 |
+
xc = pred[..., 4] > conf_thres
|
| 154 |
+
pred = pred[xc]
|
| 155 |
+
#pred[:, 15:] *= pred[:, 4:5]
|
| 156 |
+
|
| 157 |
+
# best class only
|
| 158 |
+
confs = np.amax(pred[:, 4:5], 1, keepdims=True)
|
| 159 |
+
pred[..., 0:4] = self.xywh2xyxy(pred[..., 0:4])
|
| 160 |
+
return self.non_max_suppression(pred, confs, iou_thres)
|
| 161 |
+
|
| 162 |
+
def xywh2xyxy(self, x):
|
| 163 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 164 |
+
y = np.zeros_like(x)
|
| 165 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
| 166 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
| 167 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
| 168 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
| 169 |
+
return y
|
| 170 |
+
|
| 171 |
+
def get_largest_face(self,pred):
|
| 172 |
+
"""[获取图片中最大的人脸]
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
object ([dict]): [人脸数据]
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
[int]: [最大人脸的坐标]
|
| 179 |
+
"""
|
| 180 |
+
max_index = 0
|
| 181 |
+
max_value = 0
|
| 182 |
+
for index in range(len(pred)):
|
| 183 |
+
xmin,ymin,xmax,ymax = pred[index][:4]
|
| 184 |
+
w = xmax - xmin
|
| 185 |
+
h = ymax - ymin
|
| 186 |
+
if w*h > max_value:
|
| 187 |
+
max_value = w*h
|
| 188 |
+
max_index = index
|
| 189 |
+
return max_index
|
| 190 |
+
|
| 191 |
+
def run(self, ori_image,get_largest=True):
|
| 192 |
+
img,orgimg=self.img_process(ori_image,long_side=320) #[1,3,640,640]
|
| 193 |
+
#print(img.shape)
|
| 194 |
+
input_feed = self.get_input_feed(self.input_name, img)
|
| 195 |
+
pred = self.onnx_session.run(self.output_name, input_feed=input_feed)
|
| 196 |
+
pred=self.after_process(pred) # torch后处理
|
| 197 |
+
pred=self.nms(pred[0],0.3,0.5)
|
| 198 |
+
#detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
|
| 199 |
+
pred=self.scale_coords(img.shape[2:], pred, orgimg.shape)
|
| 200 |
+
#detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
|
| 201 |
+
|
| 202 |
+
if get_largest and pred.shape[0]!=0 :
|
| 203 |
+
pred_index = self.get_largest_face(pred)
|
| 204 |
+
pred = pred[[pred_index]]
|
| 205 |
+
bboxes = pred[:,[0,1,2,3,4]]
|
| 206 |
+
return bboxes.tolist()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
head_detect/head_detector/pose.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
class Pose:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
|
| 8 |
+
pt3d = np.zeros((3, 5))
|
| 9 |
+
pt3d[0, :] = [-0.3207, 0.3101, -0.0011, -0.2578, 0.2460]
|
| 10 |
+
pt3d[1, :] = [0.2629, 0.2631, -0.0800, -0.4123, -0.4127]
|
| 11 |
+
pt3d[2, :] = [0.9560, 0.9519, 1.3194, 0.9921, 0.9899]
|
| 12 |
+
self.pt3d = pt3d * 1e5
|
| 13 |
+
|
| 14 |
+
def __call__(self, pt2d):
|
| 15 |
+
|
| 16 |
+
#pt2d = np.asarray(pt2d, np.float)
|
| 17 |
+
pt2d = np.reshape(pt2d, (5, 2)).transpose()
|
| 18 |
+
pt3d = self.pt3d
|
| 19 |
+
# 参照论文Optimum Fiducials Under Weak Perspective Projection,使用弱透视投影
|
| 20 |
+
# 减均值,排除t,便于求出R
|
| 21 |
+
pt2dm = np.zeros(pt2d.shape)
|
| 22 |
+
pt3dm = np.zeros(pt3d.shape)
|
| 23 |
+
pt2dm[0, :] = pt2d[0, :] - np.mean(pt2d[0, :])
|
| 24 |
+
pt2dm[1, :] = pt2d[1, :] - np.mean(pt2d[1, :])
|
| 25 |
+
pt3dm[0, :] = pt3d[0, :] - np.mean(pt3d[0, :])
|
| 26 |
+
pt3dm[1, :] = pt3d[1, :] - np.mean(pt3d[1, :])
|
| 27 |
+
pt3dm[2, :] = pt3d[2, :] - np.mean(pt3d[2, :])
|
| 28 |
+
# 最小二乘方法计算R
|
| 29 |
+
R1 = np.dot(np.dot(np.mat(np.dot(pt3dm, pt3dm.T)).I, pt3dm), pt2dm[0, :].T)
|
| 30 |
+
R2 = np.dot(np.dot(np.mat(np.dot(pt3dm, pt3dm.T)).I, pt3dm), pt2dm[1, :].T)
|
| 31 |
+
# 计算出f
|
| 32 |
+
f = (math.sqrt(R1[0, 0] ** 2 + R1[0, 1] ** 2 + R1[0, 2] ** 2) + math.sqrt(
|
| 33 |
+
R2[0, 0] ** 2 + R2[0, 1] ** 2 + R2[0, 2] ** 2)) / 2
|
| 34 |
+
R1 = R1 / f
|
| 35 |
+
|
| 36 |
+
R2 = R2 / f
|
| 37 |
+
R3 = np.cross(R1, R2)
|
| 38 |
+
# 使用旋转矩阵R恢复出旋转角度
|
| 39 |
+
phi = math.atan(R2[0, 2] / R3[0, 2])
|
| 40 |
+
gamma = math.atan(-R1[0, 2] / (math.sqrt(R1[0, 0] ** 2 + R1[0, 1] ** 2)))
|
| 41 |
+
theta = math.atan(R1[0, 1] / R1[0, 0])
|
| 42 |
+
|
| 43 |
+
# 使用R重新计算旋转平移矩阵,求出t
|
| 44 |
+
pt3d = np.row_stack((pt3d, np.ones((1, pt3d.shape[1]))))
|
| 45 |
+
R1_orig = np.dot(np.dot(np.mat(np.dot(pt3d, pt3d.T)).I, pt3d), pt2d[0, :].T)
|
| 46 |
+
R2_orig = np.dot(np.dot(np.mat(np.dot(pt3d, pt3d.T)).I, pt3d), pt2d[1, :].T)
|
| 47 |
+
|
| 48 |
+
t3d = np.array([R1_orig[0, 3], R2_orig[0, 3], 0]).reshape((3, 1))
|
| 49 |
+
pitch = phi * 180 / np.pi
|
| 50 |
+
yaw = gamma * 180 / np.pi
|
| 51 |
+
roll = theta * 180 / np.pi
|
| 52 |
+
|
| 53 |
+
return pitch, yaw, roll
|
| 54 |
+
|
| 55 |
+
|
head_detect/models/HeadDetectorv1.6.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fb959a1126c460b7becda05ba121c6ebd0b2a2fcc116e14af065aec69430068
|
| 3 |
+
size 1227728
|
head_detect/utils_quailty_assurance/__pycache__/utils_quailty_assurance.cpython-38.pyc
ADDED
|
Binary file (2.92 kB). View file
|
|
|
head_detect/utils_quailty_assurance/draw_tools.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
|
| 7 |
+
def show_results(img, bboxe,landmark,mask_label=None,emotion_label=None):
|
| 8 |
+
h,w,c = img.shape
|
| 9 |
+
tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
|
| 10 |
+
x1,y1,x2,y2,confidence = bboxe
|
| 11 |
+
cv2.rectangle(img, (int(x1),int(y1)), (int(x2), int(y2)), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
|
| 12 |
+
clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
|
| 13 |
+
|
| 14 |
+
for i in range(5):
|
| 15 |
+
point_x = int(landmark[i][0])
|
| 16 |
+
point_y = int(landmark[i][1])
|
| 17 |
+
cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1)
|
| 18 |
+
|
| 19 |
+
tf = max(tl - 1, 1) # font thickness
|
| 20 |
+
label = str(confidence)[:5]
|
| 21 |
+
cv2.putText(img, label, (int(x2), int(y2) - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 22 |
+
if mask_label!=None:
|
| 23 |
+
labels_dict = {0: 'Mask', 1: 'NoMask'}
|
| 24 |
+
color_dict = {0:[0,0,255],1:[255,255,0]}
|
| 25 |
+
cv2.putText(img, labels_dict[mask_label], (int(x1), int(y1)-5), 0, tl, color_dict[mask_label], thickness=tf, lineType=cv2.LINE_AA)
|
| 26 |
+
if emotion_label!=None:
|
| 27 |
+
emotion_str = "smile:{:.2f}".format(emotion_label)
|
| 28 |
+
cv2.putText(img, emotion_str, (int(x1), int(y1)-30), 0, tl, [255,153,255], thickness=tf, lineType=cv2.LINE_AA)
|
| 29 |
+
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
+
def draw_bboxes_landmarks(img, bboxe,landmark,multi_label=None,multi_conf=None):
|
| 33 |
+
h,w,c = img.shape
|
| 34 |
+
tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
|
| 35 |
+
x1,y1,x2,y2,confidence = bboxe
|
| 36 |
+
cv2.rectangle(img, (int(x1),int(y1)), (int(x2), int(y2)), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
|
| 37 |
+
clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
|
| 38 |
+
|
| 39 |
+
for i in range(5):
|
| 40 |
+
point_x = int(landmark[i][0])
|
| 41 |
+
point_y = int(landmark[i][1])
|
| 42 |
+
cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1)
|
| 43 |
+
|
| 44 |
+
tf = max(tl - 1, 1) # font thickness
|
| 45 |
+
label = f"face:{confidence:.3f}"
|
| 46 |
+
cv2.putText(img, label, (int(x1), int(y1) - 2), 0, tl / 2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 47 |
+
if multi_label!=None:
|
| 48 |
+
labels_dict = {0:"normal",1:"smoke",2:"phone",3:"drink"}
|
| 49 |
+
multi_str = f"{labels_dict[multi_label]}:{multi_conf:.3f}"
|
| 50 |
+
if y1>10:
|
| 51 |
+
cv2.putText(img, multi_str, (int(x1), int(y1)-17), 0, tl/2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 52 |
+
else:
|
| 53 |
+
cv2.putText(img, multi_str, (int(x1), int(y2)+5), 0, tl/2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 54 |
+
return img
|
| 55 |
+
|
| 56 |
+
def draw_bboxes(img, bboxe,multi_label=None,multi_conf=None):
|
| 57 |
+
h,w,c = img.shape
|
| 58 |
+
tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
|
| 59 |
+
x1,y1,x2,y2,confidence = bboxe
|
| 60 |
+
cv2.rectangle(img, (int(x1),int(y1)), (int(x2), int(y2)), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
|
| 61 |
+
tf = max(tl - 1, 1) # font thickness
|
| 62 |
+
label = f"face:{confidence:.3f}"
|
| 63 |
+
cv2.putText(img, label, (int(x1), int(y1) - 2), 0, tl / 2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 64 |
+
if multi_label!=None:
|
| 65 |
+
labels_dict = {0:"normal",1:"smoke",2:"phone",3:"drink"}
|
| 66 |
+
multi_str = f"{labels_dict[multi_label]}:{multi_conf:.3f}"
|
| 67 |
+
if y1>10:
|
| 68 |
+
cv2.putText(img, multi_str, (int(x1), int(y1)-17), 0, tl/2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 69 |
+
else:
|
| 70 |
+
cv2.putText(img, multi_str, (int(x1), int(y2)+5), 0, tl/2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
| 71 |
+
return img
|
| 72 |
+
|
| 73 |
+
def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=80):
|
| 74 |
+
height, width = img.shape[:2]
|
| 75 |
+
tl = 1 or round(0.002 * (height + width) / 2) + 1 # line/font thickness
|
| 76 |
+
pitch = pitch * np.pi / 180
|
| 77 |
+
yaw = -(yaw * np.pi / 180)
|
| 78 |
+
roll = roll * np.pi / 180
|
| 79 |
+
|
| 80 |
+
if tdx != None and tdy != None:
|
| 81 |
+
tdx = tdx
|
| 82 |
+
tdy = tdy
|
| 83 |
+
else:
|
| 84 |
+
tdx = width / 2
|
| 85 |
+
tdy = height / 2
|
| 86 |
+
|
| 87 |
+
# X-Axis pointing to right. drawn in red
|
| 88 |
+
x1 = size * (math.cos(yaw) * math.cos(roll)) + tdx
|
| 89 |
+
y1 = size * (math.cos(pitch) * math.sin(roll) + math.cos(roll)
|
| 90 |
+
* math.sin(pitch) * math.sin(yaw)) + tdy
|
| 91 |
+
|
| 92 |
+
# Y-Axis | drawn in green
|
| 93 |
+
# v
|
| 94 |
+
x2 = size * (-math.cos(yaw) * math.sin(roll)) + tdx
|
| 95 |
+
y2 = size * (math.cos(pitch) * math.cos(roll) - math.sin(pitch)
|
| 96 |
+
* math.sin(yaw) * math.sin(roll)) + tdy
|
| 97 |
+
|
| 98 |
+
# Z-Axis (out of the screen) drawn in blue
|
| 99 |
+
x3 = size * (math.sin(yaw)) + tdx
|
| 100 |
+
y3 = size * (-math.cos(yaw) * math.sin(pitch)) + tdy
|
| 101 |
+
|
| 102 |
+
cv2.line(img, (int(tdx), int(tdy)), (int(x1), int(y1)), (0, 0, 255), 3)
|
| 103 |
+
cv2.line(img, (int(tdx), int(tdy)), (int(x2), int(y2)), (0, 255, 0), 3)
|
| 104 |
+
cv2.line(img, (int(tdx), int(tdy)), (int(x3), int(y3)), (255, 0, 0), 2)
|
| 105 |
+
|
| 106 |
+
return img
|
| 107 |
+
|
| 108 |
+
def split_num(sort_lst):
|
| 109 |
+
if not sort_lst:
|
| 110 |
+
return []
|
| 111 |
+
len_lst = len(sort_lst)
|
| 112 |
+
i = 0
|
| 113 |
+
split_lst = []
|
| 114 |
+
tmp_lst = [sort_lst[i]]
|
| 115 |
+
while True:
|
| 116 |
+
if i + 1 == len_lst:
|
| 117 |
+
break
|
| 118 |
+
next_n = sort_lst[i+1]
|
| 119 |
+
if sort_lst[i] + 1 == next_n:
|
| 120 |
+
tmp_lst.append(next_n)
|
| 121 |
+
else:
|
| 122 |
+
split_lst.append(tmp_lst)
|
| 123 |
+
tmp_lst = [next_n]
|
| 124 |
+
i += 1
|
| 125 |
+
split_lst.append(tmp_lst)
|
| 126 |
+
return split_lst
|
| 127 |
+
def expand_Scope(nums,lenght):
|
| 128 |
+
if nums[0]==0:
|
| 129 |
+
start = nums[0]
|
| 130 |
+
end = nums[-1]+2
|
| 131 |
+
else:
|
| 132 |
+
start = nums[0]-1
|
| 133 |
+
end = nums[-1]+1
|
| 134 |
+
if nums[-1]==lenght-1:
|
| 135 |
+
start = nums[0]-2
|
| 136 |
+
end = nums[-1]
|
| 137 |
+
else:
|
| 138 |
+
start = nums[0]-1
|
| 139 |
+
end = nums[-1]+1
|
| 140 |
+
return start,end
|
| 141 |
+
def draw_plot(xs, ys=None, dt=None, title ="Orign Yaw angle",c='C0', label='Yaw', time=None,y_line=None,std=None,outline_dict=None,**kwargs):
|
| 142 |
+
""" plot result of KF with color `c`, optionally displaying the variance
|
| 143 |
+
of `xs`. Returns the list of lines generated by plt.plot()"""
|
| 144 |
+
|
| 145 |
+
if ys is None and dt is not None:
|
| 146 |
+
ys = xs
|
| 147 |
+
xs = np.arange(0, len(ys) * dt, dt)
|
| 148 |
+
if ys is None:
|
| 149 |
+
ys = xs
|
| 150 |
+
xs = range(len(ys))
|
| 151 |
+
|
| 152 |
+
lines = plt.title(title)
|
| 153 |
+
lines = plt.plot(xs, ys, color=c, label=label, **kwargs)
|
| 154 |
+
if time is None:
|
| 155 |
+
return lines
|
| 156 |
+
x0 = time*dt
|
| 157 |
+
y0 = ys[time]
|
| 158 |
+
lines = plt.scatter(time*dt,ys[time],s=20,color='b')
|
| 159 |
+
|
| 160 |
+
lines = plt.plot([x0,x0],[y0,0],'r-',lw=2)
|
| 161 |
+
if y_line is None:
|
| 162 |
+
return lines
|
| 163 |
+
|
| 164 |
+
lines = plt.axhline(y=y_line, color='g', linestyle='--')
|
| 165 |
+
y_line_list = np.full(len(ys),y_line)
|
| 166 |
+
std_top = y_line_list+std
|
| 167 |
+
std_btm = y_line_list-std
|
| 168 |
+
plt.plot(xs, std_top, linestyle=':', color='k', lw=2)
|
| 169 |
+
plt.plot(xs, std_btm, linestyle=':', color='k', lw=2)
|
| 170 |
+
up_outline_lst = outline_dict["up"]
|
| 171 |
+
down_outline_lst = outline_dict["down"]
|
| 172 |
+
|
| 173 |
+
for nums in up_outline_lst:
|
| 174 |
+
start,end = expand_Scope(nums,len(ys))
|
| 175 |
+
plt.fill_between(xs[start:end+1], std_btm[start:end+1], np.array(ys)[start:end+1], facecolor='red', alpha=0.3)
|
| 176 |
+
for nums in down_outline_lst:
|
| 177 |
+
start,end = expand_Scope(nums,len(ys))
|
| 178 |
+
plt.fill_between(xs[start:end+1], np.array(ys)[start:end+1],std_top[start:end+1], facecolor='red', alpha=0.3)
|
| 179 |
+
|
| 180 |
+
plt.fill_between(xs, std_btm, std_top,
|
| 181 |
+
facecolor='yellow', alpha=0.2)
|
| 182 |
+
|
| 183 |
+
return lines
|
| 184 |
+
|
| 185 |
+
def draw_plot_old(xs, ys=None, dt=None, title ="Orign Yaw angle",c='C0', label='Yaw', time=None,y_line=None,std=None,outline_dict=None,**kwargs):
|
| 186 |
+
""" plot result of KF with color `c`, optionally displaying the variance
|
| 187 |
+
of `xs`. Returns the list of lines generated by plt.plot()"""
|
| 188 |
+
|
| 189 |
+
if ys is None and dt is not None:
|
| 190 |
+
ys = xs
|
| 191 |
+
xs = np.arange(0, len(ys) * dt, dt)
|
| 192 |
+
if ys is None:
|
| 193 |
+
ys = xs
|
| 194 |
+
xs = range(len(ys))
|
| 195 |
+
|
| 196 |
+
lines = plt.title(title)
|
| 197 |
+
lines = plt.plot(xs, ys, color=c, label=label, **kwargs)
|
| 198 |
+
if time is None:
|
| 199 |
+
return lines
|
| 200 |
+
x0 = time*dt
|
| 201 |
+
y0 = ys[time]
|
| 202 |
+
lines = plt.scatter(time*dt,ys[time],s=20,color='b')
|
| 203 |
+
|
| 204 |
+
lines = plt.plot([x0,x0],[y0,0],'r-',lw=2)
|
| 205 |
+
if y_line is None:
|
| 206 |
+
return lines
|
| 207 |
+
|
| 208 |
+
lines = plt.axhline(y=y_line, color='g', linestyle='--')
|
| 209 |
+
y_line_list = np.full(len(ys),y_line)
|
| 210 |
+
std_top = y_line_list+std[0]
|
| 211 |
+
std_btm = y_line_list-std[1]
|
| 212 |
+
plt.plot(xs, std_top, linestyle=':', color='k', lw=2)
|
| 213 |
+
plt.plot(xs, std_btm, linestyle=':', color='k', lw=2)
|
| 214 |
+
up_outline_lst = outline_dict["up"]
|
| 215 |
+
down_outline_lst = outline_dict["down"]
|
| 216 |
+
|
| 217 |
+
for nums in up_outline_lst:
|
| 218 |
+
start,end = expand_Scope(nums,len(ys))
|
| 219 |
+
plt.fill_between(xs[start:end+1], std_btm[start:end+1], np.array(ys)[start:end+1], facecolor='red', alpha=0.3)
|
| 220 |
+
for nums in down_outline_lst:
|
| 221 |
+
start,end = expand_Scope(nums,len(ys))
|
| 222 |
+
plt.fill_between(xs[start:end+1], np.array(ys)[start:end+1],std_top[start:end+1], facecolor='red', alpha=0.3)
|
| 223 |
+
|
| 224 |
+
plt.fill_between(xs, std_btm, std_top,
|
| 225 |
+
facecolor='yellow', alpha=0.2)
|
| 226 |
+
|
| 227 |
+
return lines
|
| 228 |
+
def draw_sticker(src, offset, pupils, landmarks,
|
| 229 |
+
blink_thd=0.22,
|
| 230 |
+
arrow_color=(0, 125, 255), copy=False):
|
| 231 |
+
if copy:
|
| 232 |
+
src = src.copy()
|
| 233 |
+
|
| 234 |
+
left_eye_hight = landmarks[33, 1] - landmarks[40, 1]
|
| 235 |
+
left_eye_width = landmarks[39, 0] - landmarks[35, 0]
|
| 236 |
+
|
| 237 |
+
right_eye_hight = landmarks[87, 1] - landmarks[94, 1]
|
| 238 |
+
right_eye_width = landmarks[93, 0] - landmarks[89, 0]
|
| 239 |
+
|
| 240 |
+
# for mark in landmarks.reshape(-1, 2).astype(int):
|
| 241 |
+
# cv2.circle(src, tuple(mark), radius=1,
|
| 242 |
+
# color=(0, 0, 255), thickness=-1)
|
| 243 |
+
|
| 244 |
+
if left_eye_hight / left_eye_width > blink_thd:
|
| 245 |
+
cv2.arrowedLine(src, tuple(pupils[0].astype(int)),
|
| 246 |
+
tuple((offset+pupils[0]).astype(int)), arrow_color, 2)
|
| 247 |
+
|
| 248 |
+
if right_eye_hight / right_eye_width > blink_thd:
|
| 249 |
+
cv2.arrowedLine(src, tuple(pupils[1].astype(int)),
|
| 250 |
+
tuple((offset+pupils[1]).astype(int)), arrow_color, 2)
|
| 251 |
+
|
| 252 |
+
return src
|
| 253 |
+
|
head_detect/utils_quailty_assurance/metrics.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model validation metrics
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def getLabel2idx(labels):
|
| 7 |
+
label2idx = dict()
|
| 8 |
+
for i in labels:
|
| 9 |
+
if i not in label2idx:
|
| 10 |
+
label2idx[i] = len(label2idx)
|
| 11 |
+
return label2idx
|
| 12 |
+
|
| 13 |
+
def calculate_all_prediction(confMatrix):
|
| 14 |
+
'''
|
| 15 |
+
计算总精度:对角线上所有值除以总数
|
| 16 |
+
'''
|
| 17 |
+
total_sum = confMatrix.sum()
|
| 18 |
+
correct_sum = (np.diag(confMatrix)).sum()
|
| 19 |
+
prediction = round(100*float(correct_sum)/float(total_sum),2)
|
| 20 |
+
return prediction
|
| 21 |
+
|
| 22 |
+
def calculate_label_prediction(confMatrix,labelidx):
|
| 23 |
+
'''
|
| 24 |
+
计算某一个类标预测精度:该类被预测正确的数除以该类的总数
|
| 25 |
+
'''
|
| 26 |
+
label_total_sum = confMatrix.sum(axis=0)[labelidx]
|
| 27 |
+
label_correct_sum = confMatrix[labelidx][labelidx]
|
| 28 |
+
prediction = 0
|
| 29 |
+
if label_total_sum != 0:
|
| 30 |
+
prediction = round(100*float(label_correct_sum)/float(label_total_sum),2)
|
| 31 |
+
return prediction
|
| 32 |
+
|
| 33 |
+
def calculate_label_recall(confMatrix,labelidx):
|
| 34 |
+
'''
|
| 35 |
+
计算某一个类标的召回率:
|
| 36 |
+
'''
|
| 37 |
+
label_total_sum = confMatrix.sum(axis=1)[labelidx]
|
| 38 |
+
label_correct_sum = confMatrix[labelidx][labelidx]
|
| 39 |
+
recall = 0
|
| 40 |
+
if label_total_sum != 0:
|
| 41 |
+
recall = round(100*float(label_correct_sum)/float(label_total_sum),2)
|
| 42 |
+
return recall
|
| 43 |
+
|
| 44 |
+
def calculate_f1(prediction,recall):
|
| 45 |
+
if (prediction+recall)==0:
|
| 46 |
+
return 0
|
| 47 |
+
return round(2*prediction*recall/(prediction+recall),2)
|
| 48 |
+
|
| 49 |
+
def ap_per_class(tp, conf, pred_cls, target_cls):
|
| 50 |
+
# tp = np.squeeze(tp,axis = 1)
|
| 51 |
+
conf = np.squeeze(conf,axis=1)
|
| 52 |
+
pred_cls = np.squeeze(pred_cls,axis=1)
|
| 53 |
+
# target_cls = target_cls.reshape((-1,1))
|
| 54 |
+
|
| 55 |
+
# Sort by objectness
|
| 56 |
+
i = np.argsort(-conf)
|
| 57 |
+
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
| 58 |
+
#print(f"labels face:{len(target_cls)},pred face:{len(pred_cls)}")
|
| 59 |
+
|
| 60 |
+
# Find unique classes
|
| 61 |
+
unique_classes = np.unique(target_cls)
|
| 62 |
+
|
| 63 |
+
# Create Precision-Recall curve and compute AP for each class
|
| 64 |
+
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
| 65 |
+
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
|
| 66 |
+
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
|
| 67 |
+
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
|
| 68 |
+
for ci, c in enumerate(unique_classes):
|
| 69 |
+
i = pred_cls == c
|
| 70 |
+
n_l = (target_cls == c).sum() # number of labels
|
| 71 |
+
n_p = i.sum() # number of predictions
|
| 72 |
+
|
| 73 |
+
if n_p == 0 or n_l == 0:
|
| 74 |
+
continue
|
| 75 |
+
else:
|
| 76 |
+
# Accumulate FPs and TPs
|
| 77 |
+
fpc = (1 - tp[i]).cumsum(0)
|
| 78 |
+
tpc = tp[i].cumsum(0)
|
| 79 |
+
|
| 80 |
+
# Recall
|
| 81 |
+
recall = tpc / (n_l + 1e-16) # recall curve
|
| 82 |
+
r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
|
| 83 |
+
|
| 84 |
+
# Precision
|
| 85 |
+
precision = tpc / (tpc + fpc) # precision curve
|
| 86 |
+
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
|
| 87 |
+
|
| 88 |
+
# AP from recall-precision curve
|
| 89 |
+
for j in range(tp.shape[1]):
|
| 90 |
+
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
| 91 |
+
# Compute F1 score (harmonic mean of precision and recall)
|
| 92 |
+
f1 = 2 * p * r / (p + r + 1e-16)
|
| 93 |
+
return p, r, ap, f1, unique_classes.astype('int32')
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_ap(recall, precision):
|
| 97 |
+
""" Compute the average precision, given the recall and precision curves
|
| 98 |
+
# Arguments
|
| 99 |
+
recall: The recall curve (list)
|
| 100 |
+
precision: The precision curve (list)
|
| 101 |
+
# Returns
|
| 102 |
+
Average precision, precision curve, recall curve
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
# Append sentinel values to beginning and end
|
| 106 |
+
mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01]))
|
| 107 |
+
mpre = np.concatenate(([1.], precision, [0.]))
|
| 108 |
+
|
| 109 |
+
# Compute the precision envelope
|
| 110 |
+
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
| 111 |
+
|
| 112 |
+
# Integrate area under curve
|
| 113 |
+
method = 'interp' # methods: 'continuous', 'interp'
|
| 114 |
+
if method == 'interp':
|
| 115 |
+
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
| 116 |
+
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
| 117 |
+
else: # 'continuous'
|
| 118 |
+
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
|
| 119 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
| 120 |
+
|
| 121 |
+
return ap, mpre, mrec
|
| 122 |
+
|
| 123 |
+
|
head_detect/utils_quailty_assurance/result_to_coco.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from email import header
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# -*- encoding: utf-8 -*-
|
| 5 |
+
'''
|
| 6 |
+
@File : result_to_coco.py
|
| 7 |
+
@Time : 2022/04/27 15:54:13
|
| 8 |
+
@Author : Xie WenZhen
|
| 9 |
+
@Version : 1.0
|
| 10 |
+
@Contact : xiewenzhen@didiglobal.com
|
| 11 |
+
@Desc : None
|
| 12 |
+
'''
|
| 13 |
+
|
| 14 |
+
# here put the import lib
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
|
| 20 |
+
def parse_label_resultv1(label_path,img_width,img_higth):
|
| 21 |
+
with open(label_path, 'r') as fr:
|
| 22 |
+
labelList = fr.readlines()
|
| 23 |
+
face_list = []
|
| 24 |
+
for label in labelList:
|
| 25 |
+
label = label.strip().split()
|
| 26 |
+
x = float(label[1])
|
| 27 |
+
y = float(label[2])
|
| 28 |
+
w = float(label[3])
|
| 29 |
+
h = float(label[4])
|
| 30 |
+
x1 = (x - w / 2) * img_width
|
| 31 |
+
y1 = (y - h / 2) * img_higth
|
| 32 |
+
x2 = (x + w / 2) * img_width
|
| 33 |
+
y2 = (y + h / 2) * img_higth
|
| 34 |
+
face_list.append(deepcopy([x1,y1,x2-x1,y2-y1]))
|
| 35 |
+
return face_list
|
| 36 |
+
|
| 37 |
+
def parse_label_resultv2(label_path,img_width,img_higth):
|
| 38 |
+
with open(label_path, 'r') as fr:
|
| 39 |
+
labelList = fr.readlines()
|
| 40 |
+
face_list = []
|
| 41 |
+
mark_list = []
|
| 42 |
+
category_id_list = []
|
| 43 |
+
for label in labelList:
|
| 44 |
+
label = label.strip().split()
|
| 45 |
+
x = float(label[1])
|
| 46 |
+
y = float(label[2])
|
| 47 |
+
w = float(label[3])
|
| 48 |
+
h = float(label[4])
|
| 49 |
+
x1 = (x - w / 2) * img_width
|
| 50 |
+
y1 = (y - h / 2) * img_higth
|
| 51 |
+
x2 = (x + w / 2) * img_width
|
| 52 |
+
y2 = (y + h / 2) * img_higth
|
| 53 |
+
face_list.append(deepcopy([x1,y1,x2-x1,y2-y1]))
|
| 54 |
+
######
|
| 55 |
+
mx0_ = float(label[5]) * img_width
|
| 56 |
+
my0_ = float(label[6]) * img_higth
|
| 57 |
+
mx1_ = float(label[7]) * img_width
|
| 58 |
+
my1_ = float(label[8]) * img_higth
|
| 59 |
+
mx2_ = float(label[9]) * img_width
|
| 60 |
+
my2_ = float(label[10]) * img_higth
|
| 61 |
+
mx3_ = float(label[11]) * img_width
|
| 62 |
+
my3_ = float(label[12]) * img_higth
|
| 63 |
+
mx4_ = float(label[13]) * img_width
|
| 64 |
+
my4_ = float(label[14]) * img_higth
|
| 65 |
+
mark_list.append(deepcopy([mx0_,my0_,mx1_,my1_,mx2_,my2_,mx3_,my3_,mx4_,my4_]))
|
| 66 |
+
#####
|
| 67 |
+
category_id = int(label[15])
|
| 68 |
+
category_id_list.append(deepcopy(category_id))
|
| 69 |
+
|
| 70 |
+
return face_list,mark_list,category_id_list
|
| 71 |
+
def generate_coco_labels(bbox,img_height,img_width,keypoint,filename,category_id):
|
| 72 |
+
images_info = {}
|
| 73 |
+
images_info["file_name"] = filename
|
| 74 |
+
images_info['id'] = 0
|
| 75 |
+
images_info['height'] = img_height
|
| 76 |
+
images_info['width'] = img_width
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
anno = {}
|
| 80 |
+
anno['keypoints'] = keypoint
|
| 81 |
+
anno['image_id'] = 0
|
| 82 |
+
anno['id'] = 0
|
| 83 |
+
anno['num_keypoints'] = 13 # all keypoints are labelled
|
| 84 |
+
anno['bbox'] = bbox
|
| 85 |
+
anno['iscrowd'] = 0
|
| 86 |
+
anno['area'] = anno['bbox'][2] * anno['bbox'][3]
|
| 87 |
+
anno['category_id'] = category_id
|
| 88 |
+
final_output = {"images":images_info,
|
| 89 |
+
"annotations":anno}
|
| 90 |
+
return final_output
|
| 91 |
+
def get_largest_face(face_dict_list):
|
| 92 |
+
if len(face_dict_list)==1:
|
| 93 |
+
return face_dict_list[0]["bbox"],face_dict_list[0]["label"]
|
| 94 |
+
max_id = 0
|
| 95 |
+
max_area = 0
|
| 96 |
+
for idx,face_dict in enumerate(face_dict_list):
|
| 97 |
+
recent_area = face_dict["bbox"][2] * face_dict["bbox"][3]
|
| 98 |
+
if recent_area>max_area:
|
| 99 |
+
max_id = idx
|
| 100 |
+
max_area = recent_area
|
| 101 |
+
return face_dict_list[max_id]["bbox"],face_dict_list[max_id]["label"]
|
| 102 |
+
|
| 103 |
+
def main():
|
| 104 |
+
print("Hello, World!")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
main()
|
head_detect/utils_quailty_assurance/utils_quailty_assurance.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
##检查两部分数据
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
|
| 7 |
+
def write_json(json_save_path,result_dict):
|
| 8 |
+
with open(json_save_path,"w") as f:
|
| 9 |
+
f.write(json.dumps(result_dict))
|
| 10 |
+
|
| 11 |
+
def read_json(json_path):
|
| 12 |
+
with open(json_path, 'r') as f:
|
| 13 |
+
result_dict = json.load(f)
|
| 14 |
+
return result_dict
|
| 15 |
+
|
| 16 |
+
def parse_label(label_path):
|
| 17 |
+
with open(label_path, 'r') as fr:
|
| 18 |
+
labelList = fr.readlines()
|
| 19 |
+
label_list = []
|
| 20 |
+
for label in labelList:
|
| 21 |
+
label = label.strip().split()
|
| 22 |
+
l = int(label[-1])
|
| 23 |
+
label_list.append(deepcopy(l))
|
| 24 |
+
return label_list
|
| 25 |
+
|
| 26 |
+
def find_files(file_path,file_type='mp4'):
|
| 27 |
+
if not os.path.exists(file_path):
|
| 28 |
+
return []
|
| 29 |
+
result = []
|
| 30 |
+
for root, dirs, files in os.walk(file_path, topdown=False):
|
| 31 |
+
for name in files:
|
| 32 |
+
if name.split(".")[-1]==file_type:
|
| 33 |
+
result.append(os.path.join(root, name))
|
| 34 |
+
return result
|
| 35 |
+
def find_images(file_path):
|
| 36 |
+
if not os.path.exists(file_path):
|
| 37 |
+
return []
|
| 38 |
+
result = []
|
| 39 |
+
for root, dirs, files in os.walk(file_path, topdown=False):
|
| 40 |
+
for name in files:
|
| 41 |
+
if name.split(".")[-1] in ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng']:
|
| 42 |
+
result.append(os.path.join(root, name))
|
| 43 |
+
return result
|
| 44 |
+
|
| 45 |
+
def box_iou(box1, box2):
|
| 46 |
+
def box_area(box):
|
| 47 |
+
# box = 4xn
|
| 48 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
| 49 |
+
|
| 50 |
+
area1 = box_area(box1.T)
|
| 51 |
+
area2 = box_area(box2.T)
|
| 52 |
+
|
| 53 |
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
| 54 |
+
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
|
| 55 |
+
torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
| 56 |
+
# iou = inter / (area1 + area2 - inter)
|
| 57 |
+
return inter / (area1[:, None] + area2 - inter)
|
| 58 |
+
|
| 59 |
+
def xywh2xyxy(x):
|
| 60 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 61 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 62 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
| 63 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
| 64 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
| 65 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
| 66 |
+
return y
|
| 67 |
+
|
| 68 |
+
def match_Iou(label5_list,label16_list):
|
| 69 |
+
label15_torch = torch.tensor(label5_list)
|
| 70 |
+
label16_torch = torch.tensor(label16_list)
|
| 71 |
+
bbox5 = xywh2xyxy(label15_torch[:,1:5])
|
| 72 |
+
bbox16 = xywh2xyxy(label16_torch[:,1:5])
|
| 73 |
+
label5 = label15_torch[:,0]
|
| 74 |
+
label16 = label16_torch[:,15]
|
| 75 |
+
ious, i = box_iou(bbox5, bbox16).max(1)
|
| 76 |
+
final_result = label16_torch[i]
|
| 77 |
+
final_result[:,15] = label5[i]
|
| 78 |
+
return final_result
|
| 79 |
+
|
head_detect/utils_quailty_assurance/video2imglist.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#将视频按时间抽取关键帧
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def video2img(flv="ceshivedio.flv",rate=0.3,start=1,end=100):
|
| 10 |
+
|
| 11 |
+
list_all=[]
|
| 12 |
+
vc=cv2.VideoCapture(flv)
|
| 13 |
+
c=1
|
| 14 |
+
fps=dps(flv)
|
| 15 |
+
if rate>fps:
|
| 16 |
+
print("the fps is %s, set the rate=fps"%(fps))
|
| 17 |
+
rate=fps
|
| 18 |
+
|
| 19 |
+
if vc.isOpened():
|
| 20 |
+
rval=True
|
| 21 |
+
else:
|
| 22 |
+
rval=False
|
| 23 |
+
|
| 24 |
+
j=1.0
|
| 25 |
+
count=0.0
|
| 26 |
+
while rval:
|
| 27 |
+
|
| 28 |
+
count=fps/rate*(j-1)
|
| 29 |
+
rval,frame=vc.read()
|
| 30 |
+
if (c-1)==int(count):
|
| 31 |
+
j+=1
|
| 32 |
+
if (math.floor(c/fps))>=start and (math.floor(c/fps))<end:
|
| 33 |
+
if frame is not None:
|
| 34 |
+
list_all.append(frame)
|
| 35 |
+
c+=1
|
| 36 |
+
|
| 37 |
+
if (math.floor(c/fps))>=end:
|
| 38 |
+
break
|
| 39 |
+
print("[ %d ] pictures from '%s' "%(len(list_all),flv))
|
| 40 |
+
vc.release()
|
| 41 |
+
|
| 42 |
+
return list_all
|
| 43 |
+
|
| 44 |
+
def dps(vedio):
|
| 45 |
+
video = cv2.VideoCapture(vedio)
|
| 46 |
+
#(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
|
| 47 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
| 48 |
+
video.release()
|
| 49 |
+
return fps
|
| 50 |
+
|
| 51 |
+
if __name__=="__main__":
|
| 52 |
+
video_path = "/tmp-data/QACode/QAMaterial/2022-02-26video/B200C视频/低头抬头+打电话+打哈欠+抽烟.mp4"
|
| 53 |
+
imglist = video2img(video_path,rate=0.1,start=0,end=100000)
|
| 54 |
+
print(len(imglist))
|
| 55 |
+
os.makedirs("tmp",exist_ok=True)
|
| 56 |
+
for idx,image in enumerate(imglist):
|
| 57 |
+
cv2.imwrite("{}/sample_{}.jpg".format("tmp",idx),image)
|
inference_mtl.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from head_detect.demo import detect_driver_face
|
| 9 |
+
from models.shufflenet2_att_m import ShuffleNetV2
|
| 10 |
+
from utils.images import expand_box_rate, crop_with_pad, show
|
| 11 |
+
from utils.os_util import get_file_paths
|
| 12 |
+
from utils.plt_util import DrawMTL
|
| 13 |
+
|
| 14 |
+
project = 'dms3'
|
| 15 |
+
version = 'v1.0'
|
| 16 |
+
|
| 17 |
+
if version in ["v0.1", "v0.2"]:
|
| 18 |
+
input_size = [160, 160]
|
| 19 |
+
model_size = '0.5x'
|
| 20 |
+
stack_lite_head = 1
|
| 21 |
+
resume = 'data/dms3/v0.2/epoch100_11_0.910_0.865_0.983_0.984_0.964_0.923_0.952_0.831_0.801_0.948_0.948_0.833.pth'
|
| 22 |
+
onnx_export_path = 'data/dms3/v0.2/dms3_mtl_v0.2.onnx'
|
| 23 |
+
tasks = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l", 'shift_x', 'shift_y', 'expand']
|
| 24 |
+
classes = [['normal', 'left', 'down', 'right', 'ind'], ['normal', 'close', 'ind'], ['normal', 'yawn', 'ind'],
|
| 25 |
+
['normal', 'glass', 'ind'], ['normal', 'mask', 'ind'], ['normal', 'smoke'], ['normal', 'phone'],
|
| 26 |
+
['distance'], ['distance'], ['distance'], ['distance'], ['distance']]
|
| 27 |
+
task_types = [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3]
|
| 28 |
+
num_classes = [5, 3, 3, 3, 3, 2, 2, 1, 1, 1, 1, 1]
|
| 29 |
+
reg_relative_max = [-1.] * 7 + [0.05, 0.05, 0.7, 0.7, 0.1]
|
| 30 |
+
reg_dimension = [-1.] * 7 + [3, 3, 3, 3, 2]
|
| 31 |
+
expand_rate = -0.075
|
| 32 |
+
|
| 33 |
+
elif version == 'v1.0':
|
| 34 |
+
input_size = [160, 160]
|
| 35 |
+
model_size = '0.5x'
|
| 36 |
+
stack_lite_head = [1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1]
|
| 37 |
+
resume = 'data/dms3/v1.0/dms3_mtl_v1.0.pth'
|
| 38 |
+
onnx_export_path = 'data/dms3/v1.0/dms3_mtl_v1.0.onnx'
|
| 39 |
+
tasks = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l", 'shift_x', 'shift_y',
|
| 40 |
+
'expand']
|
| 41 |
+
classes = [['normal', 'left', 'down', 'right', 'ind'], ['normal', 'close', 'ind'], ['normal', 'yawn', 'ind'],
|
| 42 |
+
['normal', 'glass', 'ind'], ['normal', 'mask', 'ind'], ['normal', 'smoke'], ['normal', 'phone'],
|
| 43 |
+
['distance'], ['distance'], ['distance'], ['distance'], ['distance']]
|
| 44 |
+
task_types = [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3]
|
| 45 |
+
num_classes = [5, 3, 3, 3, 3, 2, 2, 1, 1, 1, 1, 1]
|
| 46 |
+
reg_relative_max = [-1.] * 7 + [0.05, 0.05, 0.7, 0.7, 0.1]
|
| 47 |
+
reg_dimension = [-1.] * 7 + [3, 3, 3, 3, 2]
|
| 48 |
+
expand_rate = -0.075
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplementedError
|
| 52 |
+
|
| 53 |
+
# device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
| 54 |
+
device = "cpu"
|
| 55 |
+
model = ShuffleNetV2(num_tasks=len(tasks), task_types=task_types, num_classes=num_classes, model_size=model_size,
|
| 56 |
+
with_last_conv=0, stack_lite_head=stack_lite_head, lite_head_channels=-1, onnx=False)
|
| 57 |
+
model.load_state_dict(torch.load(resume, map_location=device)['state_dict'])
|
| 58 |
+
model.to(device)
|
| 59 |
+
print(f"loading {resume}")
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
drawer = DrawMTL(project, tasks, task_types, classes)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def onnx_export(export_name='test.onnx'):
|
| 66 |
+
model_export = ShuffleNetV2(num_tasks=len(tasks), task_types=task_types, num_classes=num_classes, model_size=model_size,
|
| 67 |
+
with_last_conv=0, stack_lite_head=stack_lite_head, lite_head_channels=-1, onnx=True)
|
| 68 |
+
model_export.load_state_dict(torch.load(resume, map_location=device)['state_dict'])
|
| 69 |
+
model_export.eval()
|
| 70 |
+
|
| 71 |
+
example = torch.randn(1, 3, input_size[1], input_size[0])
|
| 72 |
+
torch.onnx.export(
|
| 73 |
+
model_export, # model being run
|
| 74 |
+
example, # model input (or a tuple for multiple inputs)
|
| 75 |
+
export_name,
|
| 76 |
+
verbose=False,
|
| 77 |
+
# store the trained parameter weights inside the model file
|
| 78 |
+
training=False,
|
| 79 |
+
input_names=['input'],
|
| 80 |
+
output_names=tasks,
|
| 81 |
+
do_constant_folding=True
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def inference_xyl(input_img, vis=False, base_box=None, dx=None, dy=None, dl=None, return_drawn=False):
|
| 86 |
+
"""
|
| 87 |
+
单张推理,人头检测+多任务分类(带xy偏移)
|
| 88 |
+
:param input_img: input image path/cv_array
|
| 89 |
+
:param vis: 可视化
|
| 90 |
+
:param base_box: 基础框
|
| 91 |
+
:param dx: x方向偏移
|
| 92 |
+
:param dy: y方向偏移
|
| 93 |
+
:param return_drawn:
|
| 94 |
+
:return: preds->list [ems_pred, eye_pred, ..., dist_r, dist_l],
|
| 95 |
+
probs->list [[ems_probs], [eye_probs], ..., [1], [1]]
|
| 96 |
+
"""
|
| 97 |
+
if isinstance(input_img, str) and os.path.isfile(input_img):
|
| 98 |
+
img = cv2.imread(input_img)
|
| 99 |
+
img_name = os.path.basename(input_img)
|
| 100 |
+
else:
|
| 101 |
+
img = input_img
|
| 102 |
+
img_name = 'image'
|
| 103 |
+
|
| 104 |
+
# img = cv2.resize(img, dsize=None, fx=0.5, fy=0.5)
|
| 105 |
+
if base_box is None or dx is None or dy is None or dl is None or \
|
| 106 |
+
abs(dx) >= 0.35 or abs(dy) >= 0.35 or abs(dl) >= 0.1:
|
| 107 |
+
box, score = detect_driver_face(img.copy())
|
| 108 |
+
print(box)
|
| 109 |
+
# if box == [0, 0, 0, 0]:
|
| 110 |
+
# return
|
| 111 |
+
base_box = expand_box_rate(img, box, rate=expand_rate)
|
| 112 |
+
print(base_box)
|
| 113 |
+
else:
|
| 114 |
+
box = None
|
| 115 |
+
w, h = base_box[2] - base_box[0], base_box[3] - base_box[1]
|
| 116 |
+
assert w == h
|
| 117 |
+
x0, y0, x1, y1 = base_box
|
| 118 |
+
x0, x1 = x0 + int(w * dx), x1 + int(w * dx)
|
| 119 |
+
y0, y1 = y0 + int(h * dy), y1 + int(h * dy)
|
| 120 |
+
|
| 121 |
+
expand = int(h * dl / (1 + 2 * dl))
|
| 122 |
+
x0, y0 = x0 + expand, y0 + expand
|
| 123 |
+
x1, y1 = x1 - expand, y1 - expand
|
| 124 |
+
|
| 125 |
+
base_box = [x0, y0, x1, y1]
|
| 126 |
+
|
| 127 |
+
crop_img = crop_with_pad(img, base_box)
|
| 128 |
+
crop_img = cv2.resize(crop_img, tuple(input_size))
|
| 129 |
+
crop_img = crop_img.astype(np.float32)
|
| 130 |
+
crop_img = (crop_img - 128) / 127.
|
| 131 |
+
crop_img = crop_img.transpose([2, 0, 1])
|
| 132 |
+
crop_img = torch.from_numpy(crop_img).to(device)
|
| 133 |
+
crop_img = crop_img.view(1, *crop_img.size())
|
| 134 |
+
|
| 135 |
+
outputs = model(crop_img)
|
| 136 |
+
preds, probs = [], []
|
| 137 |
+
msg_list = []
|
| 138 |
+
for ti, outs in enumerate(outputs):
|
| 139 |
+
if task_types[ti] == 0:
|
| 140 |
+
sub_probs = torch.softmax(outs, dim=1).cpu().detach().numpy()[0]
|
| 141 |
+
sub_pred = np.argmax(sub_probs)
|
| 142 |
+
# msg_list.append(f'{tasks[ti].upper()}: {classes[ti][sub_pred]} {sub_probs[sub_pred]:.3f}')
|
| 143 |
+
msg_list.append(f'{tasks[ti].upper()}: {sub_probs}')
|
| 144 |
+
preds.append(sub_pred)
|
| 145 |
+
probs.append([round(x, 3) for x in sub_probs])
|
| 146 |
+
elif task_types[ti] == 1:
|
| 147 |
+
sub_pred = outs.cpu().detach().item() * (base_box[2] - base_box[0]) * reg_relative_max[ti] / reg_dimension[ti]
|
| 148 |
+
msg_list.append(f'{tasks[ti].upper()}: {sub_pred:.6f}')
|
| 149 |
+
preds.append(sub_pred)
|
| 150 |
+
probs.append([round(sub_pred, 3)])
|
| 151 |
+
elif task_types[ti] in [2, 3]:
|
| 152 |
+
sub_pred = outs.cpu().detach().item() * reg_relative_max[ti] / reg_dimension[ti]
|
| 153 |
+
msg_list.append(f'{tasks[ti].upper()}: {sub_pred:.6f}')
|
| 154 |
+
preds.append(sub_pred)
|
| 155 |
+
probs.append([round(sub_pred, 3)])
|
| 156 |
+
|
| 157 |
+
# print('\n'.join(msg_list))
|
| 158 |
+
|
| 159 |
+
if vis:
|
| 160 |
+
# drawn = draw_texts(img, msg_list, box, crop_box, use_mask=True)
|
| 161 |
+
# show(drawn, img_name)
|
| 162 |
+
drawn = drawer.draw_result(img, preds, probs, box, base_box)
|
| 163 |
+
# drawn = drawer.draw_ind(img)
|
| 164 |
+
show(drawn, img_name)
|
| 165 |
+
|
| 166 |
+
if return_drawn:
|
| 167 |
+
drawn = drawer.draw_result(img, preds, probs, box, base_box)
|
| 168 |
+
return drawn
|
| 169 |
+
|
| 170 |
+
return preds, probs, box, base_box, msg_list
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def inference_xyl_dir(data_dir, vis):
|
| 174 |
+
img_paths = get_file_paths(data_dir)
|
| 175 |
+
for p in img_paths:
|
| 176 |
+
print(f"\n{os.path.basename(p)}")
|
| 177 |
+
inference_xyl(p, vis=vis)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def inference_with_basebox(input_img, base_box):
|
| 181 |
+
if isinstance(input_img, str) and os.path.isfile(input_img):
|
| 182 |
+
img = cv2.imread(input_img)
|
| 183 |
+
else:
|
| 184 |
+
img = input_img
|
| 185 |
+
|
| 186 |
+
crop_img = crop_with_pad(img, base_box)
|
| 187 |
+
crop_img = cv2.resize(crop_img, tuple(input_size))
|
| 188 |
+
crop_img = crop_img.astype(np.float32)
|
| 189 |
+
crop_img = (crop_img - 128) / 127.
|
| 190 |
+
crop_img = crop_img.transpose([2, 0, 1])
|
| 191 |
+
crop_img = torch.from_numpy(crop_img).to(device)
|
| 192 |
+
crop_img = crop_img.view(1, *crop_img.size())
|
| 193 |
+
|
| 194 |
+
outputs = model(crop_img)
|
| 195 |
+
preds, probs = [], []
|
| 196 |
+
msg_list = []
|
| 197 |
+
for ti, outs in enumerate(outputs):
|
| 198 |
+
if task_types[ti] == 0:
|
| 199 |
+
sub_probs = torch.softmax(outs, dim=1).cpu().detach().numpy()[0]
|
| 200 |
+
sub_pred = np.argmax(sub_probs)
|
| 201 |
+
# msg_list.append(f'{tasks[ti].upper()}: {classes[ti][sub_pred]} {sub_probs[sub_pred]:.3f}')
|
| 202 |
+
msg_list.append(f'{tasks[ti].upper()}: {sub_probs}')
|
| 203 |
+
preds.append(sub_pred)
|
| 204 |
+
# probs.append([round(x, 3) for x in sub_probs])
|
| 205 |
+
probs += sub_probs.tolist()
|
| 206 |
+
elif task_types[ti] == 1:
|
| 207 |
+
sub_pred = outs.cpu().detach().item() * (base_box[2] - base_box[0]) * reg_relative_max[ti] / reg_dimension[
|
| 208 |
+
ti]
|
| 209 |
+
msg_list.append(f'{tasks[ti].upper()}: {sub_pred:.6f}')
|
| 210 |
+
preds.append(sub_pred)
|
| 211 |
+
# probs.append([round(sub_pred, 3)])
|
| 212 |
+
probs.append(sub_pred)
|
| 213 |
+
elif task_types[ti] in [2, 3]:
|
| 214 |
+
sub_pred = outs.cpu().detach().item() * reg_relative_max[ti] / reg_dimension[ti]
|
| 215 |
+
msg_list.append(f'{tasks[ti].upper()}: {sub_pred:.6f}')
|
| 216 |
+
preds.append(sub_pred)
|
| 217 |
+
# probs.append([round(sub_pred, 3)])
|
| 218 |
+
probs.append(sub_pred)
|
| 219 |
+
|
| 220 |
+
return probs
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def generate_onnx_result(data_dir, label_file, save_file):
|
| 224 |
+
from utils.labels import load_labels, save_labels
|
| 225 |
+
from utils.multiprogress import MultiThreading
|
| 226 |
+
|
| 227 |
+
label_dict = load_labels(label_file)
|
| 228 |
+
|
| 229 |
+
def kernel(img_name):
|
| 230 |
+
img_path = os.path.join(data_dir, img_name)
|
| 231 |
+
box = [int(b) for b in label_dict[img_name][:4]]
|
| 232 |
+
probs = inference_with_basebox(img_path, box)
|
| 233 |
+
print(img_name, probs)
|
| 234 |
+
return img_name, probs
|
| 235 |
+
|
| 236 |
+
exe = MultiThreading(label_dict.keys(), 8)
|
| 237 |
+
res = exe.run(kernel)
|
| 238 |
+
res_dict = {r[0]: r[1] for r in res}
|
| 239 |
+
save_labels(save_file, res_dict)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == '__main__':
|
| 243 |
+
inference_xyl('/Users/didi/Desktop/MTL/dataset/images/20210712062707_09af469be0e74945994d0d6e9e0cbe36_209949.jpg', vis=True)
|
| 244 |
+
# inference_xyl_dir('test_mtl', vis=True)
|
| 245 |
+
|
| 246 |
+
# generate_onnx_result('/Users/didi/Desktop/MTL/dataset/images',
|
| 247 |
+
# '/Users/didi/Desktop/test_mtl_raw_out/test_mtl_raw_label.txt',
|
| 248 |
+
# '/Users/didi/Desktop/test_mtl_raw_out/test_mtl_raw_torch_prob.txt')
|
| 249 |
+
|
| 250 |
+
# onnx_export_pipeline(onnx_export_path, export_func=onnx_export, view_net=True, simplify=True)
|
| 251 |
+
|
| 252 |
+
"""
|
| 253 |
+
20210712062707_09af469be0e74945994d0d6e9e0cbe36_209949.jpg
|
| 254 |
+
EMS: [1.1826814e-06 5.6046390e-09 9.9999881e-01 6.0338436e-09 5.0259811e-08]
|
| 255 |
+
EYE: [9.999689e-01 3.101773e-05 1.121569e-07]
|
| 256 |
+
MOUTH: [9.991779e-01 5.185943e-06 8.168342e-04]
|
| 257 |
+
GLASS: [9.9998879e-01 1.0724665e-05 4.9968133e-07]
|
| 258 |
+
MASK: [9.9702114e-01 2.5327259e-03 4.4608722e-04]
|
| 259 |
+
SMOKE: [0.94698274 0.05301724]
|
| 260 |
+
PHONE: [9.9996448e-01 3.5488123e-05]
|
| 261 |
+
EYELID_R: [1.9863424]
|
| 262 |
+
EYELID_L: [2.1488686]
|
| 263 |
+
SHIFT_X: [0.00669874]
|
| 264 |
+
SHIFT_Y: [-0.00329317]
|
| 265 |
+
EXPAND: [-0.07628014]
|
| 266 |
+
"""
|
inference_video_mtl.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from inference_mtl import *
|
| 5 |
+
from utils.os_util import get_file_paths
|
| 6 |
+
from utils.plt_util import plot_scores_mtl, syn_plot_scores_mtl, DrawMTL
|
| 7 |
+
from utils.time_util import convert_input_time, convert_stamp
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
FOURCC = 'x264' # output video codec [h264] for internet
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def inference_videos(input_video, save_dir='output/out_vdo', detect_mode='frame', frequency=1.0, continuous=True,
|
| 14 |
+
show_res=False, plot_score=False, save_score=False, save_vdo=False, save_img=False, img_save_dir=None,
|
| 15 |
+
syn_plot=True, resize=None, time_delta=1, save_vdo_path=None, save_plt_path=None, save_csv_path=None):
|
| 16 |
+
"""
|
| 17 |
+
@param input_video: 输入视频路径,文件夹或单个视频
|
| 18 |
+
@param save_dir: 输出视频存储目录
|
| 19 |
+
@param detect_mode: 抽帧模式,second/frame,表示按秒/帧推理
|
| 20 |
+
@param frequency: 抽帧频率倍数(0-n),乘法关系,数值越大间隔越大,可输入小数
|
| 21 |
+
@param show_res: 是否可视化推理过程的结果
|
| 22 |
+
@param continuous: 输入一个文件夹时是否其中的视频连续推理
|
| 23 |
+
@param plot_score: 是否绘制分数图
|
| 24 |
+
@param save_score: 是否保存推理结果csv文件
|
| 25 |
+
@param save_vdo: 是否保存推理结果视频
|
| 26 |
+
@param save_img: 是否保存中间帧图片,需单独设置保存条件
|
| 27 |
+
@param img_save_dir: 图像保存目录
|
| 28 |
+
@param syn_plot: 是否绘制动态同步的分数图(人数座次不支持)
|
| 29 |
+
@param resize: 输出调整
|
| 30 |
+
@param time_delta: 展示视频间隔 0 按键下一张
|
| 31 |
+
@param save_vdo_path: 视频保存路径,默认为None
|
| 32 |
+
@param save_plt_path: 分数图保存路径,默认为None
|
| 33 |
+
@param save_csv_path: 分数明细保存路径,默认为None
|
| 34 |
+
@return:
|
| 35 |
+
"""
|
| 36 |
+
if not os.path.exists(save_dir):
|
| 37 |
+
os.makedirs(save_dir)
|
| 38 |
+
if save_img:
|
| 39 |
+
img_save_dir = os.path.join(save_dir, 'images') if img_save_dir is None else img_save_dir
|
| 40 |
+
if not os.path.exists(img_save_dir):
|
| 41 |
+
os.makedirs(img_save_dir)
|
| 42 |
+
save_count = 0
|
| 43 |
+
|
| 44 |
+
separately_save = True
|
| 45 |
+
if os.path.isfile(input_video):
|
| 46 |
+
vdo_list = [input_video]
|
| 47 |
+
elif os.path.isdir(input_video):
|
| 48 |
+
vdo_list = get_file_paths(input_video, mod='vdo')
|
| 49 |
+
if continuous:
|
| 50 |
+
title = os.path.basename(input_video)
|
| 51 |
+
save_vdo_path = os.path.join(save_dir, title + '.mp4') if save_vdo_path is None else save_vdo_path
|
| 52 |
+
save_plt_path = os.path.join(save_dir, title + '.jpg') if save_plt_path is None else save_plt_path
|
| 53 |
+
save_csv_path = os.path.join(save_dir, title + '.csv') if save_csv_path is None else save_csv_path
|
| 54 |
+
separately_save = False
|
| 55 |
+
frames, seconds = 0, 0
|
| 56 |
+
else:
|
| 57 |
+
print(f'No {input_video}')
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
if save_score:
|
| 61 |
+
columns = ['index']
|
| 62 |
+
for ti, task in enumerate(tasks):
|
| 63 |
+
if ti < 7:
|
| 64 |
+
sub_columns = [f"{task}-{sc}" for sc in classes[ti]]
|
| 65 |
+
else:
|
| 66 |
+
sub_columns = [task]
|
| 67 |
+
columns += sub_columns
|
| 68 |
+
|
| 69 |
+
if save_vdo and not separately_save:
|
| 70 |
+
video = cv2.VideoCapture(vdo_list[0])
|
| 71 |
+
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) if resize is None else resize[0]
|
| 72 |
+
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) if resize is None else resize[1]
|
| 73 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
| 74 |
+
fourcc = cv2.VideoWriter_fourcc(*FOURCC)
|
| 75 |
+
out_video = cv2.VideoWriter(save_vdo_path, fourcc, fps if detect_mode == 'frame' else 5,
|
| 76 |
+
(int(width*1.5), height) if syn_plot else (width, height))
|
| 77 |
+
print(f"result video save in '{save_vdo_path}'")
|
| 78 |
+
|
| 79 |
+
res_list = []
|
| 80 |
+
for vdo_path in vdo_list:
|
| 81 |
+
vdo_name = os.path.basename(vdo_path)
|
| 82 |
+
try:
|
| 83 |
+
start_time_str = vdo_name.split('_')[2][:14]
|
| 84 |
+
start_time_stamp = convert_input_time(start_time_str, digit=10)
|
| 85 |
+
except:
|
| 86 |
+
start_time_str, start_time_stamp = '', 0
|
| 87 |
+
|
| 88 |
+
cap = cv2.VideoCapture(vdo_path)
|
| 89 |
+
cur_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 90 |
+
if not cur_frames:
|
| 91 |
+
continue
|
| 92 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 93 |
+
cur_seconds = int(cur_frames / (fps + 1e-6))
|
| 94 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 95 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 96 |
+
print(f"video:{vdo_name} width:{width} height:{height} fps:{fps:.1f} frames:{cur_frames} "
|
| 97 |
+
f"seconds: {cur_seconds} start_time:{start_time_str}")
|
| 98 |
+
|
| 99 |
+
if resize is not None:
|
| 100 |
+
width, height = resize
|
| 101 |
+
|
| 102 |
+
cur_res_list = []
|
| 103 |
+
if separately_save:
|
| 104 |
+
title = os.path.splitext(vdo_name)[0]
|
| 105 |
+
save_vdo_path = os.path.join(save_dir, title + '_res.mp4') if save_vdo_path is None else save_vdo_path
|
| 106 |
+
save_plt_path = os.path.join(save_dir, title + '_res.jpg') if save_plt_path is None else save_plt_path
|
| 107 |
+
save_csv_path = os.path.join(save_dir, title + '_res.csv') if save_csv_path is None else save_csv_path
|
| 108 |
+
if save_vdo:
|
| 109 |
+
fourcc = cv2.VideoWriter_fourcc(*FOURCC)
|
| 110 |
+
out_video = cv2.VideoWriter(save_vdo_path, fourcc, fps if detect_mode == 'frame' else 5,
|
| 111 |
+
(int(1.5*width), height) if syn_plot else (width, height))
|
| 112 |
+
print(f"result video save in '{save_vdo_path}'")
|
| 113 |
+
else:
|
| 114 |
+
frames += cur_frames
|
| 115 |
+
seconds += cur_seconds
|
| 116 |
+
|
| 117 |
+
base_box, dx, dy, dl = None, None, None, None
|
| 118 |
+
|
| 119 |
+
step = 1 if detect_mode == 'frame' else fps
|
| 120 |
+
step = max(1, round(step * frequency))
|
| 121 |
+
count = 0
|
| 122 |
+
for i in range(0, cur_frames, step):
|
| 123 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 124 |
+
ret, frame = cap.read()
|
| 125 |
+
if not ret:
|
| 126 |
+
# print('video end!')
|
| 127 |
+
break
|
| 128 |
+
count += 1
|
| 129 |
+
|
| 130 |
+
if resize is not None and frame.shape[0] != resize[1] and frame.shape[1] != resize[0]:
|
| 131 |
+
frame = cv2.resize(frame, resize)
|
| 132 |
+
|
| 133 |
+
cur_res = inference_xyl(frame, vis=False, base_box=base_box, dx=dx, dy=dy, dl=dl)
|
| 134 |
+
if cur_res is None:
|
| 135 |
+
preds = [0] * len(tasks)
|
| 136 |
+
probs = [[0] * len(sub_classes) for sub_classes in classes]
|
| 137 |
+
msg_list = ['no driver']
|
| 138 |
+
base_box, dx, dy, dl = None, None, None, None
|
| 139 |
+
else:
|
| 140 |
+
preds, probs, box, crop_box, msg_list = cur_res
|
| 141 |
+
base_box, dx, dy, dl = crop_box.copy(), preds[-3], preds[-2], preds[-1]
|
| 142 |
+
|
| 143 |
+
if start_time_stamp:
|
| 144 |
+
time_stamp = start_time_stamp + round(i/fps)
|
| 145 |
+
cur_time = convert_stamp(time_stamp)
|
| 146 |
+
cur_res_list.append(tuple([cur_time+f'-{i}'] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
|
| 147 |
+
res_list.append(tuple([cur_time+f'-{i}'] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
|
| 148 |
+
else:
|
| 149 |
+
cur_time = ''
|
| 150 |
+
cur_res_list.append(tuple([count] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
|
| 151 |
+
res_list.append(tuple([count] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
|
| 152 |
+
|
| 153 |
+
if not count % 10 and count:
|
| 154 |
+
msg = "{} {} => {}".format(i, cur_time, '\t'.join(msg_list))
|
| 155 |
+
print(msg)
|
| 156 |
+
|
| 157 |
+
if save_img and probs[1][1] > 0.8: # Todo:设置不同的保存条件
|
| 158 |
+
img_name = vdo_name.replace(".mp4", f'_{i}.jpg') if not cur_time else \
|
| 159 |
+
f"{convert_stamp(convert_input_time(cur_time), '%Y%m%d%H%M%S')}_{i}.jpg"
|
| 160 |
+
img_save_path = os.path.join(img_save_dir, img_name)
|
| 161 |
+
cv2.imwrite(img_save_path, frame)
|
| 162 |
+
save_count += 1
|
| 163 |
+
|
| 164 |
+
if show_res or save_vdo:
|
| 165 |
+
drawn = drawer.draw_ind(frame) if cur_res is None else \
|
| 166 |
+
drawer.draw_result(frame, preds, probs, box, crop_box, use_mask=False, use_frame=False)
|
| 167 |
+
if syn_plot:
|
| 168 |
+
score_array = np.array([r[1:] for r in res_list])
|
| 169 |
+
if detect_mode == 'second' and cur_time:
|
| 170 |
+
indexes = [r[0][-5:] for r in res_list]
|
| 171 |
+
else:
|
| 172 |
+
indexes = list(range(len(res_list)))
|
| 173 |
+
window_length = 300 if detect_mode == 'frame' else 30
|
| 174 |
+
assert len(score_array) == len(indexes)
|
| 175 |
+
score_chart = syn_plot_scores_mtl(
|
| 176 |
+
tasks, [[0, 1, 2, 3, 4, 5, 6], [7, 8], [9, 10], [11]], classes, indexes, score_array,
|
| 177 |
+
int(0.5*width), height, window_length, width/1280)
|
| 178 |
+
|
| 179 |
+
drawn = np.concatenate([drawn, score_chart], axis=1)
|
| 180 |
+
|
| 181 |
+
if show_res:
|
| 182 |
+
cv2.namedWindow(title, 0)
|
| 183 |
+
# cv2.moveWindow(title, 0, 0)
|
| 184 |
+
# cv2.setWindowProperty(title, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
|
| 185 |
+
cv2.imshow(title, drawn)
|
| 186 |
+
cv2.waitKey(time_delta)
|
| 187 |
+
|
| 188 |
+
# write the frame after processing
|
| 189 |
+
if save_vdo:
|
| 190 |
+
out_video.write(drawn)
|
| 191 |
+
|
| 192 |
+
if separately_save:
|
| 193 |
+
if show_res:
|
| 194 |
+
cv2.destroyWindow(title)
|
| 195 |
+
if plot_score:
|
| 196 |
+
res = np.array([r[1:] for r in cur_res_list])
|
| 197 |
+
plot_scores_mtl(tasks, task_types, classes, res, title, detect_mode, save_dir=save_dir,
|
| 198 |
+
save_path=save_plt_path, show=show_res)
|
| 199 |
+
if save_score:
|
| 200 |
+
df = pd.DataFrame(cur_res_list, columns=columns)
|
| 201 |
+
df.to_csv(save_csv_path, index=False, float_format='%.3f')
|
| 202 |
+
|
| 203 |
+
if save_img:
|
| 204 |
+
print(f"total save {save_count} images")
|
| 205 |
+
|
| 206 |
+
if not separately_save:
|
| 207 |
+
if show_res:
|
| 208 |
+
cv2.destroyWindow(title)
|
| 209 |
+
if plot_score:
|
| 210 |
+
res = np.array([r[1:] for r in res_list])
|
| 211 |
+
plot_scores_mtl(tasks, task_types, classes, res, title, detect_mode, save_dir=save_dir,
|
| 212 |
+
save_path=save_plt_path, show=show_res)
|
| 213 |
+
if save_score:
|
| 214 |
+
df = pd.DataFrame(res_list, columns=columns)
|
| 215 |
+
df.to_csv(save_csv_path, index=False, float_format='%.3f')
|
| 216 |
+
|
| 217 |
+
return save_vdo_path, save_plt_path, save_vdo_path
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == '__main__':
|
| 221 |
+
inference_videos('/Users/didi/Desktop/CARVIDEO_03afqwj7uw5d801e_20230708132305000_20230708132325000.mp4',
|
| 222 |
+
save_dir='/Users/didi/Desktop/error_res_v1.0',
|
| 223 |
+
detect_mode='second', frequency=0.2, plot_score=True, save_score=True, syn_plot=True,
|
| 224 |
+
save_vdo=True, save_img=False, continuous=False, show_res=True, resize=(1280, 720), time_delta=1)
|
models/__pycache__/shufflenet2_att_m.cpython-38.pyc
ADDED
|
Binary file (6.68 kB). View file
|
|
|
models/module/__pycache__/activation.cpython-38.pyc
ADDED
|
Binary file (563 Bytes). View file
|
|
|
models/module/__pycache__/conv.cpython-38.pyc
ADDED
|
Binary file (8.6 kB). View file
|
|
|
models/module/__pycache__/init_weights.cpython-38.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
models/module/__pycache__/norm.cpython-38.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
models/module/activation.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
activations = {'ReLU': nn.ReLU,
|
| 4 |
+
'LeakyReLU': nn.LeakyReLU,
|
| 5 |
+
'ReLU6': nn.ReLU6,
|
| 6 |
+
'SELU': nn.SELU,
|
| 7 |
+
'ELU': nn.ELU,
|
| 8 |
+
None: nn.Identity
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def act_layers(name):
|
| 13 |
+
assert name in activations.keys()
|
| 14 |
+
if name == 'LeakyReLU':
|
| 15 |
+
return nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
| 16 |
+
else:
|
| 17 |
+
return activations[name](inplace=True)
|
models/module/blocks.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from models.attention.attention_blocks import *
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from utils.common import log_warn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# 在每个shuffle_block之后加入att_block
|
| 9 |
+
class ShuffleV2Block(nn.Module):
|
| 10 |
+
def __init__(self, inp, oup, mid_channels, ksize, stride, attention='', ratio=16, loc='side', onnx=False):
|
| 11 |
+
super(ShuffleV2Block, self).__init__()
|
| 12 |
+
self.onnx = onnx
|
| 13 |
+
self.stride = stride
|
| 14 |
+
assert stride in [1, 2]
|
| 15 |
+
|
| 16 |
+
self.mid_channels = mid_channels
|
| 17 |
+
self.ksize = ksize
|
| 18 |
+
pad = ksize // 2
|
| 19 |
+
self.pad = pad
|
| 20 |
+
self.inp = inp
|
| 21 |
+
|
| 22 |
+
outputs = oup - inp
|
| 23 |
+
|
| 24 |
+
branch_main = [
|
| 25 |
+
# pw
|
| 26 |
+
nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
|
| 27 |
+
nn.BatchNorm2d(mid_channels),
|
| 28 |
+
nn.ReLU(inplace=True),
|
| 29 |
+
# dw
|
| 30 |
+
nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
|
| 31 |
+
nn.BatchNorm2d(mid_channels),
|
| 32 |
+
# pw-linear
|
| 33 |
+
nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
|
| 34 |
+
nn.BatchNorm2d(outputs),
|
| 35 |
+
nn.ReLU(inplace=True),
|
| 36 |
+
]
|
| 37 |
+
self.branch_main = nn.Sequential(*branch_main)
|
| 38 |
+
|
| 39 |
+
if stride == 2:
|
| 40 |
+
branch_proj = [
|
| 41 |
+
# dw
|
| 42 |
+
nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
|
| 43 |
+
nn.BatchNorm2d(inp),
|
| 44 |
+
# pw-linear
|
| 45 |
+
nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
|
| 46 |
+
nn.BatchNorm2d(inp),
|
| 47 |
+
nn.ReLU(inplace=True),
|
| 48 |
+
]
|
| 49 |
+
self.branch_proj = nn.Sequential(*branch_proj)
|
| 50 |
+
else:
|
| 51 |
+
self.branch_proj = None
|
| 52 |
+
|
| 53 |
+
if attention:
|
| 54 |
+
self.loc = loc
|
| 55 |
+
att_out = outputs if loc == 'side' else oup
|
| 56 |
+
if attention.lower() == 'se':
|
| 57 |
+
self.att_block = SELayer(att_out, reduction=ratio)
|
| 58 |
+
elif attention.lower() == 'cbam':
|
| 59 |
+
self.att_block = CBAM(att_out, ratio)
|
| 60 |
+
elif attention.lower() == 'gc':
|
| 61 |
+
self.att_block = GCBlock(att_out, ratio=ratio)
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
else:
|
| 65 |
+
self.att_block = None
|
| 66 |
+
|
| 67 |
+
def forward(self, old_x):
|
| 68 |
+
if self.stride == 1:
|
| 69 |
+
x_proj, x = self.channel_shuffle(old_x)
|
| 70 |
+
else:
|
| 71 |
+
x_proj = old_x
|
| 72 |
+
x_proj = self.branch_proj(x_proj)
|
| 73 |
+
x = old_x
|
| 74 |
+
x = self.branch_main(x)
|
| 75 |
+
if self.att_block and self.loc == 'side':
|
| 76 |
+
x = self.att_block(x)
|
| 77 |
+
x = torch.cat((x_proj, x), 1)
|
| 78 |
+
if self.att_block and self.loc == 'after':
|
| 79 |
+
x = self.att_block(x)
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
def channel_shuffle(self, x):
|
| 83 |
+
batchsize, num_channels, height, width = x.data.size()
|
| 84 |
+
if self.onnx:
|
| 85 |
+
# 由于需要将onnx模型转换为ifx模型,ifx引擎以nchw(n=1)的格式存储数据,因此做shape变换时,尽量保证按nchw(n=1)来操作
|
| 86 |
+
x = x.reshape(1, batchsize * num_channels // 2, 2, height * width)
|
| 87 |
+
x = x.permute(0, 2, 1, 3)
|
| 88 |
+
z = num_channels // 2
|
| 89 |
+
x = x.reshape(1, -1, height, width)
|
| 90 |
+
# split时避免使用x[0]、x[1]的操作,尽量使用torch的算子来实现
|
| 91 |
+
x1, x2 = torch.split(x, split_size_or_sections=[z, z], dim=1)
|
| 92 |
+
return x1, x2
|
| 93 |
+
else:
|
| 94 |
+
x = x.reshape(batchsize * num_channels // 2, 2, height * width)
|
| 95 |
+
x = x.permute(1, 0, 2)
|
| 96 |
+
x = x.reshape(2, -1, num_channels // 2, height, width)
|
| 97 |
+
return x[0], x[1]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# 在每个shuffle_block之后加入att_block
|
| 101 |
+
class QuantizableShuffleV2Block(ShuffleV2Block):
|
| 102 |
+
|
| 103 |
+
def __init__(self, *args, **kwargs):
|
| 104 |
+
if kwargs.get('attention', ''):
|
| 105 |
+
log_warn('Quantizable model not support attention blocks')
|
| 106 |
+
kwargs['attention'] = ''
|
| 107 |
+
super(QuantizableShuffleV2Block, self).__init__(*args, **kwargs)
|
| 108 |
+
self.quantized_funcs = nn.quantized.FloatFunctional()
|
| 109 |
+
|
| 110 |
+
def forward(self, old_x):
|
| 111 |
+
if self.branch_proj is None:
|
| 112 |
+
x_proj, x = self.channel_shuffle(old_x)
|
| 113 |
+
else:
|
| 114 |
+
x_proj = old_x
|
| 115 |
+
x_proj = self.branch_proj(x_proj)
|
| 116 |
+
x = old_x
|
| 117 |
+
x = self.branch_main(x)
|
| 118 |
+
x = self.quantized_funcs.cat((x_proj, x), 1)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
def channel_shuffle(self, x):
|
| 122 |
+
batchsize, num_channels, height, width = x.data.size()
|
| 123 |
+
x = x.reshape(batchsize * num_channels // 2, 2, height * width)
|
| 124 |
+
x = x.permute(1, 0, 2)
|
| 125 |
+
x = x.reshape(2, -1, num_channels // 2, height, width)
|
| 126 |
+
return x[0], x[1]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ShuffleV2BlockSK(nn.Module):
|
| 130 |
+
def __init__(self, inp, oup, mid_channels, *, ksize, stride):
|
| 131 |
+
super(ShuffleV2BlockSK, self).__init__()
|
| 132 |
+
self.stride = stride
|
| 133 |
+
assert stride in [1, 2]
|
| 134 |
+
|
| 135 |
+
self.mid_channels = mid_channels
|
| 136 |
+
self.ksize = ksize
|
| 137 |
+
pad = ksize // 2
|
| 138 |
+
self.pad = pad
|
| 139 |
+
self.inp = inp
|
| 140 |
+
|
| 141 |
+
outputs = oup - inp
|
| 142 |
+
|
| 143 |
+
branch_main = [
|
| 144 |
+
# pw
|
| 145 |
+
nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
|
| 146 |
+
nn.BatchNorm2d(mid_channels),
|
| 147 |
+
nn.ReLU(inplace=True),
|
| 148 |
+
# dw
|
| 149 |
+
# nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
|
| 150 |
+
SKConv(mid_channels, 2, mid_channels, stride=stride, use_relu=False),
|
| 151 |
+
# SKConv2(mid_channels, 2, mid_channels, 4, stride=stride),
|
| 152 |
+
nn.BatchNorm2d(mid_channels),
|
| 153 |
+
# pw-linear
|
| 154 |
+
nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
|
| 155 |
+
nn.BatchNorm2d(outputs),
|
| 156 |
+
nn.ReLU(inplace=True),
|
| 157 |
+
]
|
| 158 |
+
self.branch_main = nn.Sequential(*branch_main)
|
| 159 |
+
|
| 160 |
+
if stride == 2:
|
| 161 |
+
branch_proj = [
|
| 162 |
+
# dw
|
| 163 |
+
nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
|
| 164 |
+
# SKConv(inp, 2, inp, stride=stride, use_relu=False),
|
| 165 |
+
# SKConv2(inp, 2, inp, 4, stride=stride),
|
| 166 |
+
nn.BatchNorm2d(inp),
|
| 167 |
+
# pw-linear
|
| 168 |
+
nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
|
| 169 |
+
nn.BatchNorm2d(inp),
|
| 170 |
+
nn.ReLU(inplace=True),
|
| 171 |
+
]
|
| 172 |
+
self.branch_proj = nn.Sequential(*branch_proj)
|
| 173 |
+
else:
|
| 174 |
+
self.branch_proj = None
|
| 175 |
+
|
| 176 |
+
def forward(self, old_x):
|
| 177 |
+
if self.stride == 1:
|
| 178 |
+
x_proj, x = self.channel_shuffle(old_x)
|
| 179 |
+
return torch.cat((x_proj, self.branch_main(x)), 1)
|
| 180 |
+
elif self.stride == 2:
|
| 181 |
+
x_proj = old_x
|
| 182 |
+
x = old_x
|
| 183 |
+
return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)
|
| 184 |
+
|
| 185 |
+
def channel_shuffle(self, x):
|
| 186 |
+
batchsize, num_channels, height, width = x.data.size()
|
| 187 |
+
assert (num_channels % 4 == 0)
|
| 188 |
+
x = x.reshape(batchsize * num_channels // 2, 2, height * width)
|
| 189 |
+
x = x.permute(1, 0, 2)
|
| 190 |
+
x = x.reshape(2, -1, num_channels // 2, height, width)
|
| 191 |
+
return x[0], x[1]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class SnetBlock(ShuffleV2Block):
|
| 195 |
+
"""
|
| 196 |
+
自定义了shuffle函数,其它都一样
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def channel_shuffle(self, x):
|
| 200 |
+
g = 2
|
| 201 |
+
x = x.reshape(x.shape[0], g, x.shape[1] // g, x.shape[2], x.shape[3])
|
| 202 |
+
x = x.permute(0, 2, 1, 3, 4)
|
| 203 |
+
x = x.reshape(x.shape[0], -1, x.shape[3], x.shape[4])
|
| 204 |
+
x_proj = x[:, :(x.shape[1] // 2), :, :]
|
| 205 |
+
x = x[:, (x.shape[1] // 2):, :, :]
|
| 206 |
+
return x_proj, x
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def conv_bn_relu(inp, oup, kernel_size, stride, pad):
|
| 210 |
+
return nn.Sequential(
|
| 211 |
+
nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=False),
|
| 212 |
+
nn.BatchNorm2d(oup),
|
| 213 |
+
nn.ReLU(inplace=True)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class CEM(nn.Module):
|
| 218 |
+
"""
|
| 219 |
+
Context Enhancement Module
|
| 220 |
+
改进的FPN结构,c4、c5、glb分别1x1卷积(c5上次样,glb传播),结果聚合
|
| 221 |
+
支持feat_stride 8、16
|
| 222 |
+
TODO: 最后再做一次卷积
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self, in_channels1, in_channels2, in_channels3,
|
| 226 |
+
feat_stride=16, squeeze_channels=245, use_relu=False):
|
| 227 |
+
super(CEM, self).__init__()
|
| 228 |
+
self.feat_stride = feat_stride
|
| 229 |
+
assert feat_stride in [8, 16], f"{feat_stride} not support, select in [8, 16]"
|
| 230 |
+
if feat_stride == 8:
|
| 231 |
+
self.conv3 = nn.Conv2d(in_channels1 // 2, squeeze_channels, 1, bias=True)
|
| 232 |
+
self.conv4 = nn.Conv2d(in_channels1, squeeze_channels, 1, bias=True)
|
| 233 |
+
self.conv5 = nn.Conv2d(in_channels2, squeeze_channels, 1, bias=True)
|
| 234 |
+
self.conv_last = nn.Conv2d(in_channels3, squeeze_channels, 1, bias=True)
|
| 235 |
+
self.use_relu = use_relu
|
| 236 |
+
if use_relu:
|
| 237 |
+
self.relu7 = nn.ReLU(inplace=True)
|
| 238 |
+
|
| 239 |
+
def forward(self, inputs):
|
| 240 |
+
if self.feat_stride == 8:
|
| 241 |
+
c3_lat = self.conv3(inputs[0])
|
| 242 |
+
c4_lat = self.conv4(inputs[1])
|
| 243 |
+
c4_lat = F.interpolate(c4_lat, size=[c3_lat.size(2), c3_lat.size(3)], mode="nearest")
|
| 244 |
+
c5_lat = self.conv5(inputs[2])
|
| 245 |
+
c5_lat = F.interpolate(c5_lat, size=[c3_lat.size(2), c3_lat.size(3)], mode="nearest")
|
| 246 |
+
glb_lat = self.conv_last(inputs[3])
|
| 247 |
+
out = c3_lat + c4_lat + c5_lat + glb_lat
|
| 248 |
+
else:
|
| 249 |
+
c4_lat = self.conv4(inputs[0])
|
| 250 |
+
c5_lat = self.conv5(inputs[1])
|
| 251 |
+
c5_lat = F.interpolate(c5_lat, size=[c4_lat.size(2), c4_lat.size(3)], mode="nearest") # 上采样
|
| 252 |
+
glb_lat = self.conv_last(inputs[2])
|
| 253 |
+
out = c4_lat + c5_lat + glb_lat
|
| 254 |
+
|
| 255 |
+
if self.use_relu:
|
| 256 |
+
out = self.relu7(out)
|
| 257 |
+
return out
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class CEM_a(nn.Module):
|
| 261 |
+
"""
|
| 262 |
+
Context Enhancement Module
|
| 263 |
+
改进的FPN结构,c4、c5、glb分别1x1卷积(c5上次样,glb传播),结果聚合
|
| 264 |
+
支持feat_stride 8、16
|
| 265 |
+
TODO: 最后再做一次卷积
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
def __init__(self, in_channels1, in_channels2, in_channels3,
|
| 269 |
+
feat_stride=16, squeeze_channels=245, use_relu=False):
|
| 270 |
+
super(CEM_a, self).__init__()
|
| 271 |
+
self.feat_stride = feat_stride
|
| 272 |
+
assert feat_stride in [8, 16], f"{feat_stride} not support, select in [8, 16]"
|
| 273 |
+
if feat_stride == 8:
|
| 274 |
+
self.conv3 = nn.Conv2d(in_channels1 // 2, squeeze_channels, 1, bias=True)
|
| 275 |
+
self.conv4 = nn.Conv2d(in_channels1, squeeze_channels, 1, bias=True)
|
| 276 |
+
self.conv5 = nn.Conv2d(in_channels2, squeeze_channels, 1, bias=True)
|
| 277 |
+
self.conv_last = nn.Conv2d(in_channels3, squeeze_channels, 1, bias=True)
|
| 278 |
+
self.use_relu = use_relu
|
| 279 |
+
if use_relu:
|
| 280 |
+
self.relu7 = nn.ReLU(inplace=True)
|
| 281 |
+
|
| 282 |
+
def forward(self, inputs):
|
| 283 |
+
if self.feat_stride == 8:
|
| 284 |
+
c3_lat = self.conv3(inputs[0])
|
| 285 |
+
c4_lat = self.conv4(inputs[1])
|
| 286 |
+
c4_lat = F.interpolate(c4_lat, size=[c3_lat.size(2), c3_lat.size(3)], mode="nearest")
|
| 287 |
+
c5_lat = self.conv5(inputs[2])
|
| 288 |
+
c5_lat = F.interpolate(c5_lat, size=[c3_lat.size(2), c3_lat.size(3)], mode="nearest")
|
| 289 |
+
glb_lat = self.conv_last(inputs[3])
|
| 290 |
+
out = c3_lat + c4_lat + c5_lat + glb_lat
|
| 291 |
+
else:
|
| 292 |
+
c4_lat = self.conv4(inputs[0])
|
| 293 |
+
c5_lat = self.conv5(inputs[1])
|
| 294 |
+
c5_lat = F.interpolate(c5_lat, size=[c4_lat.size(2), c4_lat.size(3)], mode="nearest") # 上采样
|
| 295 |
+
glb_lat = self.conv_last(inputs[2])
|
| 296 |
+
out = c4_lat + c5_lat + glb_lat
|
| 297 |
+
|
| 298 |
+
if self.use_relu:
|
| 299 |
+
out = self.relu7(out)
|
| 300 |
+
return out
|
models/module/conv.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ConvModule refers from MMDetection
|
| 3 |
+
RepVGGConvModule refers from RepVGG: Making VGG-style ConvNets Great Again
|
| 4 |
+
"""
|
| 5 |
+
import warnings
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from models.module.activation import act_layers
|
| 12 |
+
from models.module.init_weights import kaiming_init, constant_init
|
| 13 |
+
from models.module.norm import build_norm_layer
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ConvModule(nn.Module):
|
| 17 |
+
"""A conv block that contains conv/norm/activation layers.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
in_channels (int): Same as nn.Conv2d.
|
| 21 |
+
out_channels (int): Same as nn.Conv2d.
|
| 22 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
| 23 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
| 24 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
| 25 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
| 26 |
+
groups (int): Same as nn.Conv2d.
|
| 27 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
| 28 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
| 29 |
+
False.
|
| 30 |
+
conv_cfg (dict): Config dict for convolution layer.
|
| 31 |
+
norm_cfg (dict): Config dict for normalization layer.
|
| 32 |
+
activation (str): activation layer, "ReLU" by default.
|
| 33 |
+
inplace (bool): Whether to use inplace mode for activation.
|
| 34 |
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
| 35 |
+
sequence of "conv", "norm" and "act". Examples are
|
| 36 |
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self,
|
| 40 |
+
in_channels,
|
| 41 |
+
out_channels,
|
| 42 |
+
kernel_size,
|
| 43 |
+
stride=1,
|
| 44 |
+
padding=0,
|
| 45 |
+
dilation=1,
|
| 46 |
+
groups=1,
|
| 47 |
+
bias='auto',
|
| 48 |
+
conv_cfg=None,
|
| 49 |
+
norm_cfg=None,
|
| 50 |
+
activation='ReLU',
|
| 51 |
+
inplace=True,
|
| 52 |
+
order=('conv', 'norm', 'act')):
|
| 53 |
+
super(ConvModule, self).__init__()
|
| 54 |
+
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
| 55 |
+
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
| 56 |
+
assert activation is None or isinstance(activation, str)
|
| 57 |
+
self.conv_cfg = conv_cfg
|
| 58 |
+
self.norm_cfg = norm_cfg
|
| 59 |
+
self.activation = activation
|
| 60 |
+
self.inplace = inplace
|
| 61 |
+
self.order = order
|
| 62 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
| 63 |
+
assert set(order) == set(['conv', 'norm', 'act'])
|
| 64 |
+
|
| 65 |
+
self.with_norm = norm_cfg is not None
|
| 66 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
| 67 |
+
if bias == 'auto':
|
| 68 |
+
bias = False if self.with_norm else True
|
| 69 |
+
self.with_bias = bias
|
| 70 |
+
|
| 71 |
+
if self.with_norm and self.with_bias:
|
| 72 |
+
warnings.warn('ConvModule has norm and bias at the same time')
|
| 73 |
+
|
| 74 |
+
# build convolution layer
|
| 75 |
+
self.conv = nn.Conv2d( #
|
| 76 |
+
in_channels,
|
| 77 |
+
out_channels,
|
| 78 |
+
kernel_size,
|
| 79 |
+
stride=stride,
|
| 80 |
+
padding=padding,
|
| 81 |
+
dilation=dilation,
|
| 82 |
+
groups=groups,
|
| 83 |
+
bias=bias)
|
| 84 |
+
# export the attributes of self.conv to a higher level for convenience
|
| 85 |
+
self.in_channels = self.conv.in_channels
|
| 86 |
+
self.out_channels = self.conv.out_channels
|
| 87 |
+
self.kernel_size = self.conv.kernel_size
|
| 88 |
+
self.stride = self.conv.stride
|
| 89 |
+
self.padding = self.conv.padding
|
| 90 |
+
self.dilation = self.conv.dilation
|
| 91 |
+
self.transposed = self.conv.transposed
|
| 92 |
+
self.output_padding = self.conv.output_padding
|
| 93 |
+
self.groups = self.conv.groups
|
| 94 |
+
|
| 95 |
+
# build normalization layers
|
| 96 |
+
if self.with_norm:
|
| 97 |
+
# norm layer is after conv layer
|
| 98 |
+
if order.index('norm') > order.index('conv'):
|
| 99 |
+
norm_channels = out_channels
|
| 100 |
+
else:
|
| 101 |
+
norm_channels = in_channels
|
| 102 |
+
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
|
| 103 |
+
self.add_module(self.norm_name, norm)
|
| 104 |
+
|
| 105 |
+
# build activation layer
|
| 106 |
+
if self.activation:
|
| 107 |
+
self.act = act_layers(self.activation)
|
| 108 |
+
|
| 109 |
+
# Use msra init by default
|
| 110 |
+
self.init_weights()
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def norm(self):
|
| 114 |
+
return getattr(self, self.norm_name)
|
| 115 |
+
|
| 116 |
+
def init_weights(self):
|
| 117 |
+
if self.activation == 'LeakyReLU':
|
| 118 |
+
nonlinearity = 'leaky_relu'
|
| 119 |
+
else:
|
| 120 |
+
nonlinearity = 'relu'
|
| 121 |
+
kaiming_init(self.conv, nonlinearity=nonlinearity)
|
| 122 |
+
if self.with_norm:
|
| 123 |
+
constant_init(self.norm, 1, bias=0)
|
| 124 |
+
|
| 125 |
+
def forward(self, x, norm=True):
|
| 126 |
+
for layer in self.order:
|
| 127 |
+
if layer == 'conv':
|
| 128 |
+
x = self.conv(x)
|
| 129 |
+
elif layer == 'norm' and norm and self.with_norm:
|
| 130 |
+
x = self.norm(x)
|
| 131 |
+
elif layer == 'act' and self.activation:
|
| 132 |
+
x = self.act(x)
|
| 133 |
+
return x
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DepthwiseConvModule(nn.Module):
|
| 137 |
+
|
| 138 |
+
def __init__(self,
|
| 139 |
+
in_channels,
|
| 140 |
+
out_channels,
|
| 141 |
+
kernel_size,
|
| 142 |
+
stride=1,
|
| 143 |
+
padding=0,
|
| 144 |
+
dilation=1,
|
| 145 |
+
bias='auto',
|
| 146 |
+
norm_cfg=dict(type='BN'),
|
| 147 |
+
activation='ReLU',
|
| 148 |
+
inplace=True,
|
| 149 |
+
order=('depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act')):
|
| 150 |
+
super(DepthwiseConvModule, self).__init__()
|
| 151 |
+
assert activation is None or isinstance(activation, str)
|
| 152 |
+
self.activation = activation
|
| 153 |
+
self.inplace = inplace
|
| 154 |
+
self.order = order
|
| 155 |
+
assert isinstance(self.order, tuple) and len(self.order) == 6
|
| 156 |
+
assert set(order) == set(['depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act'])
|
| 157 |
+
|
| 158 |
+
self.with_norm = norm_cfg is not None
|
| 159 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
| 160 |
+
if bias == 'auto':
|
| 161 |
+
bias = False if self.with_norm else True
|
| 162 |
+
self.with_bias = bias
|
| 163 |
+
|
| 164 |
+
if self.with_norm and self.with_bias:
|
| 165 |
+
warnings.warn('ConvModule has norm and bias at the same time')
|
| 166 |
+
|
| 167 |
+
# build convolution layer
|
| 168 |
+
self.depthwise = nn.Conv2d(in_channels,
|
| 169 |
+
in_channels,
|
| 170 |
+
kernel_size,
|
| 171 |
+
stride=stride,
|
| 172 |
+
padding=padding,
|
| 173 |
+
dilation=dilation,
|
| 174 |
+
groups=in_channels,
|
| 175 |
+
bias=bias)
|
| 176 |
+
self.pointwise = nn.Conv2d(in_channels,
|
| 177 |
+
out_channels,
|
| 178 |
+
kernel_size=1,
|
| 179 |
+
stride=1,
|
| 180 |
+
padding=0,
|
| 181 |
+
bias=bias)
|
| 182 |
+
|
| 183 |
+
# export the attributes of self.conv to a higher level for convenience
|
| 184 |
+
self.in_channels = self.depthwise.in_channels
|
| 185 |
+
self.out_channels = self.pointwise.out_channels
|
| 186 |
+
self.kernel_size = self.depthwise.kernel_size
|
| 187 |
+
self.stride = self.depthwise.stride
|
| 188 |
+
self.padding = self.depthwise.padding
|
| 189 |
+
self.dilation = self.depthwise.dilation
|
| 190 |
+
self.transposed = self.depthwise.transposed
|
| 191 |
+
self.output_padding = self.depthwise.output_padding
|
| 192 |
+
|
| 193 |
+
# build normalization layers
|
| 194 |
+
if self.with_norm:
|
| 195 |
+
# norm layer is after conv layer
|
| 196 |
+
_, self.dwnorm = build_norm_layer(norm_cfg, in_channels)
|
| 197 |
+
_, self.pwnorm = build_norm_layer(norm_cfg, out_channels)
|
| 198 |
+
|
| 199 |
+
# build activation layer
|
| 200 |
+
if self.activation:
|
| 201 |
+
self.act = act_layers(self.activation)
|
| 202 |
+
|
| 203 |
+
# Use msra init by default
|
| 204 |
+
self.init_weights()
|
| 205 |
+
|
| 206 |
+
def init_weights(self):
|
| 207 |
+
if self.activation == 'LeakyReLU':
|
| 208 |
+
nonlinearity = 'leaky_relu'
|
| 209 |
+
else:
|
| 210 |
+
nonlinearity = 'relu'
|
| 211 |
+
kaiming_init(self.depthwise, nonlinearity=nonlinearity)
|
| 212 |
+
kaiming_init(self.pointwise, nonlinearity=nonlinearity)
|
| 213 |
+
if self.with_norm:
|
| 214 |
+
constant_init(self.dwnorm, 1, bias=0)
|
| 215 |
+
constant_init(self.pwnorm, 1, bias=0)
|
| 216 |
+
|
| 217 |
+
def forward(self, x, norm=True):
|
| 218 |
+
for layer_name in self.order:
|
| 219 |
+
if layer_name != 'act':
|
| 220 |
+
layer = self.__getattr__(layer_name)
|
| 221 |
+
x = layer(x)
|
| 222 |
+
elif layer_name == 'act' and self.activation:
|
| 223 |
+
x = self.act(x)
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class RepVGGConvModule(nn.Module):
|
| 228 |
+
"""
|
| 229 |
+
RepVGG Conv Block from paper RepVGG: Making VGG-style ConvNets Great Again
|
| 230 |
+
https://arxiv.org/abs/2101.03697
|
| 231 |
+
https://github.com/DingXiaoH/RepVGG
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self,
|
| 235 |
+
in_channels,
|
| 236 |
+
out_channels,
|
| 237 |
+
kernel_size,
|
| 238 |
+
stride=1,
|
| 239 |
+
padding=0,
|
| 240 |
+
dilation=1,
|
| 241 |
+
groups=1,
|
| 242 |
+
activation='ReLU',
|
| 243 |
+
padding_mode='zeros',
|
| 244 |
+
deploy=False):
|
| 245 |
+
super(RepVGGConvModule, self).__init__()
|
| 246 |
+
assert activation is None or isinstance(activation, str)
|
| 247 |
+
self.activation = activation
|
| 248 |
+
|
| 249 |
+
self.deploy = deploy
|
| 250 |
+
self.groups = groups
|
| 251 |
+
self.in_channels = in_channels
|
| 252 |
+
|
| 253 |
+
assert kernel_size == 3
|
| 254 |
+
assert padding == 1
|
| 255 |
+
|
| 256 |
+
padding_11 = padding - kernel_size // 2
|
| 257 |
+
|
| 258 |
+
# build activation layer
|
| 259 |
+
if self.activation:
|
| 260 |
+
self.act = act_layers(self.activation)
|
| 261 |
+
|
| 262 |
+
if deploy:
|
| 263 |
+
self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
| 264 |
+
kernel_size=kernel_size, stride=stride,
|
| 265 |
+
padding=padding, dilation=dilation, groups=groups, bias=True,
|
| 266 |
+
padding_mode=padding_mode)
|
| 267 |
+
|
| 268 |
+
else:
|
| 269 |
+
self.rbr_identity = nn.BatchNorm2d(
|
| 270 |
+
num_features=in_channels) if out_channels == in_channels and stride == 1 else None
|
| 271 |
+
|
| 272 |
+
self.rbr_dense = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
| 273 |
+
kernel_size=kernel_size, stride=stride, padding=padding,
|
| 274 |
+
groups=groups, bias=False),
|
| 275 |
+
nn.BatchNorm2d(num_features=out_channels))
|
| 276 |
+
|
| 277 |
+
self.rbr_1x1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
| 278 |
+
kernel_size=1, stride=stride, padding=padding_11,
|
| 279 |
+
groups=groups, bias=False),
|
| 280 |
+
nn.BatchNorm2d(num_features=out_channels))
|
| 281 |
+
print('RepVGG Block, identity = ', self.rbr_identity)
|
| 282 |
+
|
| 283 |
+
def forward(self, inputs):
|
| 284 |
+
if hasattr(self, 'rbr_reparam'):
|
| 285 |
+
return self.act(self.rbr_reparam(inputs))
|
| 286 |
+
|
| 287 |
+
if self.rbr_identity is None:
|
| 288 |
+
id_out = 0
|
| 289 |
+
else:
|
| 290 |
+
id_out = self.rbr_identity(inputs)
|
| 291 |
+
|
| 292 |
+
return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
|
| 293 |
+
|
| 294 |
+
# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
|
| 295 |
+
# You can get the equivalent kernel and bias at any time and do whatever you want,
|
| 296 |
+
# for example, apply some penalties or constraints during training, just like you do to the other models.
|
| 297 |
+
# May be useful for quantization or pruning.
|
| 298 |
+
def get_equivalent_kernel_bias(self):
|
| 299 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
| 300 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
|
| 301 |
+
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
|
| 302 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
| 303 |
+
|
| 304 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
| 305 |
+
if kernel1x1 is None:
|
| 306 |
+
return 0
|
| 307 |
+
else:
|
| 308 |
+
return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
| 309 |
+
|
| 310 |
+
def _fuse_bn_tensor(self, branch):
|
| 311 |
+
if branch is None:
|
| 312 |
+
return 0, 0
|
| 313 |
+
if isinstance(branch, nn.Sequential):
|
| 314 |
+
kernel = branch[0].weight
|
| 315 |
+
running_mean = branch[1].running_mean
|
| 316 |
+
running_var = branch[1].running_var
|
| 317 |
+
gamma = branch[1].weight
|
| 318 |
+
beta = branch[1].bias
|
| 319 |
+
eps = branch[1].eps
|
| 320 |
+
else:
|
| 321 |
+
assert isinstance(branch, nn.BatchNorm2d)
|
| 322 |
+
if not hasattr(self, 'id_tensor'):
|
| 323 |
+
input_dim = self.in_channels // self.groups
|
| 324 |
+
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
|
| 325 |
+
for i in range(self.in_channels):
|
| 326 |
+
kernel_value[i, i % input_dim, 1, 1] = 1
|
| 327 |
+
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
| 328 |
+
kernel = self.id_tensor
|
| 329 |
+
running_mean = branch.running_mean
|
| 330 |
+
running_var = branch.running_var
|
| 331 |
+
gamma = branch.weight
|
| 332 |
+
beta = branch.bias
|
| 333 |
+
eps = branch.eps
|
| 334 |
+
std = (running_var + eps).sqrt()
|
| 335 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
| 336 |
+
return kernel * t, beta - running_mean * gamma / std
|
| 337 |
+
|
| 338 |
+
def repvgg_convert(self):
|
| 339 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
| 340 |
+
return kernel.detach().cpu().numpy(), bias.detach().cpu().numpy(),
|
models/module/fpn.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
"""
|
| 3 |
+
from MMDetection
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from models.module.conv import ConvModule
|
| 9 |
+
from models.module.init_weights import xavier_init
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FPN(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(self,
|
| 15 |
+
in_channels,
|
| 16 |
+
out_channels,
|
| 17 |
+
num_outs,
|
| 18 |
+
start_level=0,
|
| 19 |
+
end_level=-1,
|
| 20 |
+
conv_cfg=None,
|
| 21 |
+
norm_cfg=None,
|
| 22 |
+
activation=None
|
| 23 |
+
):
|
| 24 |
+
super(FPN, self).__init__()
|
| 25 |
+
assert isinstance(in_channels, list)
|
| 26 |
+
self.in_channels = in_channels
|
| 27 |
+
self.out_channels = out_channels
|
| 28 |
+
self.num_ins = len(in_channels)
|
| 29 |
+
self.num_outs = num_outs
|
| 30 |
+
self.fp16_enabled = False
|
| 31 |
+
|
| 32 |
+
if end_level == -1:
|
| 33 |
+
self.backbone_end_level = self.num_ins
|
| 34 |
+
assert num_outs >= self.num_ins - start_level
|
| 35 |
+
else:
|
| 36 |
+
# if end_level < inputs, no extra level is allowed
|
| 37 |
+
self.backbone_end_level = end_level
|
| 38 |
+
assert end_level <= len(in_channels)
|
| 39 |
+
assert num_outs == end_level - start_level
|
| 40 |
+
self.start_level = start_level
|
| 41 |
+
self.end_level = end_level
|
| 42 |
+
self.lateral_convs = nn.ModuleList()
|
| 43 |
+
|
| 44 |
+
for i in range(self.start_level, self.backbone_end_level):
|
| 45 |
+
l_conv = ConvModule(
|
| 46 |
+
in_channels[i],
|
| 47 |
+
out_channels,
|
| 48 |
+
1,
|
| 49 |
+
conv_cfg=conv_cfg,
|
| 50 |
+
norm_cfg=norm_cfg,
|
| 51 |
+
activation=activation,
|
| 52 |
+
inplace=False)
|
| 53 |
+
|
| 54 |
+
self.lateral_convs.append(l_conv)
|
| 55 |
+
self.init_weights()
|
| 56 |
+
|
| 57 |
+
# default init_weights for conv(msra) and norm in ConvModule
|
| 58 |
+
def init_weights(self):
|
| 59 |
+
for m in self.modules():
|
| 60 |
+
if isinstance(m, nn.Conv2d):
|
| 61 |
+
xavier_init(m, distribution='uniform')
|
| 62 |
+
|
| 63 |
+
def forward(self, inputs):
|
| 64 |
+
assert len(inputs) == len(self.in_channels)
|
| 65 |
+
|
| 66 |
+
# build laterals
|
| 67 |
+
laterals = [
|
| 68 |
+
lateral_conv(inputs[i + self.start_level])
|
| 69 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
# build top-down path
|
| 73 |
+
used_backbone_levels = len(laterals)
|
| 74 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 75 |
+
prev_shape = laterals[i - 1].shape[2:]
|
| 76 |
+
laterals[i - 1] += F.interpolate(
|
| 77 |
+
laterals[i], size=prev_shape, mode='bilinear', align_corners=False)
|
| 78 |
+
|
| 79 |
+
# build outputs
|
| 80 |
+
outs = [
|
| 81 |
+
# self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
|
| 82 |
+
laterals[i] for i in range(used_backbone_levels)
|
| 83 |
+
]
|
| 84 |
+
return tuple(outs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class PAN(FPN):
|
| 88 |
+
"""Path Aggregation Network for Instance Segmentation.
|
| 89 |
+
|
| 90 |
+
This is an implementation of the `PAN in Path Aggregation Network
|
| 91 |
+
<https://arxiv.org/abs/1803.01534>`_.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
in_channels (List[int]): Number of input channels per scale.
|
| 95 |
+
out_channels (int): Number of output channels (used at each scale)
|
| 96 |
+
num_outs (int): Number of output scales.
|
| 97 |
+
start_level (int): Index of the start input backbone level used to
|
| 98 |
+
build the feature pyramid. Default: 0.
|
| 99 |
+
end_level (int): Index of the end input backbone level (exclusive) to
|
| 100 |
+
build the feature pyramid. Default: -1, which means the last level.
|
| 101 |
+
add_extra_convs (bool): Whether to add conv layers on top of the
|
| 102 |
+
original feature maps. Default: False.
|
| 103 |
+
extra_convs_on_inputs (bool): Whether to apply extra conv on
|
| 104 |
+
the original feature from the backbone. Default: False.
|
| 105 |
+
relu_before_extra_convs (bool): Whether to apply relu before the extra
|
| 106 |
+
conv. Default: False.
|
| 107 |
+
no_norm_on_lateral (bool): Whether to apply norm on lateral.
|
| 108 |
+
Default: False.
|
| 109 |
+
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
| 110 |
+
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
| 111 |
+
act_cfg (str): Config dict for activation layer in ConvModule.
|
| 112 |
+
Default: None.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self,
|
| 116 |
+
in_channels,
|
| 117 |
+
out_channels,
|
| 118 |
+
num_outs,
|
| 119 |
+
start_level=0,
|
| 120 |
+
end_level=-1,
|
| 121 |
+
conv_cfg=None,
|
| 122 |
+
norm_cfg=None,
|
| 123 |
+
activation=None):
|
| 124 |
+
super(PAN,
|
| 125 |
+
self).__init__(in_channels, out_channels, num_outs, start_level,
|
| 126 |
+
end_level, conv_cfg, norm_cfg, activation)
|
| 127 |
+
self.init_weights()
|
| 128 |
+
|
| 129 |
+
def forward(self, inputs):
|
| 130 |
+
"""Forward function."""
|
| 131 |
+
assert len(inputs) == len(self.in_channels)
|
| 132 |
+
|
| 133 |
+
# build laterals
|
| 134 |
+
laterals = [
|
| 135 |
+
lateral_conv(inputs[i + self.start_level])
|
| 136 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
# build top-down path
|
| 140 |
+
used_backbone_levels = len(laterals)
|
| 141 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
| 142 |
+
prev_shape = laterals[i - 1].shape[2:]
|
| 143 |
+
laterals[i - 1] += F.interpolate(
|
| 144 |
+
laterals[i], size=prev_shape, mode='bilinear', align_corners=False)
|
| 145 |
+
|
| 146 |
+
# build outputs
|
| 147 |
+
# part 1: from original levels
|
| 148 |
+
inter_outs = [
|
| 149 |
+
laterals[i] for i in range(used_backbone_levels)
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# part 2: add bottom-up path
|
| 153 |
+
for i in range(0, used_backbone_levels - 1):
|
| 154 |
+
prev_shape = inter_outs[i + 1].shape[2:]
|
| 155 |
+
inter_outs[i + 1] += F.interpolate(inter_outs[i], size=prev_shape, mode='bilinear', align_corners=False)
|
| 156 |
+
|
| 157 |
+
outs = []
|
| 158 |
+
outs.append(inter_outs[0])
|
| 159 |
+
outs.extend([
|
| 160 |
+
inter_outs[i] for i in range(1, used_backbone_levels)
|
| 161 |
+
])
|
| 162 |
+
return tuple(outs)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# if __name__ == '__main__':
|
models/module/init_weights.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
from MMDetection
|
| 3 |
+
"""
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def kaiming_init(module,
|
| 8 |
+
a=0,
|
| 9 |
+
mode='fan_out',
|
| 10 |
+
nonlinearity='relu',
|
| 11 |
+
bias=0,
|
| 12 |
+
distribution='normal'):
|
| 13 |
+
assert distribution in ['uniform', 'normal']
|
| 14 |
+
if distribution == 'uniform':
|
| 15 |
+
nn.init.kaiming_uniform_(
|
| 16 |
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
| 17 |
+
else:
|
| 18 |
+
nn.init.kaiming_normal_(
|
| 19 |
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
| 20 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 21 |
+
nn.init.constant_(module.bias, bias)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
| 25 |
+
assert distribution in ['uniform', 'normal']
|
| 26 |
+
if distribution == 'uniform':
|
| 27 |
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
| 28 |
+
else:
|
| 29 |
+
nn.init.xavier_normal_(module.weight, gain=gain)
|
| 30 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 31 |
+
nn.init.constant_(module.bias, bias)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def normal_init(module, mean=0, std=1, bias=0):
|
| 35 |
+
nn.init.normal_(module.weight, mean, std)
|
| 36 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 37 |
+
nn.init.constant_(module.bias, bias)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def constant_init(module, val, bias=0):
|
| 41 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
| 42 |
+
nn.init.constant_(module.weight, val)
|
| 43 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
| 44 |
+
nn.init.constant_(module.bias, bias)
|
models/module/norm.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
norm_cfg = {
|
| 4 |
+
# format: layer_type: (abbreviation, module)
|
| 5 |
+
'BN': ('bn', nn.BatchNorm2d),
|
| 6 |
+
'SyncBN': ('bn', nn.SyncBatchNorm),
|
| 7 |
+
'GN': ('gn', nn.GroupNorm),
|
| 8 |
+
# and potentially 'SN'
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def build_norm_layer(cfg, num_features, postfix=''):
|
| 13 |
+
""" Build normalization layer
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
cfg (dict): cfg should contain:
|
| 17 |
+
type (str): identify norm layer type.
|
| 18 |
+
layer args: args needed to instantiate a norm layer.
|
| 19 |
+
requires_grad (bool): [optional] whether stop gradient updates
|
| 20 |
+
num_features (int): number of channels from input.
|
| 21 |
+
postfix (int, str): appended into norm abbreviation to
|
| 22 |
+
create named layer.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
name (str): abbreviation + postfix
|
| 26 |
+
layer (nn.Module): created norm layer
|
| 27 |
+
"""
|
| 28 |
+
assert isinstance(cfg, dict) and 'type' in cfg
|
| 29 |
+
cfg_ = cfg.copy()
|
| 30 |
+
|
| 31 |
+
layer_type = cfg_.pop('type')
|
| 32 |
+
if layer_type not in norm_cfg:
|
| 33 |
+
raise KeyError('Unrecognized norm type {}'.format(layer_type))
|
| 34 |
+
else:
|
| 35 |
+
abbr, norm_layer = norm_cfg[layer_type]
|
| 36 |
+
if norm_layer is None:
|
| 37 |
+
raise NotImplementedError
|
| 38 |
+
|
| 39 |
+
assert isinstance(postfix, (int, str))
|
| 40 |
+
name = abbr + str(postfix)
|
| 41 |
+
|
| 42 |
+
requires_grad = cfg_.pop('requires_grad', True)
|
| 43 |
+
cfg_.setdefault('eps', 1e-5)
|
| 44 |
+
if layer_type != 'GN':
|
| 45 |
+
layer = norm_layer(num_features, **cfg_)
|
| 46 |
+
if layer_type == 'SyncBN':
|
| 47 |
+
layer._specify_ddp_gpu_num(1)
|
| 48 |
+
else:
|
| 49 |
+
assert 'num_groups' in cfg_
|
| 50 |
+
layer = norm_layer(num_channels=num_features, **cfg_)
|
| 51 |
+
|
| 52 |
+
for param in layer.parameters():
|
| 53 |
+
param.requires_grad = requires_grad
|
| 54 |
+
|
| 55 |
+
return name, layer
|
models/shufflenet2_att_m.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from models.module.conv import DepthwiseConvModule
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ShuffleV2Block(nn.Module):
|
| 8 |
+
def __init__(self, inp, oup, mid_channels, ksize, stride, attention='', ratio=16, loc='side', onnx=False):
|
| 9 |
+
super(ShuffleV2Block, self).__init__()
|
| 10 |
+
self.onnx = onnx
|
| 11 |
+
self.stride = stride
|
| 12 |
+
assert stride in [1, 2]
|
| 13 |
+
|
| 14 |
+
self.mid_channels = mid_channels
|
| 15 |
+
self.ksize = ksize
|
| 16 |
+
pad = ksize // 2
|
| 17 |
+
self.pad = pad
|
| 18 |
+
self.inp = inp
|
| 19 |
+
|
| 20 |
+
outputs = oup - inp
|
| 21 |
+
|
| 22 |
+
branch_main = [
|
| 23 |
+
# pw
|
| 24 |
+
nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
|
| 25 |
+
nn.BatchNorm2d(mid_channels),
|
| 26 |
+
nn.ReLU(inplace=True),
|
| 27 |
+
# dw
|
| 28 |
+
nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
|
| 29 |
+
nn.BatchNorm2d(mid_channels),
|
| 30 |
+
# pw-linear
|
| 31 |
+
nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
|
| 32 |
+
nn.BatchNorm2d(outputs),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
]
|
| 35 |
+
self.branch_main = nn.Sequential(*branch_main)
|
| 36 |
+
|
| 37 |
+
if stride == 2:
|
| 38 |
+
branch_proj = [
|
| 39 |
+
# dw
|
| 40 |
+
nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
|
| 41 |
+
nn.BatchNorm2d(inp),
|
| 42 |
+
# pw-linear
|
| 43 |
+
nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
|
| 44 |
+
nn.BatchNorm2d(inp),
|
| 45 |
+
nn.ReLU(inplace=True),
|
| 46 |
+
]
|
| 47 |
+
self.branch_proj = nn.Sequential(*branch_proj)
|
| 48 |
+
else:
|
| 49 |
+
self.branch_proj = None
|
| 50 |
+
|
| 51 |
+
def forward(self, old_x):
|
| 52 |
+
if self.stride == 1:
|
| 53 |
+
x_proj, x = self.channel_shuffle(old_x)
|
| 54 |
+
else:
|
| 55 |
+
x_proj = old_x
|
| 56 |
+
x_proj = self.branch_proj(x_proj)
|
| 57 |
+
x = old_x
|
| 58 |
+
x = self.branch_main(x)
|
| 59 |
+
x = torch.cat((x_proj, x), 1)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
def channel_shuffle(self, x):
|
| 63 |
+
batchsize, num_channels, height, width = x.data.size()
|
| 64 |
+
if self.onnx:
|
| 65 |
+
# 由于需要将onnx模型转换为ifx模型,ifx引擎以nchw(n=1)的格式存储数据,因此做shape变换时,尽量保证按nchw(n=1)来操作
|
| 66 |
+
x = x.reshape(1, batchsize * num_channels // 2, 2, height * width)
|
| 67 |
+
x = x.permute(0, 2, 1, 3)
|
| 68 |
+
z = num_channels // 2
|
| 69 |
+
x = x.reshape(1, -1, height, width)
|
| 70 |
+
# split时避免使用x[0]、x[1]的操作,尽量使用torch的算子来实现
|
| 71 |
+
x1, x2 = torch.split(x, split_size_or_sections=[z, z], dim=1)
|
| 72 |
+
return x1, x2
|
| 73 |
+
else:
|
| 74 |
+
x = x.reshape(batchsize * num_channels // 2, 2, height * width)
|
| 75 |
+
x = x.permute(1, 0, 2)
|
| 76 |
+
x = x.reshape(2, -1, num_channels // 2, height, width)
|
| 77 |
+
return x[0], x[1]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ShuffleNetV2(nn.Module):
|
| 81 |
+
def __init__(self, num_tasks=0, task_types=0, num_classes=[], out_channel=1024, model_size='0.5x', with_last_conv=True,
|
| 82 |
+
attention='', loc='side', onnx=False, shuffle_block=None,
|
| 83 |
+
stack_lite_head=0, lite_head_channels=-1):
|
| 84 |
+
super(ShuffleNetV2, self).__init__()
|
| 85 |
+
# print('model size is ', model_size)
|
| 86 |
+
assert len(num_classes) == num_tasks, f"num task must equal to length of classes list for every task"
|
| 87 |
+
|
| 88 |
+
self.num_tasks = num_tasks
|
| 89 |
+
self.use_last_conv = with_last_conv
|
| 90 |
+
if isinstance(task_types, int):
|
| 91 |
+
task_types = [task_types] * num_tasks
|
| 92 |
+
if isinstance(stack_lite_head, int):
|
| 93 |
+
stack_lite_head = [stack_lite_head] * num_tasks
|
| 94 |
+
if isinstance(lite_head_channels, int):
|
| 95 |
+
lite_head_channels = [lite_head_channels] * num_tasks
|
| 96 |
+
self.task_types = task_types
|
| 97 |
+
self.stack_lite_head = stack_lite_head
|
| 98 |
+
self.lite_head_channels = lite_head_channels
|
| 99 |
+
self.onnx = onnx
|
| 100 |
+
self.stage_repeats = [4, 8, 4]
|
| 101 |
+
self.model_size = model_size
|
| 102 |
+
if model_size == '0.5x':
|
| 103 |
+
self.stage_out_channels = [-1, 24, 48, 96, 192] + [out_channel]
|
| 104 |
+
elif model_size == '1.0x':
|
| 105 |
+
self.stage_out_channels = [-1, 24, 116, 232, 464] + [out_channel]
|
| 106 |
+
elif model_size == '1.5x':
|
| 107 |
+
self.stage_out_channels = [-1, 24, 176, 352, 704] + [out_channel]
|
| 108 |
+
elif model_size == '2.0x':
|
| 109 |
+
self.stage_out_channels = [-1, 24, 244, 488, 976] + [out_channel]
|
| 110 |
+
else:
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
shuffle_block = ShuffleV2Block if shuffle_block is None else shuffle_block
|
| 114 |
+
# building first layer
|
| 115 |
+
input_channel = self.stage_out_channels[1]
|
| 116 |
+
self.first_conv = nn.Sequential(
|
| 117 |
+
nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),
|
| 118 |
+
nn.BatchNorm2d(input_channel),
|
| 119 |
+
nn.ReLU(inplace=True),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 123 |
+
|
| 124 |
+
self.features = []
|
| 125 |
+
for idxstage in range(len(self.stage_repeats)):
|
| 126 |
+
numrepeat = self.stage_repeats[idxstage]
|
| 127 |
+
output_channel = self.stage_out_channels[idxstage + 2]
|
| 128 |
+
|
| 129 |
+
for i in range(numrepeat):
|
| 130 |
+
if i == 0:
|
| 131 |
+
self.features.append(
|
| 132 |
+
shuffle_block(input_channel, output_channel, mid_channels=output_channel // 2, ksize=3,
|
| 133 |
+
stride=2, attention=attention, ratio=16, loc=loc, onnx=onnx))
|
| 134 |
+
else:
|
| 135 |
+
self.features.append(
|
| 136 |
+
shuffle_block(input_channel//2, output_channel, mid_channels=output_channel//2, ksize=3,
|
| 137 |
+
stride=1, attention=attention, ratio=16, loc=loc, onnx=onnx))
|
| 138 |
+
|
| 139 |
+
input_channel = output_channel
|
| 140 |
+
|
| 141 |
+
self.features = nn.Sequential(*self.features)
|
| 142 |
+
|
| 143 |
+
self.classifier_inchannels = self.stage_out_channels[-2]
|
| 144 |
+
if with_last_conv:
|
| 145 |
+
self.classifier_inchannels = self.stage_out_channels[-1]
|
| 146 |
+
self.conv_last = nn.Sequential(
|
| 147 |
+
nn.Conv2d(input_channel, self.classifier_inchannels, 1, 1, 0, bias=False),
|
| 148 |
+
nn.BatchNorm2d(self.stage_out_channels[-1]),
|
| 149 |
+
nn.ReLU(inplace=True))
|
| 150 |
+
|
| 151 |
+
self.globalpool = nn.AdaptiveAvgPool2d(output_size=1)
|
| 152 |
+
|
| 153 |
+
self.lite_head_channels = [self.classifier_inchannels if v == -1 else v for v in self.lite_head_channels]
|
| 154 |
+
for ti in range(self.num_tasks):
|
| 155 |
+
if self.stack_lite_head[ti]:
|
| 156 |
+
lite_head = []
|
| 157 |
+
for j in range(self.stack_lite_head[ti]):
|
| 158 |
+
ins = self.classifier_inchannels if j == 0 else self.lite_head_channels[ti]
|
| 159 |
+
outs = self.classifier_inchannels if j == self.stack_lite_head[ti]-1 else self.lite_head_channels[ti]
|
| 160 |
+
lite_head.append(DepthwiseConvModule(ins, outs, 3, 1, 1))
|
| 161 |
+
lite_head = nn.Sequential(*lite_head)
|
| 162 |
+
self.add_module(f"lite_head{ti}", lite_head)
|
| 163 |
+
classifier = nn.Sequential(nn.Linear(self.classifier_inchannels, num_classes[ti], bias=False))
|
| 164 |
+
self.add_module(f"classifier{ti}", classifier)
|
| 165 |
+
|
| 166 |
+
# self.loss_weights = nn.Parameter(torch.ones(num_tasks), requires_grad=True)
|
| 167 |
+
|
| 168 |
+
self._initialize_weights()
|
| 169 |
+
|
| 170 |
+
def _forward_impl(self, x):
|
| 171 |
+
x = self.first_conv(x)
|
| 172 |
+
x = self.maxpool(x)
|
| 173 |
+
x = self.features(x)
|
| 174 |
+
|
| 175 |
+
if self.use_last_conv:
|
| 176 |
+
x = self.conv_last(x)
|
| 177 |
+
|
| 178 |
+
output = []
|
| 179 |
+
for ti in range(self.num_tasks):
|
| 180 |
+
if self.stack_lite_head[ti]:
|
| 181 |
+
c_x = getattr(self, f"lite_head{ti}")(x)
|
| 182 |
+
c_x = self.globalpool(c_x)
|
| 183 |
+
c_x = c_x.contiguous().view(-1, self.classifier_inchannels)
|
| 184 |
+
c_x = getattr(self, f"classifier{ti}")(c_x)
|
| 185 |
+
else:
|
| 186 |
+
c_x = self.globalpool(x)
|
| 187 |
+
c_x = c_x.contiguous().view(-1, self.classifier_inchannels)
|
| 188 |
+
c_x = getattr(self, f"classifier{ti}")(c_x)
|
| 189 |
+
|
| 190 |
+
if self.onnx:
|
| 191 |
+
if self.task_types[ti] == 0:
|
| 192 |
+
c_x = torch.softmax(c_x, dim=1)
|
| 193 |
+
elif self.task_types[ti] == 1:
|
| 194 |
+
c_x *= (0.05/3)
|
| 195 |
+
elif self.task_types[ti] == 2:
|
| 196 |
+
c_x *= (0.7/3)
|
| 197 |
+
elif self.task_types[ti] == 3:
|
| 198 |
+
c_x *= (0.1/2)
|
| 199 |
+
else:
|
| 200 |
+
raise NotImplementedError(f"task_type only support [0, 1, 2, 3], current {self.task_types[ti]}")
|
| 201 |
+
|
| 202 |
+
output.append(c_x)
|
| 203 |
+
return output
|
| 204 |
+
|
| 205 |
+
def forward(self, x):
|
| 206 |
+
output = self._forward_impl(x)
|
| 207 |
+
return output
|
| 208 |
+
|
| 209 |
+
def _initialize_weights(self):
|
| 210 |
+
for name, m in self.named_modules():
|
| 211 |
+
if isinstance(m, nn.Conv2d):
|
| 212 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in')
|
| 213 |
+
if m.bias is not None:
|
| 214 |
+
nn.init.constant_(m.bias, 0)
|
| 215 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 216 |
+
nn.init.constant_(m.weight, 1)
|
| 217 |
+
if m.bias is not None:
|
| 218 |
+
nn.init.constant_(m.bias, 0.0001)
|
| 219 |
+
nn.init.constant_(m.running_mean, 0)
|
| 220 |
+
elif isinstance(m, nn.BatchNorm1d):
|
| 221 |
+
nn.init.constant_(m.weight, 1)
|
| 222 |
+
if m.bias is not None:
|
| 223 |
+
nn.init.constant_(m.bias, 0.0001)
|
| 224 |
+
nn.init.constant_(m.running_mean, 0)
|
| 225 |
+
elif isinstance(m, nn.Linear):
|
| 226 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in')
|
| 227 |
+
if m.bias is not None:
|
| 228 |
+
nn.init.constant_(m.bias, 0)
|
| 229 |
+
|
| 230 |
+
def get_last_share_layer(self):
|
| 231 |
+
assert self.use_last_conv, "Current implement need 'with_last_conv=True'"
|
| 232 |
+
return self.conv_last[0]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == '__main__':
|
| 236 |
+
from models.create_model import speed, model_info
|
| 237 |
+
from torchstat import stat
|
| 238 |
+
|
| 239 |
+
# a = se_resnet101()
|
| 240 |
+
# b = shufflenet_v2_x0_5()
|
| 241 |
+
input_size = (160, 160)
|
| 242 |
+
shufflenet2 = ShuffleNetV2(out_channel=192,
|
| 243 |
+
num_tasks=10, task_types=1, num_classes=[5, 3, 3, 3, 3, 1, 1, 1, 1, 1], model_size='0.5x', with_last_conv=0,
|
| 244 |
+
stack_lite_head=1, lite_head_channels=-1, onnx=True)
|
| 245 |
+
#
|
| 246 |
+
model_info(shufflenet2, img_size=input_size, verbose=False)
|
| 247 |
+
speed(shufflenet2, 'shufflenet2', size=input_size, device_type='cpu')
|
| 248 |
+
stat(shufflenet2, input_size=(3, 160, 160))
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# shufflenet2.eval()
|
| 253 |
+
#
|
| 254 |
+
# # example = torch.randn(1, 3, input_size[1], input_size[0])
|
| 255 |
+
# # torch.onnx.export(
|
| 256 |
+
# # shufflenet2, # model being run
|
| 257 |
+
# # example, # model input (or a tuple for multiple inputs)
|
| 258 |
+
# '1.onnx',
|
| 259 |
+
# verbose=False,
|
| 260 |
+
# # store the trained parameter weights inside the model file
|
| 261 |
+
# training=False,
|
| 262 |
+
# input_names=['input'],
|
| 263 |
+
# output_names=['output'],
|
| 264 |
+
# do_constant_folding=True
|
| 265 |
+
# )
|
utils/__pycache__/common.cpython-38.pyc
ADDED
|
Binary file (7.83 kB). View file
|
|
|
utils/__pycache__/images.cpython-38.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
utils/__pycache__/labels.cpython-38.pyc
ADDED
|
Binary file (27.5 kB). View file
|
|
|
utils/__pycache__/multiprogress.cpython-38.pyc
ADDED
|
Binary file (2.73 kB). View file
|
|
|
utils/__pycache__/os_util.cpython-38.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
utils/__pycache__/plt_util.cpython-38.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
utils/__pycache__/time_util.cpython-38.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
utils/common.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*-coding:utf-8-*-
|
| 2 |
+
import colorsys
|
| 3 |
+
import datetime
|
| 4 |
+
import functools
|
| 5 |
+
import signal
|
| 6 |
+
import sys
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Generate colors
|
| 15 |
+
def gen_color(n):
|
| 16 |
+
hsv_tuples = [(x / n, 1., 1.)
|
| 17 |
+
for x in range(n)]
|
| 18 |
+
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
|
| 19 |
+
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
|
| 20 |
+
# np.random.seed(10101) # Fixed seed for consistent colors across runs.
|
| 21 |
+
np.random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes.
|
| 22 |
+
# np.random.seed(None) # Reset seed to default.
|
| 23 |
+
return colors
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# pylint: disable=W0232
|
| 27 |
+
class Color:
|
| 28 |
+
GRAY = 30
|
| 29 |
+
RED = 31
|
| 30 |
+
GREEN = 32
|
| 31 |
+
YELLOW = 33
|
| 32 |
+
BLUE = 34
|
| 33 |
+
MAGENTA = 35
|
| 34 |
+
CYAN = 36
|
| 35 |
+
WHITE = 67
|
| 36 |
+
CRIMSON = 38
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# 返回字符串的输出格式码,调整前景色背景色、加粗等
|
| 40 |
+
def colorize(num, string, bold=False, highlight=False):
|
| 41 |
+
assert isinstance(num, int)
|
| 42 |
+
attr = []
|
| 43 |
+
if bold:
|
| 44 |
+
attr.append('1')
|
| 45 |
+
if highlight and num == 67:
|
| 46 |
+
num += 30
|
| 47 |
+
if highlight and num != 67:
|
| 48 |
+
num += 60
|
| 49 |
+
attr.append(str(num))
|
| 50 |
+
# \x1b[显示方式;前景色;背景色m +"输出内容"+\x1b[0m
|
| 51 |
+
# ; 的顺序可以改变
|
| 52 |
+
return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def colorprint(colorcode, text, o=sys.stdout, bold=False, highlight=False, end='\n'):
|
| 56 |
+
o.write(colorize(colorcode, text, bold=bold, highlight=highlight) + end)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def cprint(text, colorcode=67, bold=False, highlight=False, end='\n', prefix=None,
|
| 60 |
+
pre_color=34, pre_bold=True, pre_high=True, pre_end=': '):
|
| 61 |
+
prefix = str(prefix).rstrip() if prefix is not None else prefix
|
| 62 |
+
if prefix is not None:
|
| 63 |
+
prefix = prefix.rstrip(':') if ':' in pre_end else prefix
|
| 64 |
+
prefix += pre_end
|
| 65 |
+
colorprint(pre_color, prefix, sys.stdout, pre_bold, pre_high, end='')
|
| 66 |
+
colorprint(colorcode, text, bold=bold, highlight=highlight, end=end)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def log_warn(msg):
|
| 70 |
+
cprint(msg, colorcode=33, prefix='Warning', pre_color=33, highlight=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def log_error(msg):
|
| 74 |
+
cprint(msg, colorcode=31, prefix='Error', pre_color=31, highlight=True)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def now_time():
|
| 78 |
+
return datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# http://stackoverflow.com/questions/366682/how-to-limit-execution-time-of-a-function-call-in-python
|
| 82 |
+
class TimeoutException(Exception):
|
| 83 |
+
def __init__(self, msg):
|
| 84 |
+
self.msg = msg
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# 上下文管理器,装饰函数func, func()中必须有yield.使用方法:
|
| 88 |
+
# with func as f:
|
| 89 |
+
# do()
|
| 90 |
+
# with语句执行yield之前的代码,然后执行f和do,最后执行yield之后的代码
|
| 91 |
+
@contextmanager
|
| 92 |
+
def time_limit(seconds):
|
| 93 |
+
# 这是一个限制程序运行时间的函数,如果超出预设时间,会报错
|
| 94 |
+
def signal_handler(signum, frame): # 信号处理函数必须有这两个参数
|
| 95 |
+
raise TimeoutException(colorize(Color.RED, "Timed out! Retry again ...", highlight=True))
|
| 96 |
+
signal.signal(signal.SIGALRM, signal_handler) # 接收信号,signal.SIGALRM 为信号,signal_handler处理信号的函数
|
| 97 |
+
signal.alarm(seconds) # 如果seconds是非0,这个函数则响应一个SIGALRM信号并在seconds秒后发送到该进程。
|
| 98 |
+
try:
|
| 99 |
+
yield
|
| 100 |
+
finally:
|
| 101 |
+
signal.alarm(0) # 中断信号
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def clock(func): # 加强版
|
| 105 |
+
@functools.wraps(func) # 把func的属性复制给clocked,避免func的name变成clocked
|
| 106 |
+
def clocked(*args, **kwargs): # 内部函数可接受任意个参数,以及支持关键字参数
|
| 107 |
+
t0 = time.perf_counter()
|
| 108 |
+
result = func(*args, **kwargs) # clocked闭包中包含自由变量func
|
| 109 |
+
elapsed = time.perf_counter() - t0
|
| 110 |
+
name = func.__name__
|
| 111 |
+
arg_lst = []
|
| 112 |
+
if args:
|
| 113 |
+
arg_lst.append(', '.join(repr(arg) for arg in args))
|
| 114 |
+
if kwargs:
|
| 115 |
+
pairs = ['%s=%r' % (k, w) for k, w in sorted(kwargs.items())]
|
| 116 |
+
arg_lst.append(', '.join(pairs))
|
| 117 |
+
arg_str = ', '.join(arg_lst)
|
| 118 |
+
colorprint(Color.GREEN, '[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
|
| 119 |
+
return result
|
| 120 |
+
return clocked # 返回clocked,取代被装饰的函数
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
DEFAULT_FMT = '[{elapsed:0.8f}s] {name}({args}) -> {result}'
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def clock_custom(fmt=DEFAULT_FMT, color=Color.GREEN):
|
| 127 |
+
def decorate(func): # 加强版
|
| 128 |
+
@functools.wraps(func) # 把func的属性复制给clocked,避免func的name变成clocked
|
| 129 |
+
def clocked(*_args, **_kwargs): # 内部函数可接受任意个参数,以及支持关键字参数
|
| 130 |
+
t0 = time.perf_counter()
|
| 131 |
+
_result = func(*_args, **_kwargs) # clocked闭包中包含自由变量func
|
| 132 |
+
elapsed = time.perf_counter() - t0
|
| 133 |
+
name = func.__name__
|
| 134 |
+
_arg_lst = []
|
| 135 |
+
if _args:
|
| 136 |
+
_arg_lst.append(', '.join(repr(arg) for arg in _args))
|
| 137 |
+
if _kwargs:
|
| 138 |
+
_pairs = ['%s=%r' % (k, w) for k, w in sorted(_kwargs.items())]
|
| 139 |
+
_arg_lst.append(', '.join(_pairs))
|
| 140 |
+
args = ', '.join(_arg_lst)
|
| 141 |
+
result = repr(_result) # 字符串形式
|
| 142 |
+
colorprint(color, fmt.format(**locals())) # 使用clocked的局部变量
|
| 143 |
+
return _result # 返回原函数的结果
|
| 144 |
+
return clocked # 返回clocked,取代被装饰的函数
|
| 145 |
+
return decorate # 装饰器工厂必须返回装饰器
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def colorstr(*input):
|
| 149 |
+
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
| 150 |
+
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
| 151 |
+
colors = {'black': '\033[30m', # basic colors
|
| 152 |
+
'red': '\033[31m',
|
| 153 |
+
'green': '\033[32m',
|
| 154 |
+
'yellow': '\033[33m',
|
| 155 |
+
'blue': '\033[34m',
|
| 156 |
+
'magenta': '\033[35m',
|
| 157 |
+
'cyan': '\033[36m',
|
| 158 |
+
'white': '\033[37m',
|
| 159 |
+
'bright_black': '\033[90m', # bright colors
|
| 160 |
+
'bright_red': '\033[91m',
|
| 161 |
+
'bright_green': '\033[92m',
|
| 162 |
+
'bright_yellow': '\033[93m',
|
| 163 |
+
'bright_blue': '\033[94m',
|
| 164 |
+
'bright_magenta': '\033[95m',
|
| 165 |
+
'bright_cyan': '\033[96m',
|
| 166 |
+
'bright_white': '\033[97m',
|
| 167 |
+
'end': '\033[0m', # misc
|
| 168 |
+
'bold': '\033[1m',
|
| 169 |
+
'underline': '\033[4m'}
|
| 170 |
+
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Logger:
|
| 174 |
+
def __init__(self, file_path):
|
| 175 |
+
self.log_file_path = file_path
|
| 176 |
+
self.log_file = None
|
| 177 |
+
self.color = {
|
| 178 |
+
'R': Color.RED,
|
| 179 |
+
'B': Color.BLUE,
|
| 180 |
+
'G': Color.GREEN,
|
| 181 |
+
'Y': Color.YELLOW
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
def start(self):
|
| 185 |
+
self.log_file = open(self.log_file_path, 'w', encoding='utf-8')
|
| 186 |
+
|
| 187 |
+
def close(self):
|
| 188 |
+
if self.log_file is not None:
|
| 189 |
+
self.log_file.close()
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
def info(self, text, color='W', prefix=None, pre_color='B', pre_end=': ', prints=True):
|
| 193 |
+
assert self.log_file is not None, "Please firstly confirm 'logger.start()' method"
|
| 194 |
+
color_code = self.color.get(color, Color.WHITE)
|
| 195 |
+
pre_color = self.color.get(pre_color, Color.BLUE)
|
| 196 |
+
prefix = str(prefix).rstrip() if prefix is not None else prefix
|
| 197 |
+
if prefix is not None:
|
| 198 |
+
prefix = prefix.rstrip(':') if ':' in pre_end else prefix
|
| 199 |
+
prefix += pre_end
|
| 200 |
+
self.log_file.write(f"{prefix if prefix is not None else ''}{text}\n")
|
| 201 |
+
if prints:
|
| 202 |
+
cprint(text, color_code, prefix=prefix, pre_color=pre_color)
|
| 203 |
+
|
| 204 |
+
def error(self, text):
|
| 205 |
+
assert self.log_file is not None, "Please firstly confirm 'logger.start()' method"
|
| 206 |
+
# self.log_file.write(text + '\n')
|
| 207 |
+
log_error(text)
|
| 208 |
+
|
| 209 |
+
def warn(self, text):
|
| 210 |
+
assert self.log_file is not None, "Please firstly confirm 'logger.start()' method"
|
| 211 |
+
# self.log_file.write(text + '\n')
|
| 212 |
+
log_warn(text)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def r(val):
|
| 216 |
+
return int(np.random.random() * val)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# if __name__ == '__main__':
|
| 220 |
+
# #print(colorize(31, 'I am fine!', bold=True, highlight=True))
|
| 221 |
+
# #colorprint(35, 'I am fine!')
|
| 222 |
+
# #error('get out')
|
| 223 |
+
# import time
|
| 224 |
+
#
|
| 225 |
+
# # ends after 5 seconds
|
| 226 |
+
# with time_limit(2):
|
| 227 |
+
# time.sleep(3)
|
utils/export_util.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
# tensorflow=2.5.0 onnx-tf=1.9.0
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import onnxruntime
|
| 6 |
+
# import tensorflow as tf
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def onnx_export_pipeline(onnx_path, export_func, simplify=True, onnx_sim_path=None, pack=True, view_net=True):
|
| 10 |
+
if not os.path.exists(os.path.dirname(onnx_path)):
|
| 11 |
+
os.makedirs(os.path.dirname(onnx_path))
|
| 12 |
+
|
| 13 |
+
# 导出
|
| 14 |
+
export_func(export_name=onnx_path) # 转化onnx的函数,与各自的配置密切相关
|
| 15 |
+
if not simplify:
|
| 16 |
+
print(f"Onnx model export in '{onnx_path}'")
|
| 17 |
+
|
| 18 |
+
if simplify: # 算子简化
|
| 19 |
+
time.sleep(1)
|
| 20 |
+
onnx_sim_path = onnx_sim_path if onnx_sim_path else onnx_path
|
| 21 |
+
os.system(f"python -m onnxsim {onnx_path} {onnx_sim_path}")
|
| 22 |
+
print(f"Simplify onnx model export in '{onnx_sim_path}'")
|
| 23 |
+
|
| 24 |
+
if pack: # 压缩打包
|
| 25 |
+
time.sleep(1)
|
| 26 |
+
src_path = onnx_sim_path if simplify else onnx_path
|
| 27 |
+
src_dir, src_name = os.path.dirname(src_path), os.path.basename(src_path)
|
| 28 |
+
os.system(f'cd {src_dir} && tar -zcf {src_name}.tgz {src_name}')
|
| 29 |
+
print(f"TGZ file save in '{src_path}.tgz'")
|
| 30 |
+
|
| 31 |
+
if view_net: # 查看网络结构
|
| 32 |
+
import netron
|
| 33 |
+
src_path = onnx_sim_path if simplify else onnx_path
|
| 34 |
+
netron.start(src_path)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ONNXModel:
|
| 38 |
+
def __init__(self, onnx_path):
|
| 39 |
+
"""
|
| 40 |
+
:param onnx_path:
|
| 41 |
+
"""
|
| 42 |
+
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
|
| 43 |
+
self.input_name = self.get_input_name(self.onnx_session)
|
| 44 |
+
self.output_name = self.get_output_name(self.onnx_session)
|
| 45 |
+
print(f"loading {onnx_path}")
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def get_output_name(onnx_session):
|
| 49 |
+
"""
|
| 50 |
+
output_name = onnx_session.get_outputs()[0].name
|
| 51 |
+
:param onnx_session:
|
| 52 |
+
:return:
|
| 53 |
+
"""
|
| 54 |
+
output_name = []
|
| 55 |
+
for node in onnx_session.get_outputs():
|
| 56 |
+
output_name.append(node.name)
|
| 57 |
+
return output_name
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_input_name(onnx_session):
|
| 61 |
+
"""
|
| 62 |
+
input_name = onnx_session.get_inputs()[0].name
|
| 63 |
+
:param onnx_session:
|
| 64 |
+
:return:
|
| 65 |
+
"""
|
| 66 |
+
input_name = []
|
| 67 |
+
for node in onnx_session.get_inputs():
|
| 68 |
+
input_name.append(node.name)
|
| 69 |
+
return input_name
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def get_input_feed(input_name, image_numpy):
|
| 73 |
+
"""
|
| 74 |
+
input_feed={self.input_name: image_numpy}
|
| 75 |
+
:param input_name:
|
| 76 |
+
:param image_numpy:
|
| 77 |
+
:return:
|
| 78 |
+
"""
|
| 79 |
+
input_feed = {}
|
| 80 |
+
for name in input_name:
|
| 81 |
+
input_feed[name] = image_numpy
|
| 82 |
+
return input_feed
|
| 83 |
+
|
| 84 |
+
def forward(self, image_numpy):
|
| 85 |
+
'''
|
| 86 |
+
# image_numpy = image.transpose(2, 0, 1)
|
| 87 |
+
# image_numpy = image_numpy[np.newaxis, :]
|
| 88 |
+
# onnx_session.run([output_name], {input_name: x})
|
| 89 |
+
# :param image_numpy:
|
| 90 |
+
# :return:
|
| 91 |
+
'''
|
| 92 |
+
input_feed = self.get_input_feed(self.input_name, image_numpy)
|
| 93 |
+
outputs = self.onnx_session.run(self.output_name, input_feed=input_feed)
|
| 94 |
+
return outputs
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def to_numpy(tensor):
|
| 98 |
+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def export_tflite(onnx_path, tflite_path):
|
| 102 |
+
save_model_dir = os.path.splitext(tflite_path)[0]
|
| 103 |
+
onnx_to_pb_cmd = f'onnx-tf convert -i {onnx_path} -o {save_model_dir}'
|
| 104 |
+
os.system(onnx_to_pb_cmd)
|
| 105 |
+
print(f"Convert onnx model to pb in '{save_model_dir}'")
|
| 106 |
+
|
| 107 |
+
time.sleep(1)
|
| 108 |
+
# make a converter object from the saved tensorflow file
|
| 109 |
+
converter = tf.lite.TFLiteConverter.from_saved_model(save_model_dir)
|
| 110 |
+
# tell converter which type of optimization techniques to use
|
| 111 |
+
# to view the best option for optimization read documentation of tflite about optimization
|
| 112 |
+
# go to this link https://www.tensorflow.org/lite/guide/get_started#4_optimize_your_model_optional
|
| 113 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
| 114 |
+
# convert the model
|
| 115 |
+
tf_lite_model = converter.convert()
|
| 116 |
+
# save the converted model
|
| 117 |
+
open(tflite_path, 'wb').write(tf_lite_model)
|
| 118 |
+
print(f"TFlite model save in '{tflite_path}'")
|
utils/images.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding:utf-8 –*-
|
| 2 |
+
import random
|
| 3 |
+
from math import *
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
from PIL import ImageDraw, Image, ImageFont
|
| 7 |
+
|
| 8 |
+
from utils.common import *
|
| 9 |
+
from utils.multiprogress import MultiThreading
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def show(image, name="Press 'q' exit"):
|
| 13 |
+
if not hasattr(image, 'shape'):
|
| 14 |
+
img_array = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
| 15 |
+
else:
|
| 16 |
+
img_array = image.copy()
|
| 17 |
+
cv2.imshow(name, img_array)
|
| 18 |
+
cv2.moveWindow(name, 0, 0)
|
| 19 |
+
key = cv2.waitKey(0) & 0xEFFFFF
|
| 20 |
+
if key == ord('q'):
|
| 21 |
+
cv2.destroyWindow(name)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def crop_face_square(img, box, pad=5):
|
| 25 |
+
img_h, img_w, img_c = img.shape
|
| 26 |
+
if box != [0, 0, 0, 0]:
|
| 27 |
+
x0, y0, x1, y1 = box
|
| 28 |
+
w, h = x1-x0, y1-y0
|
| 29 |
+
# 扩充为正方形
|
| 30 |
+
long_side = min(max(h, w) + int(2*pad), img_h)
|
| 31 |
+
w_add, h_add = (long_side - w) // 2, (long_side - h) // 2
|
| 32 |
+
crop_x0, crop_y0 = max(x0 - w_add, 0), max(y0 - h_add, 0)
|
| 33 |
+
crop_x1, crop_y1 = min(crop_x0 + long_side, img_w), min(crop_y0 + long_side, img_h)
|
| 34 |
+
crop_x0, crop_y0 = crop_x1 - long_side, crop_y1 - long_side
|
| 35 |
+
return img[crop_y0:crop_y1, crop_x0:crop_x1], [crop_x0, crop_y0, crop_x1, crop_y1]
|
| 36 |
+
else:
|
| 37 |
+
# print('No detected box, crop right rect.')
|
| 38 |
+
if img_h == 960:
|
| 39 |
+
img = img[60:780, :]
|
| 40 |
+
img_h, img_w = img.shape[:2]
|
| 41 |
+
return img[:, img_w - img_h:img_w], [0, 0, 0, 0]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def crop_face_square_rate(img, box, rate=0.1):
|
| 45 |
+
img_h, img_w, img_c = img.shape
|
| 46 |
+
if box != [0, 0, 0, 0]:
|
| 47 |
+
x0, y0, x1, y1 = box
|
| 48 |
+
w, h = x1-x0, y1-y0
|
| 49 |
+
# 扩充为正方形
|
| 50 |
+
pad = max(w, h) * rate
|
| 51 |
+
long_side = min(max(h, w) + int(2*pad), img_h)
|
| 52 |
+
w_add, h_add = (long_side - w) // 2, (long_side - h) // 2
|
| 53 |
+
crop_x0, crop_y0 = max(x0 - w_add, 0), max(y0 - h_add, 0)
|
| 54 |
+
crop_x1, crop_y1 = min(crop_x0 + long_side, img_w), min(crop_y0 + long_side, img_h)
|
| 55 |
+
crop_x0, crop_y0 = crop_x1 - long_side, crop_y1 - long_side
|
| 56 |
+
return img[crop_y0:crop_y1, crop_x0:crop_x1], [crop_x0, crop_y0, crop_x1, crop_y1]
|
| 57 |
+
else:
|
| 58 |
+
if img_h == 960:
|
| 59 |
+
crop_x0, crop_x1 = img_w - 720, img_w
|
| 60 |
+
crop_y0, crop_y1 = 60, 780
|
| 61 |
+
else:
|
| 62 |
+
crop_x0, crop_x1 = img_w - img_h, img_w
|
| 63 |
+
crop_y0, crop_y1 = 0, img_h
|
| 64 |
+
return img[crop_y0:crop_y1, crop_x0:crop_x1], [crop_x0, crop_y0, crop_x1, crop_y1]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def expand_box_rate(img, box, rate=0.1):
|
| 68 |
+
img_h, img_w, img_c = img.shape
|
| 69 |
+
if box != [0, 0, 0, 0]:
|
| 70 |
+
x0, y0, x1, y1 = box
|
| 71 |
+
w, h = x1-x0, y1-y0
|
| 72 |
+
# 扩充为正方形
|
| 73 |
+
pad = max(w, h) * rate
|
| 74 |
+
long_side = min(max(h, w) + int(2*pad), img_h)
|
| 75 |
+
w_add, h_add = (long_side - w) // 2, (long_side - h) // 2
|
| 76 |
+
crop_x0, crop_y0 = x0 - w_add, y0 - h_add
|
| 77 |
+
crop_x1, crop_y1 = crop_x0 + long_side, crop_y0 + long_side
|
| 78 |
+
return [crop_x0, crop_y0, crop_x1, crop_y1]
|
| 79 |
+
else:
|
| 80 |
+
if img_h == 960:
|
| 81 |
+
crop_x0, crop_x1 = img_w - 720, img_w
|
| 82 |
+
crop_y0, crop_y1 = 60, 780
|
| 83 |
+
else:
|
| 84 |
+
crop_x0, crop_x1 = img_w - img_h, img_w
|
| 85 |
+
crop_y0, crop_y1 = 0, img_h
|
| 86 |
+
return [crop_x0, crop_y0, crop_x1, crop_y1]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def crop_with_pad(img, crop_box):
|
| 90 |
+
img_h, img_w, img_c = img.shape
|
| 91 |
+
x0, y0, x1, y1 = crop_box
|
| 92 |
+
w, h = x1 - x0, y1 - y0
|
| 93 |
+
if tuple(crop_box) == (0, 0, 0, 0) or w <= 50 or h <= 50: # 背景,裁右半图
|
| 94 |
+
if img_h == 960:
|
| 95 |
+
crop_x0, crop_x1 = img_w - 720, img_w
|
| 96 |
+
crop_y0, crop_y1 = 60, 780
|
| 97 |
+
else:
|
| 98 |
+
crop_x0, crop_x1 = img_w - img_h, img_w
|
| 99 |
+
crop_y0, crop_y1 = 0, img_h
|
| 100 |
+
crop_img = img[crop_y0:crop_y1, crop_x0:crop_x1]
|
| 101 |
+
return crop_img
|
| 102 |
+
else:
|
| 103 |
+
crop_x0, crop_y0 = max(x0, 0), max(y0, 0)
|
| 104 |
+
crop_x1, crop_y1 = min(x1, img_w), min(y1, img_h)
|
| 105 |
+
left, top, right, bottom = crop_x0 - x0, crop_y0 - y0, x1 - crop_x1, y1 - crop_y1
|
| 106 |
+
crop_img = img[crop_y0:crop_y1, crop_x0:crop_x1]
|
| 107 |
+
crop_img = cv2.copyMakeBorder(crop_img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128))
|
| 108 |
+
return crop_img
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def clip_paste_rect(bg_img, fg_img, loc, refine=True):
|
| 112 |
+
"""
|
| 113 |
+
选取粘贴区域,不要超出背景范围
|
| 114 |
+
loc: 中心点(cx, cy)
|
| 115 |
+
"""
|
| 116 |
+
bg_h, bg_w = bg_img.shape[:2]
|
| 117 |
+
fg_h, fg_w = fg_img.shape[:2]
|
| 118 |
+
fg_h = fg_h - 1 if fg_h % 2 else fg_h
|
| 119 |
+
fg_w = fg_w - 1 if fg_w % 2 else fg_w
|
| 120 |
+
cx, cy = loc
|
| 121 |
+
left, top, right, bottom = cx - fg_w // 2, cy - fg_h // 2, cx + fg_w // 2, cy + fg_h // 2
|
| 122 |
+
|
| 123 |
+
if refine:
|
| 124 |
+
right, bottom = min(right, bg_w), min(bottom, bg_h)
|
| 125 |
+
left, top = right - fg_w, bottom - fg_h
|
| 126 |
+
left, top = max(0, left), max(0, top)
|
| 127 |
+
right, bottom = left + fg_w, top + fg_h
|
| 128 |
+
|
| 129 |
+
plot_x1, plot_y1, plot_x2, plot_y2 = left, top, right, bottom
|
| 130 |
+
use_x1, use_y1, use_x2, use_y2 = 0, 0, fg_w, fg_h
|
| 131 |
+
|
| 132 |
+
if left < 0:
|
| 133 |
+
plot_x1, use_x1 = 0, -left
|
| 134 |
+
if top < 0:
|
| 135 |
+
plot_y1, use_y1 = 0, -top
|
| 136 |
+
if right > bg_w:
|
| 137 |
+
plot_x2, use_x2 = bg_w, fg_w - (right - bg_w)
|
| 138 |
+
if bottom > bg_h:
|
| 139 |
+
plot_y2, use_y2 = bg_h, fg_h - (bottom - bg_h)
|
| 140 |
+
|
| 141 |
+
use_bg = bg_img[plot_y1:plot_y2, plot_x1:plot_x2]
|
| 142 |
+
use_fg = fg_img[use_y1:use_y2, use_x1:use_x2]
|
| 143 |
+
window = (plot_x1, plot_y1, plot_x2, plot_y2)
|
| 144 |
+
|
| 145 |
+
return use_bg, use_fg, window
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# @clock_custom(fmt=DEFAULT_FMT)
|
| 149 |
+
def paste(bg_img, fg_img, loc, trans_thresh=1, refine=True):
|
| 150 |
+
"""
|
| 151 |
+
贴图
|
| 152 |
+
loc: center (cx, cy)
|
| 153 |
+
"""
|
| 154 |
+
use_bg, use_fg, window = clip_paste_rect(bg_img, fg_img, loc, refine)
|
| 155 |
+
plot_x1, plot_y1, plot_x2, plot_y2 = window
|
| 156 |
+
b, g, r, a = cv2.split(use_fg)
|
| 157 |
+
a[a > 0] = 255
|
| 158 |
+
a = np.dstack([a, a, a]) * trans_thresh
|
| 159 |
+
use_bg = use_bg * (255.0 - a) / 255
|
| 160 |
+
use_bg += use_fg[:, :, :3] * (a / 255)
|
| 161 |
+
bg_img[plot_y1:plot_y2, plot_x1:plot_x2] = use_bg.astype('uint8')
|
| 162 |
+
return bg_img
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def put_chinese_text(img, text, position, font, text_color=(0, 255, 0)):
|
| 166 |
+
if isinstance(img, np.ndarray): # 判断是否OpenCV图片类型
|
| 167 |
+
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
| 168 |
+
# 创建一个可以在给定图像上绘图的对象
|
| 169 |
+
draw = ImageDraw.Draw(img)
|
| 170 |
+
draw.text(position, text, text_color, font=font)
|
| 171 |
+
# 转换回OpenCV格式
|
| 172 |
+
return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def compute_mean_std(img_list, preprocess=None, workers=12):
|
| 176 |
+
"""
|
| 177 |
+
计算公式:
|
| 178 |
+
S^2 = sum((x - x')^2) / N = sum(x^2 + x'^2 - 2xx') / N
|
| 179 |
+
= (sum(x^2) + sum(x'^2) - 2x'*sum(x)) / N
|
| 180 |
+
= (sum(x^2) + N*(x'^2) - 2x'*(N * x')) / N
|
| 181 |
+
= (sum(x^2) - N * (x'^2)) / N
|
| 182 |
+
= sum(x^2) / N - (sum(x) / N)^2
|
| 183 |
+
= mean(x^2) - mean(x)^2 = E(x^2) - E(x)^2
|
| 184 |
+
:param img_list:
|
| 185 |
+
:param workers:
|
| 186 |
+
:param preprocess: 图像预处理函数,需要有BGR->RGB以及归一化操作
|
| 187 |
+
:return: RGB means, stds
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def cal_func(info):
|
| 191 |
+
i, img_path = info
|
| 192 |
+
img = cv2.imread(img_path)
|
| 193 |
+
if preprocess:
|
| 194 |
+
img = preprocess(img)
|
| 195 |
+
else:
|
| 196 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 197 |
+
img = img.astype(np.float32) / 255
|
| 198 |
+
img = img.reshape((-1, 3))
|
| 199 |
+
channel_mean = np.mean(img, axis=0)
|
| 200 |
+
channel_var = np.mean(img**2, axis=0)
|
| 201 |
+
if i % 1000 == 0:
|
| 202 |
+
print(f"{i}/{len(img_list)}")
|
| 203 |
+
return channel_mean, channel_var
|
| 204 |
+
|
| 205 |
+
exe = MultiThreading(list(enumerate(img_list)), workers=min(workers, len(img_list)))
|
| 206 |
+
res = exe.run(cal_func)
|
| 207 |
+
all_means = np.array([r[0] for r in res])
|
| 208 |
+
all_vars = np.array([r[1] for r in res])
|
| 209 |
+
means = np.mean(all_means, axis=0)
|
| 210 |
+
vars = np.mean(all_vars, axis=0)
|
| 211 |
+
stds = np.sqrt(vars - means ** 2)
|
| 212 |
+
|
| 213 |
+
return means, stds
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def draw_province(f, val):
|
| 217 |
+
img = Image.new("RGB", (45, 70), (255, 255, 255))
|
| 218 |
+
draw = ImageDraw.Draw(img)
|
| 219 |
+
draw.text((0, 3), val, (0, 0, 0), font=f)
|
| 220 |
+
img = img.resize((23, 70))
|
| 221 |
+
char = np.array(img)
|
| 222 |
+
return char
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def draw_chars(f, val):
|
| 226 |
+
img = Image.new("RGB", (23, 70), (255, 255, 255))
|
| 227 |
+
draw = ImageDraw.Draw(img)
|
| 228 |
+
draw.text((0, 2), val, (0, 0, 0), font=f)
|
| 229 |
+
A = np.array(img)
|
| 230 |
+
return A
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# 双层车牌
|
| 234 |
+
def draw_province_double(f, val):
|
| 235 |
+
img = Image.new("RGB", (60, 60), (255, 255, 255))
|
| 236 |
+
draw = ImageDraw.Draw(img)
|
| 237 |
+
draw.text((0, -12), val, (0, 0, 0), font=f)
|
| 238 |
+
img = img.resize((80, 60))
|
| 239 |
+
# img.show()
|
| 240 |
+
char = np.array(img)
|
| 241 |
+
return char
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def draw_chars_ceil(f, val):
|
| 245 |
+
img = Image.new("RGB", (30, 45), (255, 255, 255))
|
| 246 |
+
draw = ImageDraw.Draw(img)
|
| 247 |
+
draw.text((1, -12), val, (0, 0, 0), font=f)
|
| 248 |
+
img = img.resize((80, 60))
|
| 249 |
+
# img.show()
|
| 250 |
+
char = np.array(img)
|
| 251 |
+
return char
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def draw_chars_floor(f, val):
|
| 255 |
+
img = Image.new("RGB", (30, 45), (255, 255, 255))
|
| 256 |
+
draw = ImageDraw.Draw(img)
|
| 257 |
+
draw.text((1, -12), val, (0, 0, 0), font=f)
|
| 258 |
+
img = img.resize((65, 110))
|
| 259 |
+
# img.show()
|
| 260 |
+
char = np.array(img)
|
| 261 |
+
return char
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# 图片做旧处理
|
| 265 |
+
def add_smudginess(img):
|
| 266 |
+
smu = cv2.imread('data/bgp/smu.jpg')
|
| 267 |
+
img_h, img_w = img.shape[:2]
|
| 268 |
+
rows = r(smu.shape[0] - img_h)
|
| 269 |
+
cols = r(smu.shape[1] - img_w)
|
| 270 |
+
adder = smu[rows:rows + img_h, cols:cols + img_w]
|
| 271 |
+
adder = cv2.resize(adder, (img_w, img_h))
|
| 272 |
+
adder = cv2.bitwise_not(adder)
|
| 273 |
+
val = random.random() * 0.5
|
| 274 |
+
img = cv2.addWeighted(img, 1 - val, adder, val, 0.0)
|
| 275 |
+
return img
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def rot(img, angel, shape, max_angel):
|
| 279 |
+
""" 使图像轻微的畸变
|
| 280 |
+
img 输入图像
|
| 281 |
+
factor 畸变的参数
|
| 282 |
+
size 为图片的目标尺寸
|
| 283 |
+
"""
|
| 284 |
+
size_o = [shape[1], shape[0]]
|
| 285 |
+
size = (shape[1] + int(shape[0] * sin((float(max_angel)/180) * 3.14)), shape[0])
|
| 286 |
+
interval = abs(int(sin((float(angel) / 180) * 3.14) * shape[0]))
|
| 287 |
+
pts1 = np.float32([[0, 0], [0, size_o[1]], [size_o[0], 0], [size_o[0], size_o[1]]])
|
| 288 |
+
if angel > 0:
|
| 289 |
+
pts2 = np.float32([[interval, 0], [0, size[1]], [size[0], 0], [size[0]-interval, size_o[1]]])
|
| 290 |
+
else:
|
| 291 |
+
pts2 = np.float32([[0, 0], [interval, size[1]], [size[0]-interval, 0], [size[0], size_o[1]]])
|
| 292 |
+
M = cv2.getPerspectiveTransform(pts1, pts2)
|
| 293 |
+
dst = cv2.warpPerspective(img, M, size)
|
| 294 |
+
return dst
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def random_rot(img, factor, size):
|
| 298 |
+
shape = size
|
| 299 |
+
pts1 = np.float32([[0, 0], [0, shape[0]], [shape[1], 0], [shape[1], shape[0]]])
|
| 300 |
+
pts2 = np.float32([[r(factor), r(factor)],
|
| 301 |
+
[r(factor), shape[0] - r(factor)],
|
| 302 |
+
[shape[1] - r(factor), r(factor)],
|
| 303 |
+
[shape[1] - r(factor), shape[0] - r(factor)]])
|
| 304 |
+
M = cv2.getPerspectiveTransform(pts1, pts2)
|
| 305 |
+
dst = cv2.warpPerspective(img, M, size)
|
| 306 |
+
return dst
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def random_rot_expend(img, factor, size):
|
| 310 |
+
height, width = size
|
| 311 |
+
degree = factor
|
| 312 |
+
# 旋转后的尺寸
|
| 313 |
+
n_h = int(width * fabs(sin(radians(degree))) + height * fabs(cos(radians(degree))))
|
| 314 |
+
n_w = int(height * fabs(sin(radians(degree))) + width * fabs(cos(radians(degree))))
|
| 315 |
+
M = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1)
|
| 316 |
+
M[0, 2] += (n_w - width) / 2 # 重点在这步,目前不懂为什么加这步
|
| 317 |
+
M[1, 2] += (n_h - height) / 2 # 重点在这步
|
| 318 |
+
rot_img = cv2.warpAffine(img, M, (n_w, n_h), borderValue=(0, 0, 0))
|
| 319 |
+
return rot_img
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def random_rot_keep_size_left_right(img, factor1, factor2, size):
|
| 323 |
+
# 透视变换 factor1 大于0 factor2 可正可负
|
| 324 |
+
shape = size
|
| 325 |
+
width = shape[0]
|
| 326 |
+
height = shape[1]
|
| 327 |
+
pts1 = np.float32([[0, 0], [width - 1, 0], [0, height-1], [width - 1, height-1]])
|
| 328 |
+
point1_x = 0
|
| 329 |
+
point1_y = 0
|
| 330 |
+
point2_x = width - r(factor1)
|
| 331 |
+
point2_y = r(factor2) # shape[0] - r(factor)
|
| 332 |
+
point3_x = 0 # shape[1] - r(factor)
|
| 333 |
+
point3_y = height # r(factor)
|
| 334 |
+
point4_x = width - r(factor1) # shape[1]
|
| 335 |
+
point4_y = height - r(factor2) # shape[0]
|
| 336 |
+
max_x = max(point2_x, point4_x)
|
| 337 |
+
max_y = point3_y
|
| 338 |
+
if factor2 < 0:
|
| 339 |
+
point1_x = 0
|
| 340 |
+
point1_y = 0 - r(factor2)
|
| 341 |
+
point2_x = width - r(factor1)
|
| 342 |
+
point2_y = 0
|
| 343 |
+
point3_x = 0
|
| 344 |
+
point3_y = height + r(factor2)
|
| 345 |
+
point4_x = width - r(factor1)
|
| 346 |
+
point4_y = height
|
| 347 |
+
max_x = max(point2_x, point4_x)
|
| 348 |
+
max_y = point4_y
|
| 349 |
+
pts2 = np.float32([[point1_x, point1_y],
|
| 350 |
+
[point2_x, point2_y],
|
| 351 |
+
[point3_x, point3_y],
|
| 352 |
+
[point4_x, point4_y]])
|
| 353 |
+
M = cv2.getPerspectiveTransform(pts1, pts2)
|
| 354 |
+
size2 = (max_x, max_y)
|
| 355 |
+
dst = cv2.warpPerspective(img, M, size2) # cv2.warpPerspective(img, M, size)
|
| 356 |
+
return dst
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# 仿射变换
|
| 360 |
+
def affine_transform(img, factor, size):
|
| 361 |
+
shape = size
|
| 362 |
+
pts1 = np.float32([[0, shape[0]], [shape[1], 0], [shape[1], shape[0]]])
|
| 363 |
+
pts2 = np.float32([[r(factor), shape[0] - r(factor)],
|
| 364 |
+
[shape[1] - r(factor), r(factor)],
|
| 365 |
+
[shape[1], shape[0]]]) # [shape[1] - r(factor), shape[0] - r(factor)]])
|
| 366 |
+
M = cv2.getAffineTransform(pts1, pts2)
|
| 367 |
+
dst = cv2.warpAffine(img, M, size)
|
| 368 |
+
return dst
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
# 腐蚀
|
| 372 |
+
def cv_erode(img, factor):
|
| 373 |
+
value = r(factor)+1
|
| 374 |
+
kernel = np.ones((value, value), np.uint8) * factor
|
| 375 |
+
erosion = cv2.erode(img, kernel, iterations=1)
|
| 376 |
+
return erosion
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# 膨胀
|
| 380 |
+
def cv_dilate(img, factor):
|
| 381 |
+
value = r(factor)+1
|
| 382 |
+
kernel = np.ones((value, value), np.uint8)
|
| 383 |
+
dilate = cv2.dilate(img, kernel, iterations=1)
|
| 384 |
+
return dilate
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def add_random_noise(img, factor):
|
| 388 |
+
value = r(factor)
|
| 389 |
+
for k in range(value): # Create 5000 noisy pixels
|
| 390 |
+
i = random.randint(0, img.shape[0] - 1)
|
| 391 |
+
j = random.randint(0, img.shape[1] - 1)
|
| 392 |
+
color = (random.randrange(256), random.randrange(256), random.randrange(256))
|
| 393 |
+
img[i, j] = color
|
| 394 |
+
return img
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# 生成卷积核和锚点
|
| 398 |
+
def gen_kernel_anchor(length, angle):
|
| 399 |
+
half = length / 2
|
| 400 |
+
EPS = np.finfo(float).eps
|
| 401 |
+
alpha = (angle - floor(angle / 180) * 180) / 180 * pi
|
| 402 |
+
cosalpha = cos(alpha)
|
| 403 |
+
sinalpha = sin(alpha)
|
| 404 |
+
if cosalpha < 0:
|
| 405 |
+
xsign = -1
|
| 406 |
+
elif angle == 90:
|
| 407 |
+
xsign = 0
|
| 408 |
+
else:
|
| 409 |
+
xsign = 1
|
| 410 |
+
psfwdt = 1
|
| 411 |
+
# 模糊核大小
|
| 412 |
+
sx = int(fabs(length * cosalpha + psfwdt * xsign - length * EPS))
|
| 413 |
+
sy = int(fabs(length * sinalpha + psfwdt - length * EPS))
|
| 414 |
+
psf1 = np.zeros((sy, sx))
|
| 415 |
+
# psf1是左上角的权值较大,越往右下角权值越小的核。
|
| 416 |
+
# 这时运动像是从右下角到左上角移动
|
| 417 |
+
for i in range(0, sy):
|
| 418 |
+
for j in range(0, sx):
|
| 419 |
+
psf1[i][j] = i * fabs(cosalpha) - j * sinalpha
|
| 420 |
+
rad = sqrt(i * i + j * j)
|
| 421 |
+
if rad >= half and fabs(psf1[i][j]) <= psfwdt:
|
| 422 |
+
temp = half - fabs((j + psf1[i][j] * sinalpha) / cosalpha)
|
| 423 |
+
psf1[i][j] = sqrt(psf1[i][j] * psf1[i][j] + temp * temp)
|
| 424 |
+
psf1[i][j] = psfwdt + EPS - fabs(psf1[i][j])
|
| 425 |
+
if psf1[i][j] < 0:
|
| 426 |
+
psf1[i][j] = 0
|
| 427 |
+
# 运动方向是往左上运动,锚点在(0,0)
|
| 428 |
+
anchor = (0, 0)
|
| 429 |
+
# 运动方向是往右上角移动,锚点一个在右上角
|
| 430 |
+
# 同时,左右翻转核函数,使得越靠近锚点,权值越大
|
| 431 |
+
if 0 < angle < 90:
|
| 432 |
+
psf1 = np.fliplr(psf1)
|
| 433 |
+
anchor = (psf1.shape[1] - 1, 0)
|
| 434 |
+
elif -90 < angle < 0: # 同理:往右下角移动
|
| 435 |
+
psf1 = np.flipud(psf1)
|
| 436 |
+
psf1 = np.fliplr(psf1)
|
| 437 |
+
anchor = (psf1.shape[1] - 1, psf1.shape[0] - 1)
|
| 438 |
+
elif angle < -90: # 同理:往左下角移动
|
| 439 |
+
psf1 = np.flipud(psf1)
|
| 440 |
+
anchor = (0, psf1.shape[0] - 1)
|
| 441 |
+
psf1 = psf1 / psf1.sum()
|
| 442 |
+
return psf1, anchor
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def hsv_transform(img):
|
| 446 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 447 |
+
# hsv[:, :, 0] = hsv[:, :, 0] * (0.2 + np.random.random() * 0.8)
|
| 448 |
+
hsv[:, :, 1] = hsv[:, :, 1] * (0.5 + np.random.random()*0.5)
|
| 449 |
+
hsv[:, :, 2] = hsv[:, :, 2] * (0.5 + np.random.random()*0.5)
|
| 450 |
+
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
| 451 |
+
return img
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def kth_diag_indices(array, k=0, n=-1):
|
| 455 |
+
"""
|
| 456 |
+
第k个对角线索引
|
| 457 |
+
:param array: 输入二维矩阵
|
| 458 |
+
:param k: k=0 主对角线 k>0 右移 k<0 左移
|
| 459 |
+
:param n: n=-1, 对角线所有元素,否则只返回长为n的列表
|
| 460 |
+
:return: 索引
|
| 461 |
+
"""
|
| 462 |
+
if n == -1:
|
| 463 |
+
rows, cols = np.diag_indices_from(array)
|
| 464 |
+
else:
|
| 465 |
+
rows, cols = np.diag_indices(n)
|
| 466 |
+
if k < 0:
|
| 467 |
+
return rows[-k:], cols[:k]
|
| 468 |
+
elif k > 0:
|
| 469 |
+
return rows[:-k], cols[k:]
|
| 470 |
+
else:
|
| 471 |
+
return rows, cols
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def slash_mask(mask, start_index, end_index, set_value=0.0, mod='l', length=-1):
|
| 475 |
+
"""
|
| 476 |
+
制作斜条纹的蒙版
|
| 477 |
+
:param mask:
|
| 478 |
+
:param start_index:
|
| 479 |
+
:param end_index:
|
| 480 |
+
:param set_value:
|
| 481 |
+
:param mod: 左斜还是右斜
|
| 482 |
+
:param length: 长度
|
| 483 |
+
:return:
|
| 484 |
+
"""
|
| 485 |
+
h, w = mask.shape[:2]
|
| 486 |
+
assert length <= min(h, w)
|
| 487 |
+
if mod == 'r':
|
| 488 |
+
mask = np.fliplr(mask)
|
| 489 |
+
for i in range(start_index, end_index+1):
|
| 490 |
+
mask[kth_diag_indices(mask, i, length)] = set_value
|
| 491 |
+
if mod == 'r':
|
| 492 |
+
mask = np.fliplr(mask)
|
| 493 |
+
return mask
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def line_mask(mask, start_index, end_index, set_value=0.0, mod='h', length=-1):
|
| 497 |
+
"""
|
| 498 |
+
制作条纹蒙版
|
| 499 |
+
:param mask:
|
| 500 |
+
:param start_index:
|
| 501 |
+
:param end_index:
|
| 502 |
+
:param set_value:
|
| 503 |
+
:param mod: h 横 v 竖
|
| 504 |
+
:param length: 长度
|
| 505 |
+
:return:
|
| 506 |
+
"""
|
| 507 |
+
h, w = mask.shape[:2]
|
| 508 |
+
if mod == 'h':
|
| 509 |
+
assert length <= w
|
| 510 |
+
assert 0 <= start_index < end_index < h
|
| 511 |
+
if length == -1 or length == w:
|
| 512 |
+
mask[start_index: end_index+1, :] = set_value
|
| 513 |
+
else:
|
| 514 |
+
left = random.randint(0, w-length-1)
|
| 515 |
+
right = left + length
|
| 516 |
+
mask[start_index: end_index + 1, left:right] = set_value
|
| 517 |
+
else:
|
| 518 |
+
assert length <= h
|
| 519 |
+
assert 0 <= start_index < end_index < w
|
| 520 |
+
if length == -1 or length == h:
|
| 521 |
+
mask[:, start_index: end_index + 1] = set_value
|
| 522 |
+
else:
|
| 523 |
+
top = random.randint(0, h - length - 1)
|
| 524 |
+
bottom = top + length
|
| 525 |
+
mask[top:bottom + 1, start_index: end_index] = set_value
|
| 526 |
+
return mask
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def read_binary_images(file_path):
|
| 530 |
+
img = np.fromfile(file_path, dtype=np.uint8)
|
| 531 |
+
img = np.reshape(img, (720, 1280, 3))
|
| 532 |
+
cv2.imshow('img', img)
|
| 533 |
+
cv2.waitKey(0)
|
| 534 |
+
print()
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
if __name__ == '__main__':
|
| 539 |
+
read_binary_images('/Users/didi/Desktop/座次/pic/bgr1641010923242.bgr')
|