diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..932d5c5c033bb233d03f65f72fa3b8d77d91f294 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+data/dms3/v1.0/dms3_mtl_v1.0.pth filter=lfs diff=lfs merge=lfs -text
+data/dms3/v1.0/dms3_mtl_v1.0.onnx filter=lfs diff=lfs merge=lfs -text
+head_detect/models/HeadDetectorv1.6.onnx filter=lfs diff=lfs merge=lfs -text
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..08db57f89a53f13c99eadbb5b70d3c2ee5d80aca
--- /dev/null
+++ b/app.py
@@ -0,0 +1,179 @@
+import os
+
+import gradio as gr
+
+from inference_mtl import inference_xyl
+from inference_video_mtl import inference_videos
+
+
+SAVE_DIR = "/nfs/volume-236-2/eval_model/model_res"
+# SAVE_DIR = "/Users/didi/Desktop/model_res"
+CURRENT_DIR = ''
+os.makedirs(SAVE_DIR, exist_ok=True)
+
+PLOT_DIR = ""
+TASKS = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l", 'shift_x', 'shift_y', 'expand']
+TASK_IDX = 0
+
+
+def last_img():
+ global TASK_IDX
+ TASK_IDX -= 1
+ if TASK_IDX < 0:
+ TASK_IDX = len(TASKS) - 1
+ plt_path = f"{PLOT_DIR}/{TASKS[TASK_IDX]}.jpg"
+
+ if not os.path.exists(plt_path):
+ return
+
+ return plt_path
+
+
+def next_img():
+ global TASK_IDX
+ TASK_IDX += 1
+ if TASK_IDX >= len(TASKS):
+ TASK_IDX = 0
+ plt_path = f"{PLOT_DIR}/{TASKS[TASK_IDX]}.jpg"
+
+ if not os.path.exists(plt_path):
+ return
+
+ return plt_path
+
+
+def inference_img(inp, engine):
+ inp = inp[:, :, ::-1]
+
+ if engine == "DMS3":
+ drawn = inference_xyl(inp, vis=False, return_drawn=True)
+ drawn = drawn[:, :, ::-1]
+ return drawn
+
+
+def inference_txt(inp, engine):
+ inp = inp[:, :, ::-1]
+
+ if engine == "DMS3":
+ msg_list = inference_xyl(inp, vis=False, return_drawn=False)[-1]
+ return "\n".join(msg_list)
+
+
+def inference_video(video, engine, device_type_, show_score_, syn_plot_):
+ global CURRENT_DIR
+ CURRENT_DIR = os.path.join(SAVE_DIR, engine)
+ os.makedirs(CURRENT_DIR, exist_ok=True)
+
+ base_name = os.path.splitext(os.path.basename(video))[0][:-8]
+
+ video_info = []
+ if device_type_ == "B100":
+ resize = (1280, 720)
+ video_info.append("720")
+ elif device_type_ == "B200":
+ resize = (1280, 960)
+ video_info.append("960")
+ else:
+ resize = None
+
+ is_syn_plot = syn_plot_ == "是"
+ is_show_score = show_score_ == "是"
+
+ video_info.append("1" if is_syn_plot else "0")
+ video_info.append(base_name + "_res.mp4")
+
+ save_video_name = "_".join(video_info)
+ save_plt_dir = base_name + "_plots"
+ save_csv_name = base_name + "_res.csv"
+
+ save_video_path = f"{CURRENT_DIR}/{save_video_name}"
+ save_plt_path = f"{CURRENT_DIR}/{save_plt_dir}"
+ save_csv_path = f"{CURRENT_DIR}/{save_csv_name}"
+
+ global PLOT_DIR, TASK_IDX
+ PLOT_DIR = save_plt_path
+ TASK_IDX = 0
+
+ if os.path.exists(save_video_path) and os.path.exists(save_plt_path):
+ if not is_show_score:
+ return save_video_path, f"{save_plt_path}/{TASKS[TASK_IDX]}.jpg", None
+ elif os.path.exists(save_csv_path):
+ return save_video_path, f"{save_plt_path}/{TASKS[TASK_IDX]}.jpg", save_csv_path
+
+ inference_videos(
+ video, save_dir=SAVE_DIR, detect_mode='second', frequency=0.2, plot_score=True, save_score=is_show_score,
+ syn_plot=is_syn_plot, save_vdo=True, save_img=False, continuous=False, show_res=False, resize=resize,
+ time_delta=1, save_vdo_path=save_video_path, save_plt_path=save_plt_path, save_csv_path=save_csv_path)
+
+ if not is_show_score:
+ return save_video_path, f"{save_plt_path}/{TASKS[TASK_IDX]}.jpg", None
+
+ return save_video_path, f"{save_plt_path}/{TASKS[TASK_IDX]}.jpg", save_csv_path
+
+
+def reset_state():
+ return None, None, None, None
+
+
+with gr.Blocks() as demo:
+ gr.HTML("""
算法可视化Demo
""")
+
+ with gr.Tab("Text"): # 标签页1
+ with gr.Row(): # 并行显示,可开多列
+ with gr.Column(): # 并列显示,可开多行
+ input_img1 = gr.Image(label="", show_label=False)
+ engine1 = gr.Dropdown(["DMS3", "EMS", "TURN2"], label="算法引擎", value="DMS3")
+ btn_txt = gr.Button(value="Submit", label="Submit Image", variant="primary")
+ btn_clear1 = gr.Button("Clear")
+
+ out1 = gr.Text(label="")
+
+ btn_txt.click(inference_txt, inputs=[input_img1, engine1], outputs=out1, show_progress=True) # 触发
+ btn_clear1.click(reset_state, inputs=[], outputs=[input_img1, out1])
+
+ with gr.Tab("Image"): # 标签页1
+ with gr.Row(): # 并行显示,可开多列
+ with gr.Column(): # 并列显示,可开多行
+ input_img2 = gr.Image(label="", show_label=False)
+ engine2 = gr.Dropdown(["DMS3", "EMS", "TURN2"], label="算法引擎", value="DMS3")
+ btn_img = gr.Button(value="Submit", label="Submit Image", variant="primary")
+ btn_clear2 = gr.Button("Clear")
+
+ out2 = gr.Image(label="", show_label=False)
+
+ btn_img.click(inference_img, inputs=[input_img2, engine2], outputs=out2, show_progress=True) # 触发
+ btn_clear2.click(reset_state, inputs=[], outputs=[input_img2, out2])
+
+ with gr.Tab("Video"): # 标签页2
+ with gr.Row(): # 并行显示,可开多列
+
+ with gr.Column(): # 并列显示,可开多行
+ input_vdo = gr.Video(label="", show_label=False)
+ engine3 = gr.Dropdown(["DMS3", "EMS", "TURN2"], label="算法引擎", value="DMS3")
+ device_type = gr.Radio(["原始", "B100", "B200"], value="原始", label="Device Type") # 单选
+
+ with gr.Row():
+ show_score = gr.Radio(["是", "否"], value="是", label="分数明细") # 单选
+ syn_plot = gr.Radio(["是", "否"], value="是", label="Syn Plot") # 单选
+
+ btn_vdo = gr.Button(value="Submit", label="Submit Video", variant="primary")
+ btn_clear3 = gr.Button("Clear")
+
+ with gr.Column():
+ out3 = gr.PlayableVideo(label="", show_label=False)
+ out3_plt = gr.Image(label=f"{TASKS[TASK_IDX]}", show_label=False)
+
+ with gr.Row():
+ btn_before = gr.Button(value="上一张", label="before")
+ btn_next = gr.Button(value="下一张", label="next", variant="primary")
+
+ out3_df = gr.DataFrame()
+
+ btn_vdo.click(inference_video, inputs=[input_vdo, engine3, device_type, show_score, syn_plot],
+ outputs=[out3, out3_plt, out3_df], show_progress=True) # 触发
+ btn_clear3.click(reset_state, inputs=[], outputs=[input_vdo, out3, out3_plt, out3_df])
+
+ btn_before.click(last_img, inputs=[], outputs=out3_plt)
+ btn_next.click(next_img, inputs=[], outputs=out3_plt)
+
+demo.queue().launch(share=True, inbrowser=True, favicon_path="files/dms3_icon.png")
diff --git a/data/dms3/v1.0/dms3_mtl_v1.0.onnx b/data/dms3/v1.0/dms3_mtl_v1.0.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..a0552f079526f791891d40a764cbd6570e339a3c
--- /dev/null
+++ b/data/dms3/v1.0/dms3_mtl_v1.0.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f4f06fd029196afea2522b194b1ba292a10a31f2792ca5c3d12f35ffd5530f02
+size 2176811
diff --git a/data/dms3/v1.0/dms3_mtl_v1.0.pth b/data/dms3/v1.0/dms3_mtl_v1.0.pth
new file mode 100755
index 0000000000000000000000000000000000000000..18000b6b36f41b6a44b4a54b1216b1e3ffc59aef
--- /dev/null
+++ b/data/dms3/v1.0/dms3_mtl_v1.0.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:483c1b16abc39462272a35006cc92993d98c385f8d05cc2ad869ad5d041b17e7
+size 2444005
diff --git "a/data/dms3/v1.0/\345\216\273\346\216\211\345\215\261\351\231\251\345\212\250\344\275\234litehead.png" "b/data/dms3/v1.0/\345\216\273\346\216\211\345\215\261\351\231\251\345\212\250\344\275\234litehead.png"
new file mode 100644
index 0000000000000000000000000000000000000000..cfef3b46b60e49b4ef353b16e625c3aa4ef90b04
Binary files /dev/null and "b/data/dms3/v1.0/\345\216\273\346\216\211\345\215\261\351\231\251\345\212\250\344\275\234litehead.png" differ
diff --git a/files/dms3_icon.png b/files/dms3_icon.png
new file mode 100644
index 0000000000000000000000000000000000000000..6a94c8e8522908c3d376b2ce7b4b182c6a4e69c4
Binary files /dev/null and b/files/dms3_icon.png differ
diff --git a/files/frame_b.png b/files/frame_b.png
new file mode 100644
index 0000000000000000000000000000000000000000..58a6582b5fad935927d0399b3164bc141bbd3978
Binary files /dev/null and b/files/frame_b.png differ
diff --git a/files/info_box.png b/files/info_box.png
new file mode 100644
index 0000000000000000000000000000000000000000..cd973ec8659ca7e2aef1421bb00533acedfc9a80
Binary files /dev/null and b/files/info_box.png differ
diff --git a/head_detect/__pycache__/demo.cpython-38.pyc b/head_detect/__pycache__/demo.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e04f6497f7346b2527f05d600c2701c4fa54ebd
Binary files /dev/null and b/head_detect/__pycache__/demo.cpython-38.pyc differ
diff --git a/head_detect/demo.py b/head_detect/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfe233c0166102a0705c1ace419ef7253fcc3639
--- /dev/null
+++ b/head_detect/demo.py
@@ -0,0 +1,59 @@
+# -*- coding:utf-8 –*-
+import os
+import sys
+
+import cv2
+
+sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
+from head_detect.head_detector.head_detectorv4 import HeadDetector
+from head_detect.utils_quailty_assurance.utils_quailty_assurance import find_files, write_json
+from utils.images import crop_face_square_rate
+import numpy as np
+
+ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
+model_path = f"{ROOT_DIR}/models/HeadDetectorv1.6.onnx"
+
+headDetector = HeadDetector(onnx_path=model_path)
+
+
+def detect_driver_face(img, show_res=False, show_crop=False):
+ if isinstance(img, str):
+ img = cv2.imread(img)
+ else:
+ img = img.copy()
+ image_heigth,image_width = img.shape[:2]
+ # origin_image = cv2.resize(img, (1280, 720))
+ # origin_image = img[:, w//2:, :]
+ # short_side = min(image_heigth,image_width)
+ width_shift,heigth_shift = image_width//2,0
+ # cv2.imwrite("squre.jpg",img[heigth_shift:,width_shift:,:])
+ bboxes = headDetector.run(img[:,width_shift:, :], get_largest=True) # 人脸检测,获取面积最大的人脸
+ if not bboxes:
+ return [0, 0, 0, 0], 0
+ box = [int(width_shift+bboxes[0][0]),int(heigth_shift+bboxes[0][1]),int(width_shift+bboxes[0][2]),int(heigth_shift+bboxes[0][3])]
+
+ # box[0], box[1] = max(0, box[0]), max(0, box[1])
+ # box[2], box[3] = min(w, box[2]), min(h, box[3])
+ score = bboxes[0][-1]
+
+ if (box[2] - box[0]) == 0 or (box[3] - box[1]) == 0:
+ return [0, 0, 0, 0], 0
+
+ # print(box, pred, score)
+ if show_res:
+ x0, y0, x1, y1 = box
+ print(box)
+ cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), thickness=2)
+ if show_crop:
+ _, crop_box = crop_face_square_rate(img, box, rate=-0.07)
+ cx0, cy0, cx1, cy1 = crop_box
+ print(crop_box)
+ cv2.rectangle(img, (cx0, cy0), (cx1, cy1), (0, 255, 0), thickness=2)
+ cv2.imshow('res', img)
+ cv2.waitKey(0)
+
+ return box, score
+
+
+if __name__ == "__main__":
+ detect_driver_face('../../test_visual/look_down.jpg', show_res=True, show_crop=True)
diff --git a/head_detect/demo_ori.py b/head_detect/demo_ori.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c1d7f49c0f3b61fd408c9bd95f5593580d24014
--- /dev/null
+++ b/head_detect/demo_ori.py
@@ -0,0 +1,58 @@
+# -*- coding:utf-8 –*-
+import os
+import sys
+
+import cv2
+
+sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
+from head_detect.head_detector.head_detectorv2 import HeadDetector
+from head_detect.utils_quailty_assurance.utils_quailty_assurance import find_files, write_json
+from utils.images import crop_face_square_rate
+import numpy as np
+
+ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
+model_path = f"{ROOT_DIR}/models/Mona_HeadDetector_v1_straght.onnx"
+
+headDetector = HeadDetector(onnx_path=model_path)
+
+
+def detect_driver_face(img, show_res=False, show_crop=False):
+ if isinstance(img, str):
+ img = cv2.imread(img)
+ else:
+ img = img.copy()
+ h, w = img.shape[:2]
+ # origin_image = cv2.resize(img, (1280, 720))
+ # origin_image = img[:, w//2:, :]
+ bboxes = headDetector.run(img[:, w//2:, :], get_largest=True) # 人脸检测,获取面积最大的人脸
+ if not bboxes:
+ return [0, 0, 0, 0], 0
+
+ box = [int(b) for b in bboxes[0][:4]]
+ box[0] += w // 2
+ box[2] += w // 2
+ # box[0], box[1] = max(0, box[0]), max(0, box[1])
+ # box[2], box[3] = min(w, box[2]), min(h, box[3])
+ score = bboxes[0][-1]
+
+ if (box[2] - box[0]) == 0 or (box[3] - box[1]) == 0:
+ return [0, 0, 0, 0], 0
+
+ # print(box, pred, score)
+ if show_res:
+ x0, y0, x1, y1 = box
+ print(box)
+ cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), thickness=2)
+ if show_crop:
+ _, crop_box = crop_face_square_rate(img, box, rate=-0.07)
+ cx0, cy0, cx1, cy1 = crop_box
+ print(crop_box)
+ cv2.rectangle(img, (cx0, cy0), (cx1, cy1), (0, 255, 0), thickness=2)
+ cv2.imshow('res', img)
+ cv2.waitKey(0)
+
+ return box, score
+
+
+if __name__ == "__main__":
+ detect_driver_face('../../test_visual/look_down.jpg', show_res=True, show_crop=True)
diff --git a/head_detect/head_detector/__pycache__/face_utils.cpython-38.pyc b/head_detect/head_detector/__pycache__/face_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bdac2773485433939fb5e0b08ff1ff76c2adff46
Binary files /dev/null and b/head_detect/head_detector/__pycache__/face_utils.cpython-38.pyc differ
diff --git a/head_detect/head_detector/__pycache__/head_detectorv4.cpython-38.pyc b/head_detect/head_detector/__pycache__/head_detectorv4.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b3fadfd682cb6d12b1158111e0e4f07e7ff32b3
Binary files /dev/null and b/head_detect/head_detector/__pycache__/head_detectorv4.cpython-38.pyc differ
diff --git a/head_detect/head_detector/face_detector.py b/head_detect/head_detector/face_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa6c2eda5d5154d4dc1a1710eee06e768f77e584
--- /dev/null
+++ b/head_detect/head_detector/face_detector.py
@@ -0,0 +1,258 @@
+#针对1280X720的图片
+import os
+import cv2
+import copy
+import onnxruntime
+import numpy as np
+from face_detector.face_utils import letterbox
+#from face_utils import letterbox
+
+class FaceDetector:
+ def __init__(self, onnx_path="models/smoke_phone_mosaic_v2.onnx"):
+
+ self.onnx_path = onnx_path
+ self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
+ self.input_name = self.get_input_name(self.onnx_session)
+ self.output_name = self.get_output_name(self.onnx_session)
+ self.multi_class = 3 #修改点1:代表输出的结果为[normal smoke,phone,drink]四类
+
+ def get_output_name(self, onnx_session):
+ output_name = []
+ for node in onnx_session.get_outputs():
+ output_name.append(node.name)
+ return output_name
+
+ def get_input_name(self, onnx_session):
+ input_name = []
+ for node in onnx_session.get_inputs():
+ input_name.append(node.name)
+ return input_name
+
+ def get_input_feed(self, input_name, image_tensor):
+
+ input_feed = {}
+ for name in input_name:
+ input_feed[name] = image_tensor
+ return input_feed
+
+ def after_process(self,pred):
+ # 输入尺寸320,192 降8、16、32倍,对应输出尺寸为(40、20、10)
+ stride = np.array([8., 16., 32.])
+ x=[pred[0],pred[1],pred[2]]
+ # ============yolov5参数 start============
+ nc=1
+ no=16 + self.multi_class
+ nl=3
+ na=3
+ #grid=[torch.zeros(1).to(device)] * nl
+ grid=[np.zeros(1)]*nl
+ anchor_grid=np.array([[[[[[ 4., 5.]]],
+ [[[ 8., 10.]]],
+ [[[ 13., 16.]]]]],
+ [[[[[ 23., 29.]]],
+ [[[ 43., 55.]]],
+ [[[ 73., 105.]]]]],
+ [[[[[146., 217.]]],
+ [[[231., 300.]]],
+ [[[335., 433.]]]]]])
+ # ============yolov5-0.5参数 end============
+ z = []
+ for i in range(len(x)):
+
+ bs,ny, nx = x[i].shape[0],x[i].shape[2] ,x[i].shape[3]
+
+ if grid[i].shape[2:4] != x[i].shape[2:4]:
+ grid[i] = self._make_grid(nx, ny)
+
+ y = np.full_like(x[i],0)
+
+ #y[..., [0,1,2,3,4,15]] = self.sigmoid_v(x[i][..., [0,1,2,3,4,15]])
+ y[..., [0,1,2,3,4,15,16,17,18]] = self.sigmoid_v(x[i][..., [0,1,2,3,4,15,16,17,18]])
+ # 同事sigmoid_v人脸的置信度和危险动作置信度
+
+ y[..., 5:15] = x[i][..., 5:15]
+
+
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i]) * stride[i] # xy
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
+
+ y[..., 5:7] = y[..., 5:7] * anchor_grid[i] + grid[i] * stride[i] # landmark x1 y1
+ y[..., 7:9] = y[..., 7:9] * anchor_grid[i] + grid[i] * stride[i]# landmark x2 y2
+ y[..., 9:11] = y[..., 9:11] * anchor_grid[i] + grid[i] * stride[i]# landmark x3 y3
+ y[..., 11:13] = y[..., 11:13] * anchor_grid[i] + grid[i] * stride[i]# landmark x4 y4
+ y[..., 13:15] = y[..., 13:15] * anchor_grid[i] + grid[i] * stride[i]# landmark x5 y5
+
+ z.append(y.reshape((bs, -1, no)))
+ return np.concatenate(z, 1)
+
+
+ def _make_grid(self, nx, ny):
+ yv, xv = np.meshgrid(np.arange(ny), np.arange(nx),indexing = 'ij')
+ return np.stack((xv, yv), 2).reshape((1, 1, ny, nx, 2)).astype(float)
+
+ def sigmoid_v(self, array):
+ return np.reciprocal(np.exp(-array) + 1.0)
+
+ def img_process(self,orgimg,long_side=320,stride_max=32):
+
+ #orgimg=cv2.imread(img_path)
+ img0 = copy.deepcopy(orgimg)
+ h0, w0 = orgimg.shape[:2] # orig hw
+ r = long_side/ max(h0, w0) # resize image to img_size
+ if r != 1: # always resize down, only resize up if training with augmentation
+ interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
+ img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
+ img = letterbox(img0, new_shape=(192,320),auto=False)[0] # auto True最小矩形 False固定尺度
+ # cv2.imwrite("convert.jpg",img=img)
+ # Convert
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
+ img = img.astype("float32") # uint8 to fp16/32
+ img /= 255.0 # 0 - 255 to 0.0 - 1.0
+ img = img[np.newaxis,:]
+
+ return img,orgimg
+
+ def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+ coords[:, [0, 2, 5, 7, 9, 11, 13]] -= pad[0] # x padding
+ coords[:, [1, 3, 6, 8, 10,12, 14]] -= pad[1] # y padding
+
+ coords[:, [0,1,2,3,5,6,7,8,9,10,11,12,13,14]] /= gain
+
+ return coords
+
+
+ def non_max_suppression(self, boxes,confs, iou_thres=0.6):
+
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = confs.flatten().argsort()[::-1]
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+ inds = np.where( ovr <= iou_thres)[0]
+ order = order[inds + 1]
+
+ return boxes[keep]
+
+ def nms(self, pred, conf_thres=0.1,iou_thres=0.5):
+ xc = pred[..., 4] > conf_thres
+ pred = pred[xc]
+ pred[:, 15:] *= pred[:, 4:5]
+
+ # best class only
+ confs = np.amax(pred[:, 15:16], 1, keepdims=True)
+ pred[..., 0:4] = self.xywh2xyxy(pred[..., 0:4])
+ return self.non_max_suppression(pred, confs, iou_thres)
+
+ def xywh2xyxy(self, x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = np.zeros_like(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+ def get_largest_face(self,pred):
+ """[获取图片中最大的人脸]
+
+ Args:
+ object ([dict]): [人脸数据]
+
+ Returns:
+ [int]: [最大人脸的坐标]
+ """
+ max_index = 0
+ max_value = 0
+ for index in range(len(pred)):
+ xmin,ymin,xmax,ymax = pred[index][:4]
+ w = xmax - xmin
+ h = ymax - ymin
+ if w*h > max_value:
+ max_value = w*h
+ max_index = index
+ return max_index
+
+ def run(self, ori_image,get_largest=True):
+ #detial_dict = {}
+
+ img,orgimg=self.img_process(ori_image,long_side=320) #[1,3,640,640]
+ #print(img.shape)
+ input_feed = self.get_input_feed(self.input_name, img)
+ pred = self.onnx_session.run(self.output_name, input_feed=input_feed)
+ #detial_dict["before_decoder"] = [i.tolist() for i in pred]
+ pred=self.after_process(pred) # torch后处理
+ #detial_dict["after_decoder"] = copy.deepcopy(pred.tolist())
+
+ pred=self.nms(pred[0],0.3,0.5)
+ #detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
+ pred=self.scale_coords(img.shape[2:], pred, orgimg.shape)
+ #detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
+
+ if get_largest and pred.shape[0]!=0 :
+ pred_index = self.get_largest_face(pred)
+ pred = pred[[pred_index]]
+ bboxes = pred[:,[0,1,2,3,15]]
+ landmarks = pred[:,5:15]
+ ## 修改点,获取危险动作的标签和置信度
+ multi_class = np.argmax(pred[:,16:],axis=1)
+ multi_conf = np.amax(pred[:, 16:], axis=1)
+ landmarks = np.reshape(landmarks,(-1,5,2))
+ return bboxes,landmarks,multi_class,multi_conf
+
+
+if __name__ == '__main__':
+ onnxmodel = FaceDetector()
+ image_path = "/tmp-data/QACode/yolov5_quality_assurance/ONNX_Inference/smoke_phone/03ajqj5uqb013bc5_1614529990_1614530009.jpg"
+ #image = np.ones(shape=[1,3,192,320], dtype=np.float32)
+ img = cv2.imread(image_path)
+ img_resize = cv2.resize(img,(320,180))
+ img_resize = cv2.copyMakeBorder(img_resize, 6, 6, 0, 0, cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT, value=(114, 114, 114))
+ #img_resize = cv2.cvtColor(img_resize,cv2.COLOR_BGR2RGB)
+ img_resize = cv2.cvtColor(img_resize,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, and 交换 h,w c to c,h,w
+ img_resize = img_resize.astype("float32") # uint8 to fp16/32
+ img_resize /= 255.0 # 0 - 255 to 0.0 - 1.0
+ img_resize = img_resize[np.newaxis,:] # 增加维度
+ detial_dict = {}
+ input_feed = onnxmodel.get_input_feed(onnxmodel.input_name, img_resize)
+ pred = onnxmodel.onnx_session.run(onnxmodel.output_name, input_feed=input_feed)
+ detial_dict["before_decoder"] = [i.tolist() for i in pred]
+ pred=onnxmodel.after_process(pred) # torch后处理
+ detial_dict["after_decoder"] = copy.deepcopy(pred.tolist())
+ print(pred)
+ write_json("test.json",detial_dict)
+ pred=onnxmodel.nms(pred[0],0.3,0.5)
+ #detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
+ #pred=onnxmodel.scale_coords(img.shape[2:], pred, img.shape)
+ detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
+ write_json("test.json",detial_dict)
+
+
+
+
+
+
+
+
+
+
+
diff --git a/head_detect/head_detector/face_utils.py b/head_detect/head_detector/face_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9460ca4eca5a1313ef4f81d9e4b0a2f85e557665
--- /dev/null
+++ b/head_detect/head_detector/face_utils.py
@@ -0,0 +1,49 @@
+import cv2
+import math
+import numpy as np
+
+
+def make_divisible(x, divisor):
+ # Returns x evenly divisible by divisor
+ return math.ceil(x / divisor) * divisor
+
+def check_img_size(img_size, s=32):
+ # Verify img_size is a multiple of stride s
+ new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
+ if new_size != img_size:
+ print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
+ return new_size
+
+
+
+def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
+ # Resize and pad image while meeting stride-multiple constraints
+ shape = im.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
+ elif scaleFill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return im, ratio, (dw, dh)
\ No newline at end of file
diff --git a/head_detect/head_detector/head_detectorv2.py b/head_detect/head_detector/head_detectorv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..73febf2c8f612ae0b9442a459624762455f5c368
--- /dev/null
+++ b/head_detect/head_detector/head_detectorv2.py
@@ -0,0 +1,168 @@
+# -*- encoding: utf-8 -*-
+'''
+@File : face_detectorv2.py
+@Time : 2022/06/23 18:51:01
+@Author : Xie WenZhen
+@Version : 1.0
+@Contact : xiewenzhen@didiglobal.com
+@Desc : [人头检测去解码版本v2]
+'''
+
+# here put the import lib
+import os
+import cv2
+import copy
+import onnxruntime
+import numpy as np
+from head_detect.head_detector.face_utils import letterbox
+
+class HeadDetector:
+ def __init__(self, onnx_path="models/Mona_HeadDetector_v1_straght.onnx"):
+ self.onnx_path = onnx_path
+ self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
+ self.input_name = self.get_input_name(self.onnx_session)
+ self.output_name = self.get_output_name(self.onnx_session)
+
+ def get_output_name(self, onnx_session):
+ output_name = []
+ for node in onnx_session.get_outputs():
+ output_name.append(node.name)
+ return output_name
+
+ def get_input_name(self, onnx_session):
+ input_name = []
+ for node in onnx_session.get_inputs():
+ input_name.append(node.name)
+ return input_name
+
+ def get_input_feed(self, input_name, image_tensor):
+
+ input_feed = {}
+ for name in input_name:
+ input_feed[name] = image_tensor
+ return input_feed
+
+ def img_process(self,orgimg,long_side=320,stride_max=32):
+
+ #orgimg=cv2.imread(img_path)
+ img0 = copy.deepcopy(orgimg)
+ h0, w0 = orgimg.shape[:2] # orig hw
+ r = long_side/ max(h0, w0) # resize image to img_size
+ if r != 1: # always resize down, only resize up if training with augmentation
+ # interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
+ interp = cv2.INTER_LINEAR
+
+ img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
+ img = letterbox(img0, new_shape=(192,320),auto=False)[0] # auto True最小矩形 False固定尺度
+ # cv2.imwrite("convert.jpg",img=img)
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
+ img = img.astype("float32") # uint8 to fp16/32
+ img /= 255.0 # 0 - 255 to 0.0 - 1.0
+ img = img[np.newaxis,:]
+
+ return img,orgimg
+
+ def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+ coords[:, [0, 2]] -= pad[0] # x padding
+ coords[:, [1, 3]] -= pad[1] # y padding
+
+ coords[:, [0,1,2,3]] /= gain
+
+ return coords
+
+
+ def non_max_suppression(self, boxes,confs, iou_thres=0.6):
+
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = confs.flatten().argsort()[::-1]
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+ inds = np.where( ovr <= iou_thres)[0]
+ order = order[inds + 1]
+
+ return boxes[keep]
+
+ def nms(self, pred, conf_thres=0.1,iou_thres=0.5):
+ xc = pred[..., 4] > conf_thres
+ pred = pred[xc]
+ #pred[:, 15:] *= pred[:, 4:5]
+
+ # best class only
+ confs = np.amax(pred[:, 4:5], 1, keepdims=True)
+ pred[..., 0:4] = self.xywh2xyxy(pred[..., 0:4])
+ return self.non_max_suppression(pred, confs, iou_thres)
+
+ def xywh2xyxy(self, x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = np.zeros_like(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+ def get_largest_face(self,pred):
+ """[获取图片中最大的人脸]
+
+ Args:
+ object ([dict]): [人脸数据]
+
+ Returns:
+ [int]: [最大人脸的坐标]
+ """
+ max_index = 0
+ max_value = 0
+ for index in range(len(pred)):
+ xmin,ymin,xmax,ymax = pred[index][:4]
+ w = xmax - xmin
+ h = ymax - ymin
+ if w*h > max_value:
+ max_value = w*h
+ max_index = index
+ return max_index
+
+ def run(self, ori_image,get_largest=True):
+ img,orgimg=self.img_process(ori_image,long_side=320) #[1,3,640,640]
+ #print(img.shape)
+ input_feed = self.get_input_feed(self.input_name, img)
+ pred = self.onnx_session.run(self.output_name, input_feed=input_feed)
+
+ # pred=self.after_process(pred) # torch后处理
+ pred=self.nms(pred[0],0.3,0.5)
+ pred=self.scale_coords(img.shape[2:], pred, orgimg.shape)
+ if get_largest and pred.shape[0]!=0 :
+ pred_index = self.get_largest_face(pred)
+ pred = pred[[pred_index]]
+ bboxes = pred[:,[0,1,2,3,4]]
+ return bboxes.tolist()
+
+
+
+
+
+
+
+
+
+
diff --git a/head_detect/head_detector/head_detectorv3.py b/head_detect/head_detector/head_detectorv3.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c8e8f4455b00313b966095c2f9dbedf657fe617
--- /dev/null
+++ b/head_detect/head_detector/head_detectorv3.py
@@ -0,0 +1,216 @@
+# -*- encoding: utf-8 -*-
+'''
+@File : face_detectorv2.py
+@Time : 2022/06/23 18:51:01
+@Author : Xie WenZhen
+@Version : 1.0
+@Contact : xiewenzhen@didiglobal.com
+@Desc : [人头检测去解码版本]
+'''
+
+# here put the import lib
+import os
+import cv2
+import copy
+import onnxruntime
+import numpy as np
+from head_detect.head_detector.face_utils import letterbox
+
+class HeadDetector:
+ def __init__(self, onnx_path="models/head_models/HeadDetectorv1.1.onnx"):
+ self.onnx_path = onnx_path
+ self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
+ self.input_name = self.get_input_name(self.onnx_session)
+ self.output_name = self.get_output_name(self.onnx_session)
+
+ def get_output_name(self, onnx_session):
+ output_name = []
+ for node in onnx_session.get_outputs():
+ output_name.append(node.name)
+ return output_name
+
+ def get_input_name(self, onnx_session):
+ input_name = []
+ for node in onnx_session.get_inputs():
+ input_name.append(node.name)
+ return input_name
+
+ def get_input_feed(self, input_name, image_tensor):
+
+ input_feed = {}
+ for name in input_name:
+ input_feed[name] = image_tensor
+ return input_feed
+
+ def after_process(self,pred):
+ # 输入尺寸320,192 降8、16、32倍,对应输出尺寸为(40、20、10)
+ stride = np.array([8., 16., 32.])
+ x=[pred[0],pred[1],pred[2]]
+ # ============yolov5参数 start============
+ nl=3
+
+ #grid=[torch.zeros(1).to(device)] * nl
+ grid=[np.zeros(1)]*nl
+ anchor_grid=np.array([[[[[[ 4., 5.]]],
+ [[[ 8., 10.]]],
+ [[[ 13., 16.]]]]],
+ [[[[[ 23., 29.]]],
+ [[[ 43., 55.]]],
+ [[[ 73., 105.]]]]],
+ [[[[[146., 217.]]],
+ [[[231., 300.]]],
+ [[[335., 433.]]]]]])
+ # ============yolov5-0.5参数 end============
+ z = []
+ for i in range(len(x)):
+
+ ny, nx = x[i].shape[1],x[i].shape[2]
+ if grid[i].shape[2:4] != x[i].shape[2:4]:
+ grid[i] = self._make_grid(nx, ny)
+
+ y = np.full_like(x[i],0)
+
+ #y[..., [0,1,2,3,4,15]] = self.sigmoid_v(x[i][..., [0,1,2,3,4,15]])
+ y[..., [0,1,2,3,4]] = self.sigmoid_v(x[i][..., [0,1,2,3,4]])
+ #sigmoid_v人脸的置信度和危险动作置信度
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i]) * stride[i] # xy
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
+
+ z.append(y.reshape((1, -1, 6)))
+ return np.concatenate(z, 1)
+
+
+ def _make_grid(self, nx, ny):
+ yv, xv = np.meshgrid(np.arange(ny), np.arange(nx),indexing = 'ij')
+ return np.stack((xv, yv), 2).reshape((1, ny, nx, 2)).astype(float)
+
+ def sigmoid_v(self, array):
+ return np.reciprocal(np.exp(-array) + 1.0)
+
+ def img_process(self,orgimg,long_side=320,stride_max=32):
+
+ #orgimg=cv2.imread(img_path)
+ img0 = copy.deepcopy(orgimg)
+ h0, w0 = orgimg.shape[:2] # orig hw
+ r = long_side/ max(h0, w0) # resize image to img_size
+ if r != 1: # always resize down, only resize up if training with augmentation
+ # interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
+ interp = cv2.INTER_LINEAR
+
+ img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
+ img = letterbox(img0, new_shape=(320,320),auto=False)[0] # auto True最小矩形 False固定尺度
+ # cv2.imwrite("convert1.jpg",img=img)
+ # Convert
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
+ img = img.astype("float32") # uint8 to fp16/32
+ img /= 255.0 # 0 - 255 to 0.0 - 1.0
+ img = img[np.newaxis,:]
+
+ return img,orgimg
+
+ def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+ coords[:, [0, 2]] -= pad[0] # x padding
+ coords[:, [1, 3]] -= pad[1] # y padding
+
+ coords[:, [0,1,2,3]] /= gain
+
+ return coords
+
+
+ def non_max_suppression(self, boxes,confs, iou_thres=0.6):
+
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = confs.flatten().argsort()[::-1]
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+ inds = np.where( ovr <= iou_thres)[0]
+ order = order[inds + 1]
+
+ return boxes[keep]
+
+ def nms(self, pred, conf_thres=0.1,iou_thres=0.5):
+ xc = pred[..., 4] > conf_thres
+ pred = pred[xc]
+ #pred[:, 15:] *= pred[:, 4:5]
+
+ # best class only
+ confs = np.amax(pred[:, 4:5], 1, keepdims=True)
+ pred[..., 0:4] = self.xywh2xyxy(pred[..., 0:4])
+ return self.non_max_suppression(pred, confs, iou_thres)
+
+ def xywh2xyxy(self, x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = np.zeros_like(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+ def get_largest_face(self,pred):
+ """[获取图片中最大的人脸]
+
+ Args:
+ object ([dict]): [人脸数据]
+
+ Returns:
+ [int]: [最大人脸的坐标]
+ """
+ max_index = 0
+ max_value = 0
+ for index in range(len(pred)):
+ xmin,ymin,xmax,ymax = pred[index][:4]
+ w = xmax - xmin
+ h = ymax - ymin
+ if w*h > max_value:
+ max_value = w*h
+ max_index = index
+ return max_index
+
+ def run(self, ori_image,get_largest=True):
+ img,orgimg=self.img_process(ori_image,long_side=320) #[1,3,640,640]
+ #print(img.shape)
+ input_feed = self.get_input_feed(self.input_name, img)
+ pred = self.onnx_session.run(self.output_name, input_feed=input_feed)
+ pred=self.after_process(pred) # torch后处理
+ pred=self.nms(pred[0],0.3,0.5)
+ #detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
+ pred=self.scale_coords(img.shape[2:], pred, orgimg.shape)
+ #detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
+
+ if get_largest and pred.shape[0]!=0 :
+ pred_index = self.get_largest_face(pred)
+ pred = pred[[pred_index]]
+ bboxes = pred[:,[0,1,2,3,4]]
+ return bboxes.tolist()
+
+
+
+
+
+
+
+
+
+
diff --git a/head_detect/head_detector/head_detectorv4.py b/head_detect/head_detector/head_detectorv4.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b2a553d00a85c33e4797d3beeb22a93b35e29d3
--- /dev/null
+++ b/head_detect/head_detector/head_detectorv4.py
@@ -0,0 +1,216 @@
+# -*- encoding: utf-8 -*-
+'''
+@File : face_detectorv2.py
+@Time : 2022/06/23 18:51:01
+@Author : Xie WenZhen
+@Version : 1.0
+@Contact : xiewenzhen@didiglobal.com
+@Desc : [人头检测去解码版本]
+'''
+
+# here put the import lib
+import os
+import cv2
+import copy
+import onnxruntime
+import numpy as np
+from head_detect.head_detector.face_utils import letterbox
+
+class HeadDetector:
+ def __init__(self, onnx_path="models/head_models/HeadDetectorv1.3.onnx"):
+ self.onnx_path = onnx_path
+ self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
+ self.input_name = self.get_input_name(self.onnx_session)
+ self.output_name = self.get_output_name(self.onnx_session)
+
+ def get_output_name(self, onnx_session):
+ output_name = []
+ for node in onnx_session.get_outputs():
+ output_name.append(node.name)
+ return output_name
+
+ def get_input_name(self, onnx_session):
+ input_name = []
+ for node in onnx_session.get_inputs():
+ input_name.append(node.name)
+ return input_name
+
+ def get_input_feed(self, input_name, image_tensor):
+
+ input_feed = {}
+ for name in input_name:
+ input_feed[name] = image_tensor
+ return input_feed
+
+ def after_process(self,pred):
+ # 输入尺寸320,192 降8、16、32倍,对应输出尺寸为(40、20、10)
+ stride = np.array([8., 16., 32.])
+ x=[pred[0],pred[1],pred[2]]
+ # ============yolov5参数 start============
+ nl=3
+
+ #grid=[torch.zeros(1).to(device)] * nl
+ grid=[np.zeros(1)]*nl
+ anchor_grid=np.array([[[[[[ 4., 5.]]],
+ [[[ 8., 10.]]],
+ [[[ 13., 16.]]]]],
+ [[[[[ 23., 29.]]],
+ [[[ 43., 55.]]],
+ [[[ 73., 105.]]]]],
+ [[[[[146., 217.]]],
+ [[[231., 300.]]],
+ [[[335., 433.]]]]]])
+ # ============yolov5-0.5参数 end============
+ z = []
+ for i in range(len(x)):
+
+ ny, nx = x[i].shape[1],x[i].shape[2]
+ if grid[i].shape[2:4] != x[i].shape[2:4]:
+ grid[i] = self._make_grid(nx, ny)
+
+ y = np.full_like(x[i],0)
+
+ #y[..., [0,1,2,3,4,15]] = self.sigmoid_v(x[i][..., [0,1,2,3,4,15]])
+ y[..., [0,1,2,3,4]] = self.sigmoid_v(x[i][..., [0,1,2,3,4]])
+ #sigmoid_v人脸的置信度和危险动作置信度
+ y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i]) * stride[i] # xy
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
+
+ z.append(y.reshape((1, -1, 6)))
+ return np.concatenate(z, 1)
+
+
+ def _make_grid(self, nx, ny):
+ yv, xv = np.meshgrid(np.arange(ny), np.arange(nx),indexing = 'ij')
+ return np.stack((xv, yv), 2).reshape((1, ny, nx, 2)).astype(float)
+
+ def sigmoid_v(self, array):
+ return np.reciprocal(np.exp(-array) + 1.0)
+
+ def img_process(self,orgimg,long_side=320,stride_max=32):
+
+ #orgimg=cv2.imread(img_path)
+ img0 = copy.deepcopy(orgimg)
+ h0, w0 = orgimg.shape[:2] # orig hw
+ r = long_side/ max(h0, w0) # resize image to img_size
+ if r != 1: # always resize down, only resize up if training with augmentation
+ # interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
+ interp = cv2.INTER_LINEAR
+
+ img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp)
+ img = letterbox(img0, new_shape=(320,288),auto=False)[0] # auto True最小矩形 False固定尺度
+ # cv2.imwrite("convert1.jpg",img=img)
+ # Convert
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416
+ img = img.astype("float32") # uint8 to fp16/32
+ img /= 255.0 # 0 - 255 to 0.0 - 1.0
+ img = img[np.newaxis,:]
+
+ return img,orgimg
+
+ def scale_coords(self,img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+ coords[:, [0, 2]] -= pad[0] # x padding
+ coords[:, [1, 3]] -= pad[1] # y padding
+
+ coords[:, [0,1,2,3]] /= gain
+
+ return coords
+
+
+ def non_max_suppression(self, boxes,confs, iou_thres=0.6):
+
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = confs.flatten().argsort()[::-1]
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+ w = np.maximum(0.0, xx2 - xx1 + 1)
+ h = np.maximum(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
+ inds = np.where( ovr <= iou_thres)[0]
+ order = order[inds + 1]
+
+ return boxes[keep]
+
+ def nms(self, pred, conf_thres=0.1,iou_thres=0.5):
+ xc = pred[..., 4] > conf_thres
+ pred = pred[xc]
+ #pred[:, 15:] *= pred[:, 4:5]
+
+ # best class only
+ confs = np.amax(pred[:, 4:5], 1, keepdims=True)
+ pred[..., 0:4] = self.xywh2xyxy(pred[..., 0:4])
+ return self.non_max_suppression(pred, confs, iou_thres)
+
+ def xywh2xyxy(self, x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = np.zeros_like(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+ def get_largest_face(self,pred):
+ """[获取图片中最大的人脸]
+
+ Args:
+ object ([dict]): [人脸数据]
+
+ Returns:
+ [int]: [最大人脸的坐标]
+ """
+ max_index = 0
+ max_value = 0
+ for index in range(len(pred)):
+ xmin,ymin,xmax,ymax = pred[index][:4]
+ w = xmax - xmin
+ h = ymax - ymin
+ if w*h > max_value:
+ max_value = w*h
+ max_index = index
+ return max_index
+
+ def run(self, ori_image,get_largest=True):
+ img,orgimg=self.img_process(ori_image,long_side=320) #[1,3,640,640]
+ #print(img.shape)
+ input_feed = self.get_input_feed(self.input_name, img)
+ pred = self.onnx_session.run(self.output_name, input_feed=input_feed)
+ pred=self.after_process(pred) # torch后处理
+ pred=self.nms(pred[0],0.3,0.5)
+ #detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
+ pred=self.scale_coords(img.shape[2:], pred, orgimg.shape)
+ #detial_dict["after_nms"] = copy.deepcopy(pred.tolist())
+
+ if get_largest and pred.shape[0]!=0 :
+ pred_index = self.get_largest_face(pred)
+ pred = pred[[pred_index]]
+ bboxes = pred[:,[0,1,2,3,4]]
+ return bboxes.tolist()
+
+
+
+
+
+
+
+
+
+
diff --git a/head_detect/head_detector/pose.py b/head_detect/head_detector/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae8564c166de2e9e4b03db718167e7719c77bb7c
--- /dev/null
+++ b/head_detect/head_detector/pose.py
@@ -0,0 +1,55 @@
+
+import numpy as np
+import math
+
+class Pose:
+ def __init__(self):
+
+ pt3d = np.zeros((3, 5))
+ pt3d[0, :] = [-0.3207, 0.3101, -0.0011, -0.2578, 0.2460]
+ pt3d[1, :] = [0.2629, 0.2631, -0.0800, -0.4123, -0.4127]
+ pt3d[2, :] = [0.9560, 0.9519, 1.3194, 0.9921, 0.9899]
+ self.pt3d = pt3d * 1e5
+
+ def __call__(self, pt2d):
+
+ #pt2d = np.asarray(pt2d, np.float)
+ pt2d = np.reshape(pt2d, (5, 2)).transpose()
+ pt3d = self.pt3d
+ # 参照论文Optimum Fiducials Under Weak Perspective Projection,使用弱透视投影
+ # 减均值,排除t,便于求出R
+ pt2dm = np.zeros(pt2d.shape)
+ pt3dm = np.zeros(pt3d.shape)
+ pt2dm[0, :] = pt2d[0, :] - np.mean(pt2d[0, :])
+ pt2dm[1, :] = pt2d[1, :] - np.mean(pt2d[1, :])
+ pt3dm[0, :] = pt3d[0, :] - np.mean(pt3d[0, :])
+ pt3dm[1, :] = pt3d[1, :] - np.mean(pt3d[1, :])
+ pt3dm[2, :] = pt3d[2, :] - np.mean(pt3d[2, :])
+ # 最小二乘方法计算R
+ R1 = np.dot(np.dot(np.mat(np.dot(pt3dm, pt3dm.T)).I, pt3dm), pt2dm[0, :].T)
+ R2 = np.dot(np.dot(np.mat(np.dot(pt3dm, pt3dm.T)).I, pt3dm), pt2dm[1, :].T)
+ # 计算出f
+ f = (math.sqrt(R1[0, 0] ** 2 + R1[0, 1] ** 2 + R1[0, 2] ** 2) + math.sqrt(
+ R2[0, 0] ** 2 + R2[0, 1] ** 2 + R2[0, 2] ** 2)) / 2
+ R1 = R1 / f
+
+ R2 = R2 / f
+ R3 = np.cross(R1, R2)
+ # 使用旋转矩阵R恢复出旋转角度
+ phi = math.atan(R2[0, 2] / R3[0, 2])
+ gamma = math.atan(-R1[0, 2] / (math.sqrt(R1[0, 0] ** 2 + R1[0, 1] ** 2)))
+ theta = math.atan(R1[0, 1] / R1[0, 0])
+
+ # 使用R重新计算旋转平移矩阵,求出t
+ pt3d = np.row_stack((pt3d, np.ones((1, pt3d.shape[1]))))
+ R1_orig = np.dot(np.dot(np.mat(np.dot(pt3d, pt3d.T)).I, pt3d), pt2d[0, :].T)
+ R2_orig = np.dot(np.dot(np.mat(np.dot(pt3d, pt3d.T)).I, pt3d), pt2d[1, :].T)
+
+ t3d = np.array([R1_orig[0, 3], R2_orig[0, 3], 0]).reshape((3, 1))
+ pitch = phi * 180 / np.pi
+ yaw = gamma * 180 / np.pi
+ roll = theta * 180 / np.pi
+
+ return pitch, yaw, roll
+
+
diff --git a/head_detect/models/HeadDetectorv1.6.onnx b/head_detect/models/HeadDetectorv1.6.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..c4f4ba7239e8caedeb31e6bbf2837cb01e1f6885
--- /dev/null
+++ b/head_detect/models/HeadDetectorv1.6.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fb959a1126c460b7becda05ba121c6ebd0b2a2fcc116e14af065aec69430068
+size 1227728
diff --git a/head_detect/utils_quailty_assurance/__pycache__/utils_quailty_assurance.cpython-38.pyc b/head_detect/utils_quailty_assurance/__pycache__/utils_quailty_assurance.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f9fda8a0970afcfbee44cea5c87fa2347a2fb12
Binary files /dev/null and b/head_detect/utils_quailty_assurance/__pycache__/utils_quailty_assurance.cpython-38.pyc differ
diff --git a/head_detect/utils_quailty_assurance/draw_tools.py b/head_detect/utils_quailty_assurance/draw_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c41b8f8832f245694b83e429f8fe5c483ef97ca
--- /dev/null
+++ b/head_detect/utils_quailty_assurance/draw_tools.py
@@ -0,0 +1,253 @@
+import os
+import cv2
+import math
+import numpy as np
+import matplotlib.pyplot as plt
+
+def show_results(img, bboxe,landmark,mask_label=None,emotion_label=None):
+ h,w,c = img.shape
+ tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
+ x1,y1,x2,y2,confidence = bboxe
+ cv2.rectangle(img, (int(x1),int(y1)), (int(x2), int(y2)), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
+ clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
+
+ for i in range(5):
+ point_x = int(landmark[i][0])
+ point_y = int(landmark[i][1])
+ cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1)
+
+ tf = max(tl - 1, 1) # font thickness
+ label = str(confidence)[:5]
+ cv2.putText(img, label, (int(x2), int(y2) - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+ if mask_label!=None:
+ labels_dict = {0: 'Mask', 1: 'NoMask'}
+ color_dict = {0:[0,0,255],1:[255,255,0]}
+ cv2.putText(img, labels_dict[mask_label], (int(x1), int(y1)-5), 0, tl, color_dict[mask_label], thickness=tf, lineType=cv2.LINE_AA)
+ if emotion_label!=None:
+ emotion_str = "smile:{:.2f}".format(emotion_label)
+ cv2.putText(img, emotion_str, (int(x1), int(y1)-30), 0, tl, [255,153,255], thickness=tf, lineType=cv2.LINE_AA)
+
+ return img
+
+def draw_bboxes_landmarks(img, bboxe,landmark,multi_label=None,multi_conf=None):
+ h,w,c = img.shape
+ tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
+ x1,y1,x2,y2,confidence = bboxe
+ cv2.rectangle(img, (int(x1),int(y1)), (int(x2), int(y2)), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
+ clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
+
+ for i in range(5):
+ point_x = int(landmark[i][0])
+ point_y = int(landmark[i][1])
+ cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1)
+
+ tf = max(tl - 1, 1) # font thickness
+ label = f"face:{confidence:.3f}"
+ cv2.putText(img, label, (int(x1), int(y1) - 2), 0, tl / 2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+ if multi_label!=None:
+ labels_dict = {0:"normal",1:"smoke",2:"phone",3:"drink"}
+ multi_str = f"{labels_dict[multi_label]}:{multi_conf:.3f}"
+ if y1>10:
+ cv2.putText(img, multi_str, (int(x1), int(y1)-17), 0, tl/2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+ else:
+ cv2.putText(img, multi_str, (int(x1), int(y2)+5), 0, tl/2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+ return img
+
+def draw_bboxes(img, bboxe,multi_label=None,multi_conf=None):
+ h,w,c = img.shape
+ tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
+ x1,y1,x2,y2,confidence = bboxe
+ cv2.rectangle(img, (int(x1),int(y1)), (int(x2), int(y2)), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
+ tf = max(tl - 1, 1) # font thickness
+ label = f"face:{confidence:.3f}"
+ cv2.putText(img, label, (int(x1), int(y1) - 2), 0, tl / 2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+ if multi_label!=None:
+ labels_dict = {0:"normal",1:"smoke",2:"phone",3:"drink"}
+ multi_str = f"{labels_dict[multi_label]}:{multi_conf:.3f}"
+ if y1>10:
+ cv2.putText(img, multi_str, (int(x1), int(y1)-17), 0, tl/2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+ else:
+ cv2.putText(img, multi_str, (int(x1), int(y2)+5), 0, tl/2, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
+ return img
+
+def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=80):
+ height, width = img.shape[:2]
+ tl = 1 or round(0.002 * (height + width) / 2) + 1 # line/font thickness
+ pitch = pitch * np.pi / 180
+ yaw = -(yaw * np.pi / 180)
+ roll = roll * np.pi / 180
+
+ if tdx != None and tdy != None:
+ tdx = tdx
+ tdy = tdy
+ else:
+ tdx = width / 2
+ tdy = height / 2
+
+ # X-Axis pointing to right. drawn in red
+ x1 = size * (math.cos(yaw) * math.cos(roll)) + tdx
+ y1 = size * (math.cos(pitch) * math.sin(roll) + math.cos(roll)
+ * math.sin(pitch) * math.sin(yaw)) + tdy
+
+ # Y-Axis | drawn in green
+ # v
+ x2 = size * (-math.cos(yaw) * math.sin(roll)) + tdx
+ y2 = size * (math.cos(pitch) * math.cos(roll) - math.sin(pitch)
+ * math.sin(yaw) * math.sin(roll)) + tdy
+
+ # Z-Axis (out of the screen) drawn in blue
+ x3 = size * (math.sin(yaw)) + tdx
+ y3 = size * (-math.cos(yaw) * math.sin(pitch)) + tdy
+
+ cv2.line(img, (int(tdx), int(tdy)), (int(x1), int(y1)), (0, 0, 255), 3)
+ cv2.line(img, (int(tdx), int(tdy)), (int(x2), int(y2)), (0, 255, 0), 3)
+ cv2.line(img, (int(tdx), int(tdy)), (int(x3), int(y3)), (255, 0, 0), 2)
+
+ return img
+
+def split_num(sort_lst):
+ if not sort_lst:
+ return []
+ len_lst = len(sort_lst)
+ i = 0
+ split_lst = []
+ tmp_lst = [sort_lst[i]]
+ while True:
+ if i + 1 == len_lst:
+ break
+ next_n = sort_lst[i+1]
+ if sort_lst[i] + 1 == next_n:
+ tmp_lst.append(next_n)
+ else:
+ split_lst.append(tmp_lst)
+ tmp_lst = [next_n]
+ i += 1
+ split_lst.append(tmp_lst)
+ return split_lst
+def expand_Scope(nums,lenght):
+ if nums[0]==0:
+ start = nums[0]
+ end = nums[-1]+2
+ else:
+ start = nums[0]-1
+ end = nums[-1]+1
+ if nums[-1]==lenght-1:
+ start = nums[0]-2
+ end = nums[-1]
+ else:
+ start = nums[0]-1
+ end = nums[-1]+1
+ return start,end
+def draw_plot(xs, ys=None, dt=None, title ="Orign Yaw angle",c='C0', label='Yaw', time=None,y_line=None,std=None,outline_dict=None,**kwargs):
+ """ plot result of KF with color `c`, optionally displaying the variance
+ of `xs`. Returns the list of lines generated by plt.plot()"""
+
+ if ys is None and dt is not None:
+ ys = xs
+ xs = np.arange(0, len(ys) * dt, dt)
+ if ys is None:
+ ys = xs
+ xs = range(len(ys))
+
+ lines = plt.title(title)
+ lines = plt.plot(xs, ys, color=c, label=label, **kwargs)
+ if time is None:
+ return lines
+ x0 = time*dt
+ y0 = ys[time]
+ lines = plt.scatter(time*dt,ys[time],s=20,color='b')
+
+ lines = plt.plot([x0,x0],[y0,0],'r-',lw=2)
+ if y_line is None:
+ return lines
+
+ lines = plt.axhline(y=y_line, color='g', linestyle='--')
+ y_line_list = np.full(len(ys),y_line)
+ std_top = y_line_list+std
+ std_btm = y_line_list-std
+ plt.plot(xs, std_top, linestyle=':', color='k', lw=2)
+ plt.plot(xs, std_btm, linestyle=':', color='k', lw=2)
+ up_outline_lst = outline_dict["up"]
+ down_outline_lst = outline_dict["down"]
+
+ for nums in up_outline_lst:
+ start,end = expand_Scope(nums,len(ys))
+ plt.fill_between(xs[start:end+1], std_btm[start:end+1], np.array(ys)[start:end+1], facecolor='red', alpha=0.3)
+ for nums in down_outline_lst:
+ start,end = expand_Scope(nums,len(ys))
+ plt.fill_between(xs[start:end+1], np.array(ys)[start:end+1],std_top[start:end+1], facecolor='red', alpha=0.3)
+
+ plt.fill_between(xs, std_btm, std_top,
+ facecolor='yellow', alpha=0.2)
+
+ return lines
+
+def draw_plot_old(xs, ys=None, dt=None, title ="Orign Yaw angle",c='C0', label='Yaw', time=None,y_line=None,std=None,outline_dict=None,**kwargs):
+ """ plot result of KF with color `c`, optionally displaying the variance
+ of `xs`. Returns the list of lines generated by plt.plot()"""
+
+ if ys is None and dt is not None:
+ ys = xs
+ xs = np.arange(0, len(ys) * dt, dt)
+ if ys is None:
+ ys = xs
+ xs = range(len(ys))
+
+ lines = plt.title(title)
+ lines = plt.plot(xs, ys, color=c, label=label, **kwargs)
+ if time is None:
+ return lines
+ x0 = time*dt
+ y0 = ys[time]
+ lines = plt.scatter(time*dt,ys[time],s=20,color='b')
+
+ lines = plt.plot([x0,x0],[y0,0],'r-',lw=2)
+ if y_line is None:
+ return lines
+
+ lines = plt.axhline(y=y_line, color='g', linestyle='--')
+ y_line_list = np.full(len(ys),y_line)
+ std_top = y_line_list+std[0]
+ std_btm = y_line_list-std[1]
+ plt.plot(xs, std_top, linestyle=':', color='k', lw=2)
+ plt.plot(xs, std_btm, linestyle=':', color='k', lw=2)
+ up_outline_lst = outline_dict["up"]
+ down_outline_lst = outline_dict["down"]
+
+ for nums in up_outline_lst:
+ start,end = expand_Scope(nums,len(ys))
+ plt.fill_between(xs[start:end+1], std_btm[start:end+1], np.array(ys)[start:end+1], facecolor='red', alpha=0.3)
+ for nums in down_outline_lst:
+ start,end = expand_Scope(nums,len(ys))
+ plt.fill_between(xs[start:end+1], np.array(ys)[start:end+1],std_top[start:end+1], facecolor='red', alpha=0.3)
+
+ plt.fill_between(xs, std_btm, std_top,
+ facecolor='yellow', alpha=0.2)
+
+ return lines
+def draw_sticker(src, offset, pupils, landmarks,
+ blink_thd=0.22,
+ arrow_color=(0, 125, 255), copy=False):
+ if copy:
+ src = src.copy()
+
+ left_eye_hight = landmarks[33, 1] - landmarks[40, 1]
+ left_eye_width = landmarks[39, 0] - landmarks[35, 0]
+
+ right_eye_hight = landmarks[87, 1] - landmarks[94, 1]
+ right_eye_width = landmarks[93, 0] - landmarks[89, 0]
+
+ # for mark in landmarks.reshape(-1, 2).astype(int):
+ # cv2.circle(src, tuple(mark), radius=1,
+ # color=(0, 0, 255), thickness=-1)
+
+ if left_eye_hight / left_eye_width > blink_thd:
+ cv2.arrowedLine(src, tuple(pupils[0].astype(int)),
+ tuple((offset+pupils[0]).astype(int)), arrow_color, 2)
+
+ if right_eye_hight / right_eye_width > blink_thd:
+ cv2.arrowedLine(src, tuple(pupils[1].astype(int)),
+ tuple((offset+pupils[1]).astype(int)), arrow_color, 2)
+
+ return src
+
diff --git a/head_detect/utils_quailty_assurance/metrics.py b/head_detect/utils_quailty_assurance/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b529a0b820fc7f3d426b982354891d5fb42a563
--- /dev/null
+++ b/head_detect/utils_quailty_assurance/metrics.py
@@ -0,0 +1,123 @@
+# Model validation metrics
+from pathlib import Path
+import matplotlib.pyplot as plt
+import numpy as np
+
+def getLabel2idx(labels):
+ label2idx = dict()
+ for i in labels:
+ if i not in label2idx:
+ label2idx[i] = len(label2idx)
+ return label2idx
+
+def calculate_all_prediction(confMatrix):
+ '''
+ 计算总精度:对角线上所有值除以总数
+ '''
+ total_sum = confMatrix.sum()
+ correct_sum = (np.diag(confMatrix)).sum()
+ prediction = round(100*float(correct_sum)/float(total_sum),2)
+ return prediction
+
+def calculate_label_prediction(confMatrix,labelidx):
+ '''
+ 计算某一个类标预测精度:该类被预测正确的数除以该类的总数
+ '''
+ label_total_sum = confMatrix.sum(axis=0)[labelidx]
+ label_correct_sum = confMatrix[labelidx][labelidx]
+ prediction = 0
+ if label_total_sum != 0:
+ prediction = round(100*float(label_correct_sum)/float(label_total_sum),2)
+ return prediction
+
+def calculate_label_recall(confMatrix,labelidx):
+ '''
+ 计算某一个类标的召回率:
+ '''
+ label_total_sum = confMatrix.sum(axis=1)[labelidx]
+ label_correct_sum = confMatrix[labelidx][labelidx]
+ recall = 0
+ if label_total_sum != 0:
+ recall = round(100*float(label_correct_sum)/float(label_total_sum),2)
+ return recall
+
+def calculate_f1(prediction,recall):
+ if (prediction+recall)==0:
+ return 0
+ return round(2*prediction*recall/(prediction+recall),2)
+
+def ap_per_class(tp, conf, pred_cls, target_cls):
+ # tp = np.squeeze(tp,axis = 1)
+ conf = np.squeeze(conf,axis=1)
+ pred_cls = np.squeeze(pred_cls,axis=1)
+ # target_cls = target_cls.reshape((-1,1))
+
+ # Sort by objectness
+ i = np.argsort(-conf)
+ tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
+ #print(f"labels face:{len(target_cls)},pred face:{len(pred_cls)}")
+
+ # Find unique classes
+ unique_classes = np.unique(target_cls)
+
+ # Create Precision-Recall curve and compute AP for each class
+ px, py = np.linspace(0, 1, 1000), [] # for plotting
+ pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
+ s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
+ ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
+ for ci, c in enumerate(unique_classes):
+ i = pred_cls == c
+ n_l = (target_cls == c).sum() # number of labels
+ n_p = i.sum() # number of predictions
+
+ if n_p == 0 or n_l == 0:
+ continue
+ else:
+ # Accumulate FPs and TPs
+ fpc = (1 - tp[i]).cumsum(0)
+ tpc = tp[i].cumsum(0)
+
+ # Recall
+ recall = tpc / (n_l + 1e-16) # recall curve
+ r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
+
+ # Precision
+ precision = tpc / (tpc + fpc) # precision curve
+ p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
+
+ # AP from recall-precision curve
+ for j in range(tp.shape[1]):
+ ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
+ # Compute F1 score (harmonic mean of precision and recall)
+ f1 = 2 * p * r / (p + r + 1e-16)
+ return p, r, ap, f1, unique_classes.astype('int32')
+
+
+def compute_ap(recall, precision):
+ """ Compute the average precision, given the recall and precision curves
+ # Arguments
+ recall: The recall curve (list)
+ precision: The precision curve (list)
+ # Returns
+ Average precision, precision curve, recall curve
+ """
+
+ # Append sentinel values to beginning and end
+ mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01]))
+ mpre = np.concatenate(([1.], precision, [0.]))
+
+ # Compute the precision envelope
+ mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
+
+ # Integrate area under curve
+ method = 'interp' # methods: 'continuous', 'interp'
+ if method == 'interp':
+ x = np.linspace(0, 1, 101) # 101-point interp (COCO)
+ ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
+ else: # 'continuous'
+ i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
+
+ return ap, mpre, mrec
+
+
diff --git a/head_detect/utils_quailty_assurance/result_to_coco.py b/head_detect/utils_quailty_assurance/result_to_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d8a38966a778ef28a3611e547422d2cba37b0e7
--- /dev/null
+++ b/head_detect/utils_quailty_assurance/result_to_coco.py
@@ -0,0 +1,108 @@
+from email import header
+
+
+# -*- encoding: utf-8 -*-
+'''
+@File : result_to_coco.py
+@Time : 2022/04/27 15:54:13
+@Author : Xie WenZhen
+@Version : 1.0
+@Contact : xiewenzhen@didiglobal.com
+@Desc : None
+'''
+
+# here put the import lib
+import os
+import shutil
+from tqdm import tqdm
+from copy import deepcopy
+
+def parse_label_resultv1(label_path,img_width,img_higth):
+ with open(label_path, 'r') as fr:
+ labelList = fr.readlines()
+ face_list = []
+ for label in labelList:
+ label = label.strip().split()
+ x = float(label[1])
+ y = float(label[2])
+ w = float(label[3])
+ h = float(label[4])
+ x1 = (x - w / 2) * img_width
+ y1 = (y - h / 2) * img_higth
+ x2 = (x + w / 2) * img_width
+ y2 = (y + h / 2) * img_higth
+ face_list.append(deepcopy([x1,y1,x2-x1,y2-y1]))
+ return face_list
+
+def parse_label_resultv2(label_path,img_width,img_higth):
+ with open(label_path, 'r') as fr:
+ labelList = fr.readlines()
+ face_list = []
+ mark_list = []
+ category_id_list = []
+ for label in labelList:
+ label = label.strip().split()
+ x = float(label[1])
+ y = float(label[2])
+ w = float(label[3])
+ h = float(label[4])
+ x1 = (x - w / 2) * img_width
+ y1 = (y - h / 2) * img_higth
+ x2 = (x + w / 2) * img_width
+ y2 = (y + h / 2) * img_higth
+ face_list.append(deepcopy([x1,y1,x2-x1,y2-y1]))
+ ######
+ mx0_ = float(label[5]) * img_width
+ my0_ = float(label[6]) * img_higth
+ mx1_ = float(label[7]) * img_width
+ my1_ = float(label[8]) * img_higth
+ mx2_ = float(label[9]) * img_width
+ my2_ = float(label[10]) * img_higth
+ mx3_ = float(label[11]) * img_width
+ my3_ = float(label[12]) * img_higth
+ mx4_ = float(label[13]) * img_width
+ my4_ = float(label[14]) * img_higth
+ mark_list.append(deepcopy([mx0_,my0_,mx1_,my1_,mx2_,my2_,mx3_,my3_,mx4_,my4_]))
+ #####
+ category_id = int(label[15])
+ category_id_list.append(deepcopy(category_id))
+
+ return face_list,mark_list,category_id_list
+def generate_coco_labels(bbox,img_height,img_width,keypoint,filename,category_id):
+ images_info = {}
+ images_info["file_name"] = filename
+ images_info['id'] = 0
+ images_info['height'] = img_height
+ images_info['width'] = img_width
+
+
+ anno = {}
+ anno['keypoints'] = keypoint
+ anno['image_id'] = 0
+ anno['id'] = 0
+ anno['num_keypoints'] = 13 # all keypoints are labelled
+ anno['bbox'] = bbox
+ anno['iscrowd'] = 0
+ anno['area'] = anno['bbox'][2] * anno['bbox'][3]
+ anno['category_id'] = category_id
+ final_output = {"images":images_info,
+ "annotations":anno}
+ return final_output
+def get_largest_face(face_dict_list):
+ if len(face_dict_list)==1:
+ return face_dict_list[0]["bbox"],face_dict_list[0]["label"]
+ max_id = 0
+ max_area = 0
+ for idx,face_dict in enumerate(face_dict_list):
+ recent_area = face_dict["bbox"][2] * face_dict["bbox"][3]
+ if recent_area>max_area:
+ max_id = idx
+ max_area = recent_area
+ return face_dict_list[max_id]["bbox"],face_dict_list[max_id]["label"]
+
+def main():
+ print("Hello, World!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/head_detect/utils_quailty_assurance/utils_quailty_assurance.py b/head_detect/utils_quailty_assurance/utils_quailty_assurance.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9c67f794dc63af6493d4df4dab38ea208c734f
--- /dev/null
+++ b/head_detect/utils_quailty_assurance/utils_quailty_assurance.py
@@ -0,0 +1,79 @@
+##检查两部分数据
+import os
+import json
+import torch
+from copy import deepcopy
+
+def write_json(json_save_path,result_dict):
+ with open(json_save_path,"w") as f:
+ f.write(json.dumps(result_dict))
+
+def read_json(json_path):
+ with open(json_path, 'r') as f:
+ result_dict = json.load(f)
+ return result_dict
+
+def parse_label(label_path):
+ with open(label_path, 'r') as fr:
+ labelList = fr.readlines()
+ label_list = []
+ for label in labelList:
+ label = label.strip().split()
+ l = int(label[-1])
+ label_list.append(deepcopy(l))
+ return label_list
+
+def find_files(file_path,file_type='mp4'):
+ if not os.path.exists(file_path):
+ return []
+ result = []
+ for root, dirs, files in os.walk(file_path, topdown=False):
+ for name in files:
+ if name.split(".")[-1]==file_type:
+ result.append(os.path.join(root, name))
+ return result
+def find_images(file_path):
+ if not os.path.exists(file_path):
+ return []
+ result = []
+ for root, dirs, files in os.walk(file_path, topdown=False):
+ for name in files:
+ if name.split(".")[-1] in ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng']:
+ result.append(os.path.join(root, name))
+ return result
+
+def box_iou(box1, box2):
+ def box_area(box):
+ # box = 4xn
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ area1 = box_area(box1.T)
+ area2 = box_area(box2.T)
+
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
+ torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
+ # iou = inter / (area1 + area2 - inter)
+ return inter / (area1[:, None] + area2 - inter)
+
+def xywh2xyxy(x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+def match_Iou(label5_list,label16_list):
+ label15_torch = torch.tensor(label5_list)
+ label16_torch = torch.tensor(label16_list)
+ bbox5 = xywh2xyxy(label15_torch[:,1:5])
+ bbox16 = xywh2xyxy(label16_torch[:,1:5])
+ label5 = label15_torch[:,0]
+ label16 = label16_torch[:,15]
+ ious, i = box_iou(bbox5, bbox16).max(1)
+ final_result = label16_torch[i]
+ final_result[:,15] = label5[i]
+ return final_result
+
diff --git a/head_detect/utils_quailty_assurance/video2imglist.py b/head_detect/utils_quailty_assurance/video2imglist.py
new file mode 100644
index 0000000000000000000000000000000000000000..11a6d6e76b6ccf36d6530f633b7ff35ed57c9c23
--- /dev/null
+++ b/head_detect/utils_quailty_assurance/video2imglist.py
@@ -0,0 +1,57 @@
+#将视频按时间抽取关键帧
+# -*- coding: utf-8 -*-
+
+import cv2
+import math
+import os
+
+
+def video2img(flv="ceshivedio.flv",rate=0.3,start=1,end=100):
+
+ list_all=[]
+ vc=cv2.VideoCapture(flv)
+ c=1
+ fps=dps(flv)
+ if rate>fps:
+ print("the fps is %s, set the rate=fps"%(fps))
+ rate=fps
+
+ if vc.isOpened():
+ rval=True
+ else:
+ rval=False
+
+ j=1.0
+ count=0.0
+ while rval:
+
+ count=fps/rate*(j-1)
+ rval,frame=vc.read()
+ if (c-1)==int(count):
+ j+=1
+ if (math.floor(c/fps))>=start and (math.floor(c/fps))=end:
+ break
+ print("[ %d ] pictures from '%s' "%(len(list_all),flv))
+ vc.release()
+
+ return list_all
+
+def dps(vedio):
+ video = cv2.VideoCapture(vedio)
+ #(major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
+ fps = video.get(cv2.CAP_PROP_FPS)
+ video.release()
+ return fps
+
+if __name__=="__main__":
+ video_path = "/tmp-data/QACode/QAMaterial/2022-02-26video/B200C视频/低头抬头+打电话+打哈欠+抽烟.mp4"
+ imglist = video2img(video_path,rate=0.1,start=0,end=100000)
+ print(len(imglist))
+ os.makedirs("tmp",exist_ok=True)
+ for idx,image in enumerate(imglist):
+ cv2.imwrite("{}/sample_{}.jpg".format("tmp",idx),image)
\ No newline at end of file
diff --git a/inference_mtl.py b/inference_mtl.py
new file mode 100644
index 0000000000000000000000000000000000000000..4627e1f85a6c0215b239c232ed06975dc1b49bf6
--- /dev/null
+++ b/inference_mtl.py
@@ -0,0 +1,266 @@
+# -*- coding:utf-8 –*-
+import os
+
+import cv2
+import numpy as np
+import torch
+
+from head_detect.demo import detect_driver_face
+from models.shufflenet2_att_m import ShuffleNetV2
+from utils.images import expand_box_rate, crop_with_pad, show
+from utils.os_util import get_file_paths
+from utils.plt_util import DrawMTL
+
+project = 'dms3'
+version = 'v1.0'
+
+if version in ["v0.1", "v0.2"]:
+ input_size = [160, 160]
+ model_size = '0.5x'
+ stack_lite_head = 1
+ resume = 'data/dms3/v0.2/epoch100_11_0.910_0.865_0.983_0.984_0.964_0.923_0.952_0.831_0.801_0.948_0.948_0.833.pth'
+ onnx_export_path = 'data/dms3/v0.2/dms3_mtl_v0.2.onnx'
+ tasks = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l", 'shift_x', 'shift_y', 'expand']
+ classes = [['normal', 'left', 'down', 'right', 'ind'], ['normal', 'close', 'ind'], ['normal', 'yawn', 'ind'],
+ ['normal', 'glass', 'ind'], ['normal', 'mask', 'ind'], ['normal', 'smoke'], ['normal', 'phone'],
+ ['distance'], ['distance'], ['distance'], ['distance'], ['distance']]
+ task_types = [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3]
+ num_classes = [5, 3, 3, 3, 3, 2, 2, 1, 1, 1, 1, 1]
+ reg_relative_max = [-1.] * 7 + [0.05, 0.05, 0.7, 0.7, 0.1]
+ reg_dimension = [-1.] * 7 + [3, 3, 3, 3, 2]
+ expand_rate = -0.075
+
+elif version == 'v1.0':
+ input_size = [160, 160]
+ model_size = '0.5x'
+ stack_lite_head = [1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1]
+ resume = 'data/dms3/v1.0/dms3_mtl_v1.0.pth'
+ onnx_export_path = 'data/dms3/v1.0/dms3_mtl_v1.0.onnx'
+ tasks = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l", 'shift_x', 'shift_y',
+ 'expand']
+ classes = [['normal', 'left', 'down', 'right', 'ind'], ['normal', 'close', 'ind'], ['normal', 'yawn', 'ind'],
+ ['normal', 'glass', 'ind'], ['normal', 'mask', 'ind'], ['normal', 'smoke'], ['normal', 'phone'],
+ ['distance'], ['distance'], ['distance'], ['distance'], ['distance']]
+ task_types = [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3]
+ num_classes = [5, 3, 3, 3, 3, 2, 2, 1, 1, 1, 1, 1]
+ reg_relative_max = [-1.] * 7 + [0.05, 0.05, 0.7, 0.7, 0.1]
+ reg_dimension = [-1.] * 7 + [3, 3, 3, 3, 2]
+ expand_rate = -0.075
+
+else:
+ raise NotImplementedError
+
+# device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+device = "cpu"
+model = ShuffleNetV2(num_tasks=len(tasks), task_types=task_types, num_classes=num_classes, model_size=model_size,
+ with_last_conv=0, stack_lite_head=stack_lite_head, lite_head_channels=-1, onnx=False)
+model.load_state_dict(torch.load(resume, map_location=device)['state_dict'])
+model.to(device)
+print(f"loading {resume}")
+model.eval()
+
+drawer = DrawMTL(project, tasks, task_types, classes)
+
+
+def onnx_export(export_name='test.onnx'):
+ model_export = ShuffleNetV2(num_tasks=len(tasks), task_types=task_types, num_classes=num_classes, model_size=model_size,
+ with_last_conv=0, stack_lite_head=stack_lite_head, lite_head_channels=-1, onnx=True)
+ model_export.load_state_dict(torch.load(resume, map_location=device)['state_dict'])
+ model_export.eval()
+
+ example = torch.randn(1, 3, input_size[1], input_size[0])
+ torch.onnx.export(
+ model_export, # model being run
+ example, # model input (or a tuple for multiple inputs)
+ export_name,
+ verbose=False,
+ # store the trained parameter weights inside the model file
+ training=False,
+ input_names=['input'],
+ output_names=tasks,
+ do_constant_folding=True
+ )
+
+
+def inference_xyl(input_img, vis=False, base_box=None, dx=None, dy=None, dl=None, return_drawn=False):
+ """
+ 单张推理,人头检测+多任务分类(带xy偏移)
+ :param input_img: input image path/cv_array
+ :param vis: 可视化
+ :param base_box: 基础框
+ :param dx: x方向偏移
+ :param dy: y方向偏移
+ :param return_drawn:
+ :return: preds->list [ems_pred, eye_pred, ..., dist_r, dist_l],
+ probs->list [[ems_probs], [eye_probs], ..., [1], [1]]
+ """
+ if isinstance(input_img, str) and os.path.isfile(input_img):
+ img = cv2.imread(input_img)
+ img_name = os.path.basename(input_img)
+ else:
+ img = input_img
+ img_name = 'image'
+
+ # img = cv2.resize(img, dsize=None, fx=0.5, fy=0.5)
+ if base_box is None or dx is None or dy is None or dl is None or \
+ abs(dx) >= 0.35 or abs(dy) >= 0.35 or abs(dl) >= 0.1:
+ box, score = detect_driver_face(img.copy())
+ print(box)
+ # if box == [0, 0, 0, 0]:
+ # return
+ base_box = expand_box_rate(img, box, rate=expand_rate)
+ print(base_box)
+ else:
+ box = None
+ w, h = base_box[2] - base_box[0], base_box[3] - base_box[1]
+ assert w == h
+ x0, y0, x1, y1 = base_box
+ x0, x1 = x0 + int(w * dx), x1 + int(w * dx)
+ y0, y1 = y0 + int(h * dy), y1 + int(h * dy)
+
+ expand = int(h * dl / (1 + 2 * dl))
+ x0, y0 = x0 + expand, y0 + expand
+ x1, y1 = x1 - expand, y1 - expand
+
+ base_box = [x0, y0, x1, y1]
+
+ crop_img = crop_with_pad(img, base_box)
+ crop_img = cv2.resize(crop_img, tuple(input_size))
+ crop_img = crop_img.astype(np.float32)
+ crop_img = (crop_img - 128) / 127.
+ crop_img = crop_img.transpose([2, 0, 1])
+ crop_img = torch.from_numpy(crop_img).to(device)
+ crop_img = crop_img.view(1, *crop_img.size())
+
+ outputs = model(crop_img)
+ preds, probs = [], []
+ msg_list = []
+ for ti, outs in enumerate(outputs):
+ if task_types[ti] == 0:
+ sub_probs = torch.softmax(outs, dim=1).cpu().detach().numpy()[0]
+ sub_pred = np.argmax(sub_probs)
+ # msg_list.append(f'{tasks[ti].upper()}: {classes[ti][sub_pred]} {sub_probs[sub_pred]:.3f}')
+ msg_list.append(f'{tasks[ti].upper()}: {sub_probs}')
+ preds.append(sub_pred)
+ probs.append([round(x, 3) for x in sub_probs])
+ elif task_types[ti] == 1:
+ sub_pred = outs.cpu().detach().item() * (base_box[2] - base_box[0]) * reg_relative_max[ti] / reg_dimension[ti]
+ msg_list.append(f'{tasks[ti].upper()}: {sub_pred:.6f}')
+ preds.append(sub_pred)
+ probs.append([round(sub_pred, 3)])
+ elif task_types[ti] in [2, 3]:
+ sub_pred = outs.cpu().detach().item() * reg_relative_max[ti] / reg_dimension[ti]
+ msg_list.append(f'{tasks[ti].upper()}: {sub_pred:.6f}')
+ preds.append(sub_pred)
+ probs.append([round(sub_pred, 3)])
+
+ # print('\n'.join(msg_list))
+
+ if vis:
+ # drawn = draw_texts(img, msg_list, box, crop_box, use_mask=True)
+ # show(drawn, img_name)
+ drawn = drawer.draw_result(img, preds, probs, box, base_box)
+ # drawn = drawer.draw_ind(img)
+ show(drawn, img_name)
+
+ if return_drawn:
+ drawn = drawer.draw_result(img, preds, probs, box, base_box)
+ return drawn
+
+ return preds, probs, box, base_box, msg_list
+
+
+def inference_xyl_dir(data_dir, vis):
+ img_paths = get_file_paths(data_dir)
+ for p in img_paths:
+ print(f"\n{os.path.basename(p)}")
+ inference_xyl(p, vis=vis)
+
+
+def inference_with_basebox(input_img, base_box):
+ if isinstance(input_img, str) and os.path.isfile(input_img):
+ img = cv2.imread(input_img)
+ else:
+ img = input_img
+
+ crop_img = crop_with_pad(img, base_box)
+ crop_img = cv2.resize(crop_img, tuple(input_size))
+ crop_img = crop_img.astype(np.float32)
+ crop_img = (crop_img - 128) / 127.
+ crop_img = crop_img.transpose([2, 0, 1])
+ crop_img = torch.from_numpy(crop_img).to(device)
+ crop_img = crop_img.view(1, *crop_img.size())
+
+ outputs = model(crop_img)
+ preds, probs = [], []
+ msg_list = []
+ for ti, outs in enumerate(outputs):
+ if task_types[ti] == 0:
+ sub_probs = torch.softmax(outs, dim=1).cpu().detach().numpy()[0]
+ sub_pred = np.argmax(sub_probs)
+ # msg_list.append(f'{tasks[ti].upper()}: {classes[ti][sub_pred]} {sub_probs[sub_pred]:.3f}')
+ msg_list.append(f'{tasks[ti].upper()}: {sub_probs}')
+ preds.append(sub_pred)
+ # probs.append([round(x, 3) for x in sub_probs])
+ probs += sub_probs.tolist()
+ elif task_types[ti] == 1:
+ sub_pred = outs.cpu().detach().item() * (base_box[2] - base_box[0]) * reg_relative_max[ti] / reg_dimension[
+ ti]
+ msg_list.append(f'{tasks[ti].upper()}: {sub_pred:.6f}')
+ preds.append(sub_pred)
+ # probs.append([round(sub_pred, 3)])
+ probs.append(sub_pred)
+ elif task_types[ti] in [2, 3]:
+ sub_pred = outs.cpu().detach().item() * reg_relative_max[ti] / reg_dimension[ti]
+ msg_list.append(f'{tasks[ti].upper()}: {sub_pred:.6f}')
+ preds.append(sub_pred)
+ # probs.append([round(sub_pred, 3)])
+ probs.append(sub_pred)
+
+ return probs
+
+
+def generate_onnx_result(data_dir, label_file, save_file):
+ from utils.labels import load_labels, save_labels
+ from utils.multiprogress import MultiThreading
+
+ label_dict = load_labels(label_file)
+
+ def kernel(img_name):
+ img_path = os.path.join(data_dir, img_name)
+ box = [int(b) for b in label_dict[img_name][:4]]
+ probs = inference_with_basebox(img_path, box)
+ print(img_name, probs)
+ return img_name, probs
+
+ exe = MultiThreading(label_dict.keys(), 8)
+ res = exe.run(kernel)
+ res_dict = {r[0]: r[1] for r in res}
+ save_labels(save_file, res_dict)
+
+
+if __name__ == '__main__':
+ inference_xyl('/Users/didi/Desktop/MTL/dataset/images/20210712062707_09af469be0e74945994d0d6e9e0cbe36_209949.jpg', vis=True)
+ # inference_xyl_dir('test_mtl', vis=True)
+
+ # generate_onnx_result('/Users/didi/Desktop/MTL/dataset/images',
+ # '/Users/didi/Desktop/test_mtl_raw_out/test_mtl_raw_label.txt',
+ # '/Users/didi/Desktop/test_mtl_raw_out/test_mtl_raw_torch_prob.txt')
+
+ # onnx_export_pipeline(onnx_export_path, export_func=onnx_export, view_net=True, simplify=True)
+
+"""
+20210712062707_09af469be0e74945994d0d6e9e0cbe36_209949.jpg
+EMS: [1.1826814e-06 5.6046390e-09 9.9999881e-01 6.0338436e-09 5.0259811e-08]
+EYE: [9.999689e-01 3.101773e-05 1.121569e-07]
+MOUTH: [9.991779e-01 5.185943e-06 8.168342e-04]
+GLASS: [9.9998879e-01 1.0724665e-05 4.9968133e-07]
+MASK: [9.9702114e-01 2.5327259e-03 4.4608722e-04]
+SMOKE: [0.94698274 0.05301724]
+PHONE: [9.9996448e-01 3.5488123e-05]
+EYELID_R: [1.9863424]
+EYELID_L: [2.1488686]
+SHIFT_X: [0.00669874]
+SHIFT_Y: [-0.00329317]
+EXPAND: [-0.07628014]
+"""
\ No newline at end of file
diff --git a/inference_video_mtl.py b/inference_video_mtl.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb043fea462e83cf2571130658ace957efb79445
--- /dev/null
+++ b/inference_video_mtl.py
@@ -0,0 +1,224 @@
+# -*- coding:utf-8 –*-
+import pandas as pd
+
+from inference_mtl import *
+from utils.os_util import get_file_paths
+from utils.plt_util import plot_scores_mtl, syn_plot_scores_mtl, DrawMTL
+from utils.time_util import convert_input_time, convert_stamp
+
+
+FOURCC = 'x264' # output video codec [h264] for internet
+
+
+def inference_videos(input_video, save_dir='output/out_vdo', detect_mode='frame', frequency=1.0, continuous=True,
+ show_res=False, plot_score=False, save_score=False, save_vdo=False, save_img=False, img_save_dir=None,
+ syn_plot=True, resize=None, time_delta=1, save_vdo_path=None, save_plt_path=None, save_csv_path=None):
+ """
+ @param input_video: 输入视频路径,文件夹或单个视频
+ @param save_dir: 输出视频存储目录
+ @param detect_mode: 抽帧模式,second/frame,表示按秒/帧推理
+ @param frequency: 抽帧频率倍数(0-n),乘法关系,数值越大间隔越大,可输入小数
+ @param show_res: 是否可视化推理过程的结果
+ @param continuous: 输入一个文件夹时是否其中的视频连续推理
+ @param plot_score: 是否绘制分数图
+ @param save_score: 是否保存推理结果csv文件
+ @param save_vdo: 是否保存推理结果视频
+ @param save_img: 是否保存中间帧图片,需单独设置保存条件
+ @param img_save_dir: 图像保存目录
+ @param syn_plot: 是否绘制动态同步的分数图(人数座次不支持)
+ @param resize: 输出调整
+ @param time_delta: 展示视频间隔 0 按键下一张
+ @param save_vdo_path: 视频保存路径,默认为None
+ @param save_plt_path: 分数图保存路径,默认为None
+ @param save_csv_path: 分数明细保存路径,默认为None
+ @return:
+ """
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ if save_img:
+ img_save_dir = os.path.join(save_dir, 'images') if img_save_dir is None else img_save_dir
+ if not os.path.exists(img_save_dir):
+ os.makedirs(img_save_dir)
+ save_count = 0
+
+ separately_save = True
+ if os.path.isfile(input_video):
+ vdo_list = [input_video]
+ elif os.path.isdir(input_video):
+ vdo_list = get_file_paths(input_video, mod='vdo')
+ if continuous:
+ title = os.path.basename(input_video)
+ save_vdo_path = os.path.join(save_dir, title + '.mp4') if save_vdo_path is None else save_vdo_path
+ save_plt_path = os.path.join(save_dir, title + '.jpg') if save_plt_path is None else save_plt_path
+ save_csv_path = os.path.join(save_dir, title + '.csv') if save_csv_path is None else save_csv_path
+ separately_save = False
+ frames, seconds = 0, 0
+ else:
+ print(f'No {input_video}')
+ return
+
+ if save_score:
+ columns = ['index']
+ for ti, task in enumerate(tasks):
+ if ti < 7:
+ sub_columns = [f"{task}-{sc}" for sc in classes[ti]]
+ else:
+ sub_columns = [task]
+ columns += sub_columns
+
+ if save_vdo and not separately_save:
+ video = cv2.VideoCapture(vdo_list[0])
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) if resize is None else resize[0]
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) if resize is None else resize[1]
+ fps = video.get(cv2.CAP_PROP_FPS)
+ fourcc = cv2.VideoWriter_fourcc(*FOURCC)
+ out_video = cv2.VideoWriter(save_vdo_path, fourcc, fps if detect_mode == 'frame' else 5,
+ (int(width*1.5), height) if syn_plot else (width, height))
+ print(f"result video save in '{save_vdo_path}'")
+
+ res_list = []
+ for vdo_path in vdo_list:
+ vdo_name = os.path.basename(vdo_path)
+ try:
+ start_time_str = vdo_name.split('_')[2][:14]
+ start_time_stamp = convert_input_time(start_time_str, digit=10)
+ except:
+ start_time_str, start_time_stamp = '', 0
+
+ cap = cv2.VideoCapture(vdo_path)
+ cur_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ if not cur_frames:
+ continue
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ cur_seconds = int(cur_frames / (fps + 1e-6))
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ print(f"video:{vdo_name} width:{width} height:{height} fps:{fps:.1f} frames:{cur_frames} "
+ f"seconds: {cur_seconds} start_time:{start_time_str}")
+
+ if resize is not None:
+ width, height = resize
+
+ cur_res_list = []
+ if separately_save:
+ title = os.path.splitext(vdo_name)[0]
+ save_vdo_path = os.path.join(save_dir, title + '_res.mp4') if save_vdo_path is None else save_vdo_path
+ save_plt_path = os.path.join(save_dir, title + '_res.jpg') if save_plt_path is None else save_plt_path
+ save_csv_path = os.path.join(save_dir, title + '_res.csv') if save_csv_path is None else save_csv_path
+ if save_vdo:
+ fourcc = cv2.VideoWriter_fourcc(*FOURCC)
+ out_video = cv2.VideoWriter(save_vdo_path, fourcc, fps if detect_mode == 'frame' else 5,
+ (int(1.5*width), height) if syn_plot else (width, height))
+ print(f"result video save in '{save_vdo_path}'")
+ else:
+ frames += cur_frames
+ seconds += cur_seconds
+
+ base_box, dx, dy, dl = None, None, None, None
+
+ step = 1 if detect_mode == 'frame' else fps
+ step = max(1, round(step * frequency))
+ count = 0
+ for i in range(0, cur_frames, step):
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
+ ret, frame = cap.read()
+ if not ret:
+ # print('video end!')
+ break
+ count += 1
+
+ if resize is not None and frame.shape[0] != resize[1] and frame.shape[1] != resize[0]:
+ frame = cv2.resize(frame, resize)
+
+ cur_res = inference_xyl(frame, vis=False, base_box=base_box, dx=dx, dy=dy, dl=dl)
+ if cur_res is None:
+ preds = [0] * len(tasks)
+ probs = [[0] * len(sub_classes) for sub_classes in classes]
+ msg_list = ['no driver']
+ base_box, dx, dy, dl = None, None, None, None
+ else:
+ preds, probs, box, crop_box, msg_list = cur_res
+ base_box, dx, dy, dl = crop_box.copy(), preds[-3], preds[-2], preds[-1]
+
+ if start_time_stamp:
+ time_stamp = start_time_stamp + round(i/fps)
+ cur_time = convert_stamp(time_stamp)
+ cur_res_list.append(tuple([cur_time+f'-{i}'] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
+ res_list.append(tuple([cur_time+f'-{i}'] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
+ else:
+ cur_time = ''
+ cur_res_list.append(tuple([count] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
+ res_list.append(tuple([count] + [round(p, 3) for sub_probs in probs for p in sub_probs]))
+
+ if not count % 10 and count:
+ msg = "{} {} => {}".format(i, cur_time, '\t'.join(msg_list))
+ print(msg)
+
+ if save_img and probs[1][1] > 0.8: # Todo:设置不同的保存条件
+ img_name = vdo_name.replace(".mp4", f'_{i}.jpg') if not cur_time else \
+ f"{convert_stamp(convert_input_time(cur_time), '%Y%m%d%H%M%S')}_{i}.jpg"
+ img_save_path = os.path.join(img_save_dir, img_name)
+ cv2.imwrite(img_save_path, frame)
+ save_count += 1
+
+ if show_res or save_vdo:
+ drawn = drawer.draw_ind(frame) if cur_res is None else \
+ drawer.draw_result(frame, preds, probs, box, crop_box, use_mask=False, use_frame=False)
+ if syn_plot:
+ score_array = np.array([r[1:] for r in res_list])
+ if detect_mode == 'second' and cur_time:
+ indexes = [r[0][-5:] for r in res_list]
+ else:
+ indexes = list(range(len(res_list)))
+ window_length = 300 if detect_mode == 'frame' else 30
+ assert len(score_array) == len(indexes)
+ score_chart = syn_plot_scores_mtl(
+ tasks, [[0, 1, 2, 3, 4, 5, 6], [7, 8], [9, 10], [11]], classes, indexes, score_array,
+ int(0.5*width), height, window_length, width/1280)
+
+ drawn = np.concatenate([drawn, score_chart], axis=1)
+
+ if show_res:
+ cv2.namedWindow(title, 0)
+ # cv2.moveWindow(title, 0, 0)
+ # cv2.setWindowProperty(title, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
+ cv2.imshow(title, drawn)
+ cv2.waitKey(time_delta)
+
+ # write the frame after processing
+ if save_vdo:
+ out_video.write(drawn)
+
+ if separately_save:
+ if show_res:
+ cv2.destroyWindow(title)
+ if plot_score:
+ res = np.array([r[1:] for r in cur_res_list])
+ plot_scores_mtl(tasks, task_types, classes, res, title, detect_mode, save_dir=save_dir,
+ save_path=save_plt_path, show=show_res)
+ if save_score:
+ df = pd.DataFrame(cur_res_list, columns=columns)
+ df.to_csv(save_csv_path, index=False, float_format='%.3f')
+
+ if save_img:
+ print(f"total save {save_count} images")
+
+ if not separately_save:
+ if show_res:
+ cv2.destroyWindow(title)
+ if plot_score:
+ res = np.array([r[1:] for r in res_list])
+ plot_scores_mtl(tasks, task_types, classes, res, title, detect_mode, save_dir=save_dir,
+ save_path=save_plt_path, show=show_res)
+ if save_score:
+ df = pd.DataFrame(res_list, columns=columns)
+ df.to_csv(save_csv_path, index=False, float_format='%.3f')
+
+ return save_vdo_path, save_plt_path, save_vdo_path
+
+
+if __name__ == '__main__':
+ inference_videos('/Users/didi/Desktop/CARVIDEO_03afqwj7uw5d801e_20230708132305000_20230708132325000.mp4',
+ save_dir='/Users/didi/Desktop/error_res_v1.0',
+ detect_mode='second', frequency=0.2, plot_score=True, save_score=True, syn_plot=True,
+ save_vdo=True, save_img=False, continuous=False, show_res=True, resize=(1280, 720), time_delta=1)
diff --git a/models/__pycache__/shufflenet2_att_m.cpython-38.pyc b/models/__pycache__/shufflenet2_att_m.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ea737030a54293202b79ad97d1a6c8db311d1fb
Binary files /dev/null and b/models/__pycache__/shufflenet2_att_m.cpython-38.pyc differ
diff --git a/models/module/__pycache__/activation.cpython-38.pyc b/models/module/__pycache__/activation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a69025f7ac0a3f9146303484c46f37a0b1af9cda
Binary files /dev/null and b/models/module/__pycache__/activation.cpython-38.pyc differ
diff --git a/models/module/__pycache__/conv.cpython-38.pyc b/models/module/__pycache__/conv.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3354ee5f76ebd1f3e34c3a52309296c4c141bcb7
Binary files /dev/null and b/models/module/__pycache__/conv.cpython-38.pyc differ
diff --git a/models/module/__pycache__/init_weights.cpython-38.pyc b/models/module/__pycache__/init_weights.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf3e71704def26d13d870192f6178b61e1808b76
Binary files /dev/null and b/models/module/__pycache__/init_weights.cpython-38.pyc differ
diff --git a/models/module/__pycache__/norm.cpython-38.pyc b/models/module/__pycache__/norm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f48095474ba1cd7d16f9f36af9554176e0e32b5
Binary files /dev/null and b/models/module/__pycache__/norm.cpython-38.pyc differ
diff --git a/models/module/activation.py b/models/module/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..82144fa30bd8b81e0008a3a2da33d5df37f606c4
--- /dev/null
+++ b/models/module/activation.py
@@ -0,0 +1,17 @@
+import torch.nn as nn
+
+activations = {'ReLU': nn.ReLU,
+ 'LeakyReLU': nn.LeakyReLU,
+ 'ReLU6': nn.ReLU6,
+ 'SELU': nn.SELU,
+ 'ELU': nn.ELU,
+ None: nn.Identity
+ }
+
+
+def act_layers(name):
+ assert name in activations.keys()
+ if name == 'LeakyReLU':
+ return nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ else:
+ return activations[name](inplace=True)
diff --git a/models/module/blocks.py b/models/module/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8ee6bd1d96535a9d02da987255b5a3192ba6fd
--- /dev/null
+++ b/models/module/blocks.py
@@ -0,0 +1,300 @@
+# -*- coding:utf-8 –*-
+import torch.nn as nn
+from models.attention.attention_blocks import *
+import torch.nn.functional as F
+from utils.common import log_warn
+
+
+# 在每个shuffle_block之后加入att_block
+class ShuffleV2Block(nn.Module):
+ def __init__(self, inp, oup, mid_channels, ksize, stride, attention='', ratio=16, loc='side', onnx=False):
+ super(ShuffleV2Block, self).__init__()
+ self.onnx = onnx
+ self.stride = stride
+ assert stride in [1, 2]
+
+ self.mid_channels = mid_channels
+ self.ksize = ksize
+ pad = ksize // 2
+ self.pad = pad
+ self.inp = inp
+
+ outputs = oup - inp
+
+ branch_main = [
+ # pw
+ nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(mid_channels),
+ nn.ReLU(inplace=True),
+ # dw
+ nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
+ nn.BatchNorm2d(mid_channels),
+ # pw-linear
+ nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(outputs),
+ nn.ReLU(inplace=True),
+ ]
+ self.branch_main = nn.Sequential(*branch_main)
+
+ if stride == 2:
+ branch_proj = [
+ # dw
+ nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ # pw-linear
+ nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.ReLU(inplace=True),
+ ]
+ self.branch_proj = nn.Sequential(*branch_proj)
+ else:
+ self.branch_proj = None
+
+ if attention:
+ self.loc = loc
+ att_out = outputs if loc == 'side' else oup
+ if attention.lower() == 'se':
+ self.att_block = SELayer(att_out, reduction=ratio)
+ elif attention.lower() == 'cbam':
+ self.att_block = CBAM(att_out, ratio)
+ elif attention.lower() == 'gc':
+ self.att_block = GCBlock(att_out, ratio=ratio)
+ else:
+ raise NotImplementedError
+ else:
+ self.att_block = None
+
+ def forward(self, old_x):
+ if self.stride == 1:
+ x_proj, x = self.channel_shuffle(old_x)
+ else:
+ x_proj = old_x
+ x_proj = self.branch_proj(x_proj)
+ x = old_x
+ x = self.branch_main(x)
+ if self.att_block and self.loc == 'side':
+ x = self.att_block(x)
+ x = torch.cat((x_proj, x), 1)
+ if self.att_block and self.loc == 'after':
+ x = self.att_block(x)
+ return x
+
+ def channel_shuffle(self, x):
+ batchsize, num_channels, height, width = x.data.size()
+ if self.onnx:
+ # 由于需要将onnx模型转换为ifx模型,ifx引擎以nchw(n=1)的格式存储数据,因此做shape变换时,尽量保证按nchw(n=1)来操作
+ x = x.reshape(1, batchsize * num_channels // 2, 2, height * width)
+ x = x.permute(0, 2, 1, 3)
+ z = num_channels // 2
+ x = x.reshape(1, -1, height, width)
+ # split时避免使用x[0]、x[1]的操作,尽量使用torch的算子来实现
+ x1, x2 = torch.split(x, split_size_or_sections=[z, z], dim=1)
+ return x1, x2
+ else:
+ x = x.reshape(batchsize * num_channels // 2, 2, height * width)
+ x = x.permute(1, 0, 2)
+ x = x.reshape(2, -1, num_channels // 2, height, width)
+ return x[0], x[1]
+
+
+# 在每个shuffle_block之后加入att_block
+class QuantizableShuffleV2Block(ShuffleV2Block):
+
+ def __init__(self, *args, **kwargs):
+ if kwargs.get('attention', ''):
+ log_warn('Quantizable model not support attention blocks')
+ kwargs['attention'] = ''
+ super(QuantizableShuffleV2Block, self).__init__(*args, **kwargs)
+ self.quantized_funcs = nn.quantized.FloatFunctional()
+
+ def forward(self, old_x):
+ if self.branch_proj is None:
+ x_proj, x = self.channel_shuffle(old_x)
+ else:
+ x_proj = old_x
+ x_proj = self.branch_proj(x_proj)
+ x = old_x
+ x = self.branch_main(x)
+ x = self.quantized_funcs.cat((x_proj, x), 1)
+ return x
+
+ def channel_shuffle(self, x):
+ batchsize, num_channels, height, width = x.data.size()
+ x = x.reshape(batchsize * num_channels // 2, 2, height * width)
+ x = x.permute(1, 0, 2)
+ x = x.reshape(2, -1, num_channels // 2, height, width)
+ return x[0], x[1]
+
+
+class ShuffleV2BlockSK(nn.Module):
+ def __init__(self, inp, oup, mid_channels, *, ksize, stride):
+ super(ShuffleV2BlockSK, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ self.mid_channels = mid_channels
+ self.ksize = ksize
+ pad = ksize // 2
+ self.pad = pad
+ self.inp = inp
+
+ outputs = oup - inp
+
+ branch_main = [
+ # pw
+ nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(mid_channels),
+ nn.ReLU(inplace=True),
+ # dw
+ # nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
+ SKConv(mid_channels, 2, mid_channels, stride=stride, use_relu=False),
+ # SKConv2(mid_channels, 2, mid_channels, 4, stride=stride),
+ nn.BatchNorm2d(mid_channels),
+ # pw-linear
+ nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(outputs),
+ nn.ReLU(inplace=True),
+ ]
+ self.branch_main = nn.Sequential(*branch_main)
+
+ if stride == 2:
+ branch_proj = [
+ # dw
+ nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
+ # SKConv(inp, 2, inp, stride=stride, use_relu=False),
+ # SKConv2(inp, 2, inp, 4, stride=stride),
+ nn.BatchNorm2d(inp),
+ # pw-linear
+ nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.ReLU(inplace=True),
+ ]
+ self.branch_proj = nn.Sequential(*branch_proj)
+ else:
+ self.branch_proj = None
+
+ def forward(self, old_x):
+ if self.stride == 1:
+ x_proj, x = self.channel_shuffle(old_x)
+ return torch.cat((x_proj, self.branch_main(x)), 1)
+ elif self.stride == 2:
+ x_proj = old_x
+ x = old_x
+ return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)
+
+ def channel_shuffle(self, x):
+ batchsize, num_channels, height, width = x.data.size()
+ assert (num_channels % 4 == 0)
+ x = x.reshape(batchsize * num_channels // 2, 2, height * width)
+ x = x.permute(1, 0, 2)
+ x = x.reshape(2, -1, num_channels // 2, height, width)
+ return x[0], x[1]
+
+
+class SnetBlock(ShuffleV2Block):
+ """
+ 自定义了shuffle函数,其它都一样
+ """
+
+ def channel_shuffle(self, x):
+ g = 2
+ x = x.reshape(x.shape[0], g, x.shape[1] // g, x.shape[2], x.shape[3])
+ x = x.permute(0, 2, 1, 3, 4)
+ x = x.reshape(x.shape[0], -1, x.shape[3], x.shape[4])
+ x_proj = x[:, :(x.shape[1] // 2), :, :]
+ x = x[:, (x.shape[1] // 2):, :, :]
+ return x_proj, x
+
+
+def conv_bn_relu(inp, oup, kernel_size, stride, pad):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.ReLU(inplace=True)
+ )
+
+
+class CEM(nn.Module):
+ """
+ Context Enhancement Module
+ 改进的FPN结构,c4、c5、glb分别1x1卷积(c5上次样,glb传播),结果聚合
+ 支持feat_stride 8、16
+ TODO: 最后再做一次卷积
+ """
+
+ def __init__(self, in_channels1, in_channels2, in_channels3,
+ feat_stride=16, squeeze_channels=245, use_relu=False):
+ super(CEM, self).__init__()
+ self.feat_stride = feat_stride
+ assert feat_stride in [8, 16], f"{feat_stride} not support, select in [8, 16]"
+ if feat_stride == 8:
+ self.conv3 = nn.Conv2d(in_channels1 // 2, squeeze_channels, 1, bias=True)
+ self.conv4 = nn.Conv2d(in_channels1, squeeze_channels, 1, bias=True)
+ self.conv5 = nn.Conv2d(in_channels2, squeeze_channels, 1, bias=True)
+ self.conv_last = nn.Conv2d(in_channels3, squeeze_channels, 1, bias=True)
+ self.use_relu = use_relu
+ if use_relu:
+ self.relu7 = nn.ReLU(inplace=True)
+
+ def forward(self, inputs):
+ if self.feat_stride == 8:
+ c3_lat = self.conv3(inputs[0])
+ c4_lat = self.conv4(inputs[1])
+ c4_lat = F.interpolate(c4_lat, size=[c3_lat.size(2), c3_lat.size(3)], mode="nearest")
+ c5_lat = self.conv5(inputs[2])
+ c5_lat = F.interpolate(c5_lat, size=[c3_lat.size(2), c3_lat.size(3)], mode="nearest")
+ glb_lat = self.conv_last(inputs[3])
+ out = c3_lat + c4_lat + c5_lat + glb_lat
+ else:
+ c4_lat = self.conv4(inputs[0])
+ c5_lat = self.conv5(inputs[1])
+ c5_lat = F.interpolate(c5_lat, size=[c4_lat.size(2), c4_lat.size(3)], mode="nearest") # 上采样
+ glb_lat = self.conv_last(inputs[2])
+ out = c4_lat + c5_lat + glb_lat
+
+ if self.use_relu:
+ out = self.relu7(out)
+ return out
+
+
+class CEM_a(nn.Module):
+ """
+ Context Enhancement Module
+ 改进的FPN结构,c4、c5、glb分别1x1卷积(c5上次样,glb传播),结果聚合
+ 支持feat_stride 8、16
+ TODO: 最后再做一次卷积
+ """
+
+ def __init__(self, in_channels1, in_channels2, in_channels3,
+ feat_stride=16, squeeze_channels=245, use_relu=False):
+ super(CEM_a, self).__init__()
+ self.feat_stride = feat_stride
+ assert feat_stride in [8, 16], f"{feat_stride} not support, select in [8, 16]"
+ if feat_stride == 8:
+ self.conv3 = nn.Conv2d(in_channels1 // 2, squeeze_channels, 1, bias=True)
+ self.conv4 = nn.Conv2d(in_channels1, squeeze_channels, 1, bias=True)
+ self.conv5 = nn.Conv2d(in_channels2, squeeze_channels, 1, bias=True)
+ self.conv_last = nn.Conv2d(in_channels3, squeeze_channels, 1, bias=True)
+ self.use_relu = use_relu
+ if use_relu:
+ self.relu7 = nn.ReLU(inplace=True)
+
+ def forward(self, inputs):
+ if self.feat_stride == 8:
+ c3_lat = self.conv3(inputs[0])
+ c4_lat = self.conv4(inputs[1])
+ c4_lat = F.interpolate(c4_lat, size=[c3_lat.size(2), c3_lat.size(3)], mode="nearest")
+ c5_lat = self.conv5(inputs[2])
+ c5_lat = F.interpolate(c5_lat, size=[c3_lat.size(2), c3_lat.size(3)], mode="nearest")
+ glb_lat = self.conv_last(inputs[3])
+ out = c3_lat + c4_lat + c5_lat + glb_lat
+ else:
+ c4_lat = self.conv4(inputs[0])
+ c5_lat = self.conv5(inputs[1])
+ c5_lat = F.interpolate(c5_lat, size=[c4_lat.size(2), c4_lat.size(3)], mode="nearest") # 上采样
+ glb_lat = self.conv_last(inputs[2])
+ out = c4_lat + c5_lat + glb_lat
+
+ if self.use_relu:
+ out = self.relu7(out)
+ return out
diff --git a/models/module/conv.py b/models/module/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c550febb9f06e125c2daa7e6a14ff9a1c59134d
--- /dev/null
+++ b/models/module/conv.py
@@ -0,0 +1,340 @@
+"""
+ConvModule refers from MMDetection
+RepVGGConvModule refers from RepVGG: Making VGG-style ConvNets Great Again
+"""
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from models.module.activation import act_layers
+from models.module.init_weights import kaiming_init, constant_init
+from models.module.norm import build_norm_layer
+
+
+class ConvModule(nn.Module):
+ """A conv block that contains conv/norm/activation layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ conv_cfg (dict): Config dict for convolution layer.
+ norm_cfg (dict): Config dict for normalization layer.
+ activation (str): activation layer, "ReLU" by default.
+ inplace (bool): Whether to use inplace mode for activation.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias='auto',
+ conv_cfg=None,
+ norm_cfg=None,
+ activation='ReLU',
+ inplace=True,
+ order=('conv', 'norm', 'act')):
+ super(ConvModule, self).__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ assert activation is None or isinstance(activation, str)
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.activation = activation
+ self.inplace = inplace
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(['conv', 'norm', 'act'])
+
+ self.with_norm = norm_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = False if self.with_norm else True
+ self.with_bias = bias
+
+ if self.with_norm and self.with_bias:
+ warnings.warn('ConvModule has norm and bias at the same time')
+
+ # build convolution layer
+ self.conv = nn.Conv2d( #
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = self.conv.padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
+ self.add_module(self.norm_name, norm)
+
+ # build activation layer
+ if self.activation:
+ self.act = act_layers(self.activation)
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ return getattr(self, self.norm_name)
+
+ def init_weights(self):
+ if self.activation == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ else:
+ nonlinearity = 'relu'
+ kaiming_init(self.conv, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+
+ def forward(self, x, norm=True):
+ for layer in self.order:
+ if layer == 'conv':
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and self.activation:
+ x = self.act(x)
+ return x
+
+
+class DepthwiseConvModule(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias='auto',
+ norm_cfg=dict(type='BN'),
+ activation='ReLU',
+ inplace=True,
+ order=('depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act')):
+ super(DepthwiseConvModule, self).__init__()
+ assert activation is None or isinstance(activation, str)
+ self.activation = activation
+ self.inplace = inplace
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 6
+ assert set(order) == set(['depthwise', 'dwnorm', 'act', 'pointwise', 'pwnorm', 'act'])
+
+ self.with_norm = norm_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = False if self.with_norm else True
+ self.with_bias = bias
+
+ if self.with_norm and self.with_bias:
+ warnings.warn('ConvModule has norm and bias at the same time')
+
+ # build convolution layer
+ self.depthwise = nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ bias=bias)
+ self.pointwise = nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias)
+
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.depthwise.in_channels
+ self.out_channels = self.pointwise.out_channels
+ self.kernel_size = self.depthwise.kernel_size
+ self.stride = self.depthwise.stride
+ self.padding = self.depthwise.padding
+ self.dilation = self.depthwise.dilation
+ self.transposed = self.depthwise.transposed
+ self.output_padding = self.depthwise.output_padding
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ _, self.dwnorm = build_norm_layer(norm_cfg, in_channels)
+ _, self.pwnorm = build_norm_layer(norm_cfg, out_channels)
+
+ # build activation layer
+ if self.activation:
+ self.act = act_layers(self.activation)
+
+ # Use msra init by default
+ self.init_weights()
+
+ def init_weights(self):
+ if self.activation == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ else:
+ nonlinearity = 'relu'
+ kaiming_init(self.depthwise, nonlinearity=nonlinearity)
+ kaiming_init(self.pointwise, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.dwnorm, 1, bias=0)
+ constant_init(self.pwnorm, 1, bias=0)
+
+ def forward(self, x, norm=True):
+ for layer_name in self.order:
+ if layer_name != 'act':
+ layer = self.__getattr__(layer_name)
+ x = layer(x)
+ elif layer_name == 'act' and self.activation:
+ x = self.act(x)
+ return x
+
+
+class RepVGGConvModule(nn.Module):
+ """
+ RepVGG Conv Block from paper RepVGG: Making VGG-style ConvNets Great Again
+ https://arxiv.org/abs/2101.03697
+ https://github.com/DingXiaoH/RepVGG
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ activation='ReLU',
+ padding_mode='zeros',
+ deploy=False):
+ super(RepVGGConvModule, self).__init__()
+ assert activation is None or isinstance(activation, str)
+ self.activation = activation
+
+ self.deploy = deploy
+ self.groups = groups
+ self.in_channels = in_channels
+
+ assert kernel_size == 3
+ assert padding == 1
+
+ padding_11 = padding - kernel_size // 2
+
+ # build activation layer
+ if self.activation:
+ self.act = act_layers(self.activation)
+
+ if deploy:
+ self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation, groups=groups, bias=True,
+ padding_mode=padding_mode)
+
+ else:
+ self.rbr_identity = nn.BatchNorm2d(
+ num_features=in_channels) if out_channels == in_channels and stride == 1 else None
+
+ self.rbr_dense = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, stride=stride, padding=padding,
+ groups=groups, bias=False),
+ nn.BatchNorm2d(num_features=out_channels))
+
+ self.rbr_1x1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=1, stride=stride, padding=padding_11,
+ groups=groups, bias=False),
+ nn.BatchNorm2d(num_features=out_channels))
+ print('RepVGG Block, identity = ', self.rbr_identity)
+
+ def forward(self, inputs):
+ if hasattr(self, 'rbr_reparam'):
+ return self.act(self.rbr_reparam(inputs))
+
+ if self.rbr_identity is None:
+ id_out = 0
+ else:
+ id_out = self.rbr_identity(inputs)
+
+ return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
+
+ # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
+ # You can get the equivalent kernel and bias at any time and do whatever you want,
+ # for example, apply some penalties or constraints during training, just like you do to the other models.
+ # May be useful for quantization or pruning.
+ def get_equivalent_kernel_bias(self):
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
+ kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
+
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
+ if kernel1x1 is None:
+ return 0
+ else:
+ return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
+
+ def _fuse_bn_tensor(self, branch):
+ if branch is None:
+ return 0, 0
+ if isinstance(branch, nn.Sequential):
+ kernel = branch[0].weight
+ running_mean = branch[1].running_mean
+ running_var = branch[1].running_var
+ gamma = branch[1].weight
+ beta = branch[1].bias
+ eps = branch[1].eps
+ else:
+ assert isinstance(branch, nn.BatchNorm2d)
+ if not hasattr(self, 'id_tensor'):
+ input_dim = self.in_channels // self.groups
+ kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
+ for i in range(self.in_channels):
+ kernel_value[i, i % input_dim, 1, 1] = 1
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
+ kernel = self.id_tensor
+ running_mean = branch.running_mean
+ running_var = branch.running_var
+ gamma = branch.weight
+ beta = branch.bias
+ eps = branch.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape(-1, 1, 1, 1)
+ return kernel * t, beta - running_mean * gamma / std
+
+ def repvgg_convert(self):
+ kernel, bias = self.get_equivalent_kernel_bias()
+ return kernel.detach().cpu().numpy(), bias.detach().cpu().numpy(),
diff --git a/models/module/fpn.py b/models/module/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..73a48964a7912acf188af367d7da0ff1d9e617b4
--- /dev/null
+++ b/models/module/fpn.py
@@ -0,0 +1,165 @@
+# -*- coding:utf-8 –*-
+"""
+from MMDetection
+"""
+
+import torch.nn as nn
+import torch.nn.functional as F
+from models.module.conv import ConvModule
+from models.module.init_weights import xavier_init
+
+
+class FPN(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ conv_cfg=None,
+ norm_cfg=None,
+ activation=None
+ ):
+ super(FPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.fp16_enabled = False
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.lateral_convs = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ activation=activation,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.init_weights()
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, mode='bilinear', align_corners=False)
+
+ # build outputs
+ outs = [
+ # self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ laterals[i] for i in range(used_backbone_levels)
+ ]
+ return tuple(outs)
+
+
+class PAN(FPN):
+ """Path Aggregation Network for Instance Segmentation.
+
+ This is an implementation of the `PAN in Path Aggregation Network
+ `_.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool): Whether to add conv layers on top of the
+ original feature maps. Default: False.
+ extra_convs_on_inputs (bool): Whether to apply extra conv on
+ the original feature from the backbone. Default: False.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (str): Config dict for activation layer in ConvModule.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ conv_cfg=None,
+ norm_cfg=None,
+ activation=None):
+ super(PAN,
+ self).__init__(in_channels, out_channels, num_outs, start_level,
+ end_level, conv_cfg, norm_cfg, activation)
+ self.init_weights()
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, mode='bilinear', align_corners=False)
+
+ # build outputs
+ # part 1: from original levels
+ inter_outs = [
+ laterals[i] for i in range(used_backbone_levels)
+ ]
+
+ # part 2: add bottom-up path
+ for i in range(0, used_backbone_levels - 1):
+ prev_shape = inter_outs[i + 1].shape[2:]
+ inter_outs[i + 1] += F.interpolate(inter_outs[i], size=prev_shape, mode='bilinear', align_corners=False)
+
+ outs = []
+ outs.append(inter_outs[0])
+ outs.extend([
+ inter_outs[i] for i in range(1, used_backbone_levels)
+ ])
+ return tuple(outs)
+
+
+# if __name__ == '__main__':
diff --git a/models/module/init_weights.py b/models/module/init_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7c2ea1c8f18da0d98579ee1adbd32713d441111
--- /dev/null
+++ b/models/module/init_weights.py
@@ -0,0 +1,44 @@
+"""
+from MMDetection
+"""
+import torch.nn as nn
+
+
+def kaiming_init(module,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ bias=0,
+ distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
\ No newline at end of file
diff --git a/models/module/norm.py b/models/module/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5687cbd9a31be259ee0b202800af53fbba8ca1a
--- /dev/null
+++ b/models/module/norm.py
@@ -0,0 +1,55 @@
+import torch.nn as nn
+
+norm_cfg = {
+ # format: layer_type: (abbreviation, module)
+ 'BN': ('bn', nn.BatchNorm2d),
+ 'SyncBN': ('bn', nn.SyncBatchNorm),
+ 'GN': ('gn', nn.GroupNorm),
+ # and potentially 'SN'
+}
+
+
+def build_norm_layer(cfg, num_features, postfix=''):
+ """ Build normalization layer
+
+ Args:
+ cfg (dict): cfg should contain:
+ type (str): identify norm layer type.
+ layer args: args needed to instantiate a norm layer.
+ requires_grad (bool): [optional] whether stop gradient updates
+ num_features (int): number of channels from input.
+ postfix (int, str): appended into norm abbreviation to
+ create named layer.
+
+ Returns:
+ name (str): abbreviation + postfix
+ layer (nn.Module): created norm layer
+ """
+ assert isinstance(cfg, dict) and 'type' in cfg
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in norm_cfg:
+ raise KeyError('Unrecognized norm type {}'.format(layer_type))
+ else:
+ abbr, norm_layer = norm_cfg[layer_type]
+ if norm_layer is None:
+ raise NotImplementedError
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ requires_grad = cfg_.pop('requires_grad', True)
+ cfg_.setdefault('eps', 1e-5)
+ if layer_type != 'GN':
+ layer = norm_layer(num_features, **cfg_)
+ if layer_type == 'SyncBN':
+ layer._specify_ddp_gpu_num(1)
+ else:
+ assert 'num_groups' in cfg_
+ layer = norm_layer(num_channels=num_features, **cfg_)
+
+ for param in layer.parameters():
+ param.requires_grad = requires_grad
+
+ return name, layer
diff --git a/models/shufflenet2_att_m.py b/models/shufflenet2_att_m.py
new file mode 100644
index 0000000000000000000000000000000000000000..af5beba2466aa14ed1f6e561f6859bdbf92a96ea
--- /dev/null
+++ b/models/shufflenet2_att_m.py
@@ -0,0 +1,265 @@
+# -*- coding:utf-8 –*-
+import torch
+import torch.nn as nn
+from models.module.conv import DepthwiseConvModule
+
+
+class ShuffleV2Block(nn.Module):
+ def __init__(self, inp, oup, mid_channels, ksize, stride, attention='', ratio=16, loc='side', onnx=False):
+ super(ShuffleV2Block, self).__init__()
+ self.onnx = onnx
+ self.stride = stride
+ assert stride in [1, 2]
+
+ self.mid_channels = mid_channels
+ self.ksize = ksize
+ pad = ksize // 2
+ self.pad = pad
+ self.inp = inp
+
+ outputs = oup - inp
+
+ branch_main = [
+ # pw
+ nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(mid_channels),
+ nn.ReLU(inplace=True),
+ # dw
+ nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
+ nn.BatchNorm2d(mid_channels),
+ # pw-linear
+ nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(outputs),
+ nn.ReLU(inplace=True),
+ ]
+ self.branch_main = nn.Sequential(*branch_main)
+
+ if stride == 2:
+ branch_proj = [
+ # dw
+ nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ # pw-linear
+ nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.ReLU(inplace=True),
+ ]
+ self.branch_proj = nn.Sequential(*branch_proj)
+ else:
+ self.branch_proj = None
+
+ def forward(self, old_x):
+ if self.stride == 1:
+ x_proj, x = self.channel_shuffle(old_x)
+ else:
+ x_proj = old_x
+ x_proj = self.branch_proj(x_proj)
+ x = old_x
+ x = self.branch_main(x)
+ x = torch.cat((x_proj, x), 1)
+ return x
+
+ def channel_shuffle(self, x):
+ batchsize, num_channels, height, width = x.data.size()
+ if self.onnx:
+ # 由于需要将onnx模型转换为ifx模型,ifx引擎以nchw(n=1)的格式存储数据,因此做shape变换时,尽量保证按nchw(n=1)来操作
+ x = x.reshape(1, batchsize * num_channels // 2, 2, height * width)
+ x = x.permute(0, 2, 1, 3)
+ z = num_channels // 2
+ x = x.reshape(1, -1, height, width)
+ # split时避免使用x[0]、x[1]的操作,尽量使用torch的算子来实现
+ x1, x2 = torch.split(x, split_size_or_sections=[z, z], dim=1)
+ return x1, x2
+ else:
+ x = x.reshape(batchsize * num_channels // 2, 2, height * width)
+ x = x.permute(1, 0, 2)
+ x = x.reshape(2, -1, num_channels // 2, height, width)
+ return x[0], x[1]
+
+
+class ShuffleNetV2(nn.Module):
+ def __init__(self, num_tasks=0, task_types=0, num_classes=[], out_channel=1024, model_size='0.5x', with_last_conv=True,
+ attention='', loc='side', onnx=False, shuffle_block=None,
+ stack_lite_head=0, lite_head_channels=-1):
+ super(ShuffleNetV2, self).__init__()
+ # print('model size is ', model_size)
+ assert len(num_classes) == num_tasks, f"num task must equal to length of classes list for every task"
+
+ self.num_tasks = num_tasks
+ self.use_last_conv = with_last_conv
+ if isinstance(task_types, int):
+ task_types = [task_types] * num_tasks
+ if isinstance(stack_lite_head, int):
+ stack_lite_head = [stack_lite_head] * num_tasks
+ if isinstance(lite_head_channels, int):
+ lite_head_channels = [lite_head_channels] * num_tasks
+ self.task_types = task_types
+ self.stack_lite_head = stack_lite_head
+ self.lite_head_channels = lite_head_channels
+ self.onnx = onnx
+ self.stage_repeats = [4, 8, 4]
+ self.model_size = model_size
+ if model_size == '0.5x':
+ self.stage_out_channels = [-1, 24, 48, 96, 192] + [out_channel]
+ elif model_size == '1.0x':
+ self.stage_out_channels = [-1, 24, 116, 232, 464] + [out_channel]
+ elif model_size == '1.5x':
+ self.stage_out_channels = [-1, 24, 176, 352, 704] + [out_channel]
+ elif model_size == '2.0x':
+ self.stage_out_channels = [-1, 24, 244, 488, 976] + [out_channel]
+ else:
+ raise NotImplementedError
+
+ shuffle_block = ShuffleV2Block if shuffle_block is None else shuffle_block
+ # building first layer
+ input_channel = self.stage_out_channels[1]
+ self.first_conv = nn.Sequential(
+ nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(input_channel),
+ nn.ReLU(inplace=True),
+ )
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.features = []
+ for idxstage in range(len(self.stage_repeats)):
+ numrepeat = self.stage_repeats[idxstage]
+ output_channel = self.stage_out_channels[idxstage + 2]
+
+ for i in range(numrepeat):
+ if i == 0:
+ self.features.append(
+ shuffle_block(input_channel, output_channel, mid_channels=output_channel // 2, ksize=3,
+ stride=2, attention=attention, ratio=16, loc=loc, onnx=onnx))
+ else:
+ self.features.append(
+ shuffle_block(input_channel//2, output_channel, mid_channels=output_channel//2, ksize=3,
+ stride=1, attention=attention, ratio=16, loc=loc, onnx=onnx))
+
+ input_channel = output_channel
+
+ self.features = nn.Sequential(*self.features)
+
+ self.classifier_inchannels = self.stage_out_channels[-2]
+ if with_last_conv:
+ self.classifier_inchannels = self.stage_out_channels[-1]
+ self.conv_last = nn.Sequential(
+ nn.Conv2d(input_channel, self.classifier_inchannels, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(self.stage_out_channels[-1]),
+ nn.ReLU(inplace=True))
+
+ self.globalpool = nn.AdaptiveAvgPool2d(output_size=1)
+
+ self.lite_head_channels = [self.classifier_inchannels if v == -1 else v for v in self.lite_head_channels]
+ for ti in range(self.num_tasks):
+ if self.stack_lite_head[ti]:
+ lite_head = []
+ for j in range(self.stack_lite_head[ti]):
+ ins = self.classifier_inchannels if j == 0 else self.lite_head_channels[ti]
+ outs = self.classifier_inchannels if j == self.stack_lite_head[ti]-1 else self.lite_head_channels[ti]
+ lite_head.append(DepthwiseConvModule(ins, outs, 3, 1, 1))
+ lite_head = nn.Sequential(*lite_head)
+ self.add_module(f"lite_head{ti}", lite_head)
+ classifier = nn.Sequential(nn.Linear(self.classifier_inchannels, num_classes[ti], bias=False))
+ self.add_module(f"classifier{ti}", classifier)
+
+ # self.loss_weights = nn.Parameter(torch.ones(num_tasks), requires_grad=True)
+
+ self._initialize_weights()
+
+ def _forward_impl(self, x):
+ x = self.first_conv(x)
+ x = self.maxpool(x)
+ x = self.features(x)
+
+ if self.use_last_conv:
+ x = self.conv_last(x)
+
+ output = []
+ for ti in range(self.num_tasks):
+ if self.stack_lite_head[ti]:
+ c_x = getattr(self, f"lite_head{ti}")(x)
+ c_x = self.globalpool(c_x)
+ c_x = c_x.contiguous().view(-1, self.classifier_inchannels)
+ c_x = getattr(self, f"classifier{ti}")(c_x)
+ else:
+ c_x = self.globalpool(x)
+ c_x = c_x.contiguous().view(-1, self.classifier_inchannels)
+ c_x = getattr(self, f"classifier{ti}")(c_x)
+
+ if self.onnx:
+ if self.task_types[ti] == 0:
+ c_x = torch.softmax(c_x, dim=1)
+ elif self.task_types[ti] == 1:
+ c_x *= (0.05/3)
+ elif self.task_types[ti] == 2:
+ c_x *= (0.7/3)
+ elif self.task_types[ti] == 3:
+ c_x *= (0.1/2)
+ else:
+ raise NotImplementedError(f"task_type only support [0, 1, 2, 3], current {self.task_types[ti]}")
+
+ output.append(c_x)
+ return output
+
+ def forward(self, x):
+ output = self._forward_impl(x)
+ return output
+
+ def _initialize_weights(self):
+ for name, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_in')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.0001)
+ nn.init.constant_(m.running_mean, 0)
+ elif isinstance(m, nn.BatchNorm1d):
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.0001)
+ nn.init.constant_(m.running_mean, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, mode='fan_in')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def get_last_share_layer(self):
+ assert self.use_last_conv, "Current implement need 'with_last_conv=True'"
+ return self.conv_last[0]
+
+
+if __name__ == '__main__':
+ from models.create_model import speed, model_info
+ from torchstat import stat
+
+ # a = se_resnet101()
+ # b = shufflenet_v2_x0_5()
+ input_size = (160, 160)
+ shufflenet2 = ShuffleNetV2(out_channel=192,
+ num_tasks=10, task_types=1, num_classes=[5, 3, 3, 3, 3, 1, 1, 1, 1, 1], model_size='0.5x', with_last_conv=0,
+ stack_lite_head=1, lite_head_channels=-1, onnx=True)
+ #
+ model_info(shufflenet2, img_size=input_size, verbose=False)
+ speed(shufflenet2, 'shufflenet2', size=input_size, device_type='cpu')
+ stat(shufflenet2, input_size=(3, 160, 160))
+
+
+
+ # shufflenet2.eval()
+ #
+ # # example = torch.randn(1, 3, input_size[1], input_size[0])
+ # # torch.onnx.export(
+ # # shufflenet2, # model being run
+ # # example, # model input (or a tuple for multiple inputs)
+ # '1.onnx',
+ # verbose=False,
+ # # store the trained parameter weights inside the model file
+ # training=False,
+ # input_names=['input'],
+ # output_names=['output'],
+ # do_constant_folding=True
+ # )
diff --git a/utils/__pycache__/common.cpython-38.pyc b/utils/__pycache__/common.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b716008306a87a6b2696e2129983d476f91c5d31
Binary files /dev/null and b/utils/__pycache__/common.cpython-38.pyc differ
diff --git a/utils/__pycache__/images.cpython-38.pyc b/utils/__pycache__/images.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82b805c056b733ca8febb75eaf0933649f4f2d1d
Binary files /dev/null and b/utils/__pycache__/images.cpython-38.pyc differ
diff --git a/utils/__pycache__/labels.cpython-38.pyc b/utils/__pycache__/labels.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c8892ce16d4752b61fd4a6f8e8d5be0d940ecea
Binary files /dev/null and b/utils/__pycache__/labels.cpython-38.pyc differ
diff --git a/utils/__pycache__/multiprogress.cpython-38.pyc b/utils/__pycache__/multiprogress.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d42cb7b78c748df0be9abc6fe9067f2f43a7002
Binary files /dev/null and b/utils/__pycache__/multiprogress.cpython-38.pyc differ
diff --git a/utils/__pycache__/os_util.cpython-38.pyc b/utils/__pycache__/os_util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35d631c10da093412be6a9fa65ba9ff5ba3b05f9
Binary files /dev/null and b/utils/__pycache__/os_util.cpython-38.pyc differ
diff --git a/utils/__pycache__/plt_util.cpython-38.pyc b/utils/__pycache__/plt_util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ee00538ab9ee4631d36c970d8f3e92b06b65cd0
Binary files /dev/null and b/utils/__pycache__/plt_util.cpython-38.pyc differ
diff --git a/utils/__pycache__/time_util.cpython-38.pyc b/utils/__pycache__/time_util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c99d6b94f853e34e9b94b70a726a5511cc921259
Binary files /dev/null and b/utils/__pycache__/time_util.cpython-38.pyc differ
diff --git a/utils/common.py b/utils/common.py
new file mode 100755
index 0000000000000000000000000000000000000000..8b49262b7d79c3521dcf73cb800de5786f2e0cf6
--- /dev/null
+++ b/utils/common.py
@@ -0,0 +1,227 @@
+# -*-coding:utf-8-*-
+import colorsys
+import datetime
+import functools
+import signal
+import sys
+from contextlib import contextmanager
+
+import numpy as np
+
+import time
+
+
+# Generate colors
+def gen_color(n):
+ hsv_tuples = [(x / n, 1., 1.)
+ for x in range(n)]
+ colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
+ colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
+ # np.random.seed(10101) # Fixed seed for consistent colors across runs.
+ np.random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes.
+ # np.random.seed(None) # Reset seed to default.
+ return colors
+
+
+# pylint: disable=W0232
+class Color:
+ GRAY = 30
+ RED = 31
+ GREEN = 32
+ YELLOW = 33
+ BLUE = 34
+ MAGENTA = 35
+ CYAN = 36
+ WHITE = 67
+ CRIMSON = 38
+
+
+# 返回字符串的输出格式码,调整前景色背景色、加粗等
+def colorize(num, string, bold=False, highlight=False):
+ assert isinstance(num, int)
+ attr = []
+ if bold:
+ attr.append('1')
+ if highlight and num == 67:
+ num += 30
+ if highlight and num != 67:
+ num += 60
+ attr.append(str(num))
+ # \x1b[显示方式;前景色;背景色m +"输出内容"+\x1b[0m
+ # ; 的顺序可以改变
+ return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
+
+
+def colorprint(colorcode, text, o=sys.stdout, bold=False, highlight=False, end='\n'):
+ o.write(colorize(colorcode, text, bold=bold, highlight=highlight) + end)
+
+
+def cprint(text, colorcode=67, bold=False, highlight=False, end='\n', prefix=None,
+ pre_color=34, pre_bold=True, pre_high=True, pre_end=': '):
+ prefix = str(prefix).rstrip() if prefix is not None else prefix
+ if prefix is not None:
+ prefix = prefix.rstrip(':') if ':' in pre_end else prefix
+ prefix += pre_end
+ colorprint(pre_color, prefix, sys.stdout, pre_bold, pre_high, end='')
+ colorprint(colorcode, text, bold=bold, highlight=highlight, end=end)
+
+
+def log_warn(msg):
+ cprint(msg, colorcode=33, prefix='Warning', pre_color=33, highlight=True)
+
+
+def log_error(msg):
+ cprint(msg, colorcode=31, prefix='Error', pre_color=31, highlight=True)
+
+
+def now_time():
+ return datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+
+
+# http://stackoverflow.com/questions/366682/how-to-limit-execution-time-of-a-function-call-in-python
+class TimeoutException(Exception):
+ def __init__(self, msg):
+ self.msg = msg
+
+
+# 上下文管理器,装饰函数func, func()中必须有yield.使用方法:
+# with func as f:
+# do()
+# with语句执行yield之前的代码,然后执行f和do,最后执行yield之后的代码
+@contextmanager
+def time_limit(seconds):
+ # 这是一个限制程序运行时间的函数,如果超出预设时间,会报错
+ def signal_handler(signum, frame): # 信号处理函数必须有这两个参数
+ raise TimeoutException(colorize(Color.RED, "Timed out! Retry again ...", highlight=True))
+ signal.signal(signal.SIGALRM, signal_handler) # 接收信号,signal.SIGALRM 为信号,signal_handler处理信号的函数
+ signal.alarm(seconds) # 如果seconds是非0,这个函数则响应一个SIGALRM信号并在seconds秒后发送到该进程。
+ try:
+ yield
+ finally:
+ signal.alarm(0) # 中断信号
+
+
+def clock(func): # 加强版
+ @functools.wraps(func) # 把func的属性复制给clocked,避免func的name变成clocked
+ def clocked(*args, **kwargs): # 内部函数可接受任意个参数,以及支持关键字参数
+ t0 = time.perf_counter()
+ result = func(*args, **kwargs) # clocked闭包中包含自由变量func
+ elapsed = time.perf_counter() - t0
+ name = func.__name__
+ arg_lst = []
+ if args:
+ arg_lst.append(', '.join(repr(arg) for arg in args))
+ if kwargs:
+ pairs = ['%s=%r' % (k, w) for k, w in sorted(kwargs.items())]
+ arg_lst.append(', '.join(pairs))
+ arg_str = ', '.join(arg_lst)
+ colorprint(Color.GREEN, '[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
+ return result
+ return clocked # 返回clocked,取代被装饰的函数
+
+
+DEFAULT_FMT = '[{elapsed:0.8f}s] {name}({args}) -> {result}'
+
+
+def clock_custom(fmt=DEFAULT_FMT, color=Color.GREEN):
+ def decorate(func): # 加强版
+ @functools.wraps(func) # 把func的属性复制给clocked,避免func的name变成clocked
+ def clocked(*_args, **_kwargs): # 内部函数可接受任意个参数,以及支持关键字参数
+ t0 = time.perf_counter()
+ _result = func(*_args, **_kwargs) # clocked闭包中包含自由变量func
+ elapsed = time.perf_counter() - t0
+ name = func.__name__
+ _arg_lst = []
+ if _args:
+ _arg_lst.append(', '.join(repr(arg) for arg in _args))
+ if _kwargs:
+ _pairs = ['%s=%r' % (k, w) for k, w in sorted(_kwargs.items())]
+ _arg_lst.append(', '.join(_pairs))
+ args = ', '.join(_arg_lst)
+ result = repr(_result) # 字符串形式
+ colorprint(color, fmt.format(**locals())) # 使用clocked的局部变量
+ return _result # 返回原函数的结果
+ return clocked # 返回clocked,取代被装饰的函数
+ return decorate # 装饰器工厂必须返回装饰器
+
+
+def colorstr(*input):
+ # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
+ *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
+ colors = {'black': '\033[30m', # basic colors
+ 'red': '\033[31m',
+ 'green': '\033[32m',
+ 'yellow': '\033[33m',
+ 'blue': '\033[34m',
+ 'magenta': '\033[35m',
+ 'cyan': '\033[36m',
+ 'white': '\033[37m',
+ 'bright_black': '\033[90m', # bright colors
+ 'bright_red': '\033[91m',
+ 'bright_green': '\033[92m',
+ 'bright_yellow': '\033[93m',
+ 'bright_blue': '\033[94m',
+ 'bright_magenta': '\033[95m',
+ 'bright_cyan': '\033[96m',
+ 'bright_white': '\033[97m',
+ 'end': '\033[0m', # misc
+ 'bold': '\033[1m',
+ 'underline': '\033[4m'}
+ return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
+
+
+class Logger:
+ def __init__(self, file_path):
+ self.log_file_path = file_path
+ self.log_file = None
+ self.color = {
+ 'R': Color.RED,
+ 'B': Color.BLUE,
+ 'G': Color.GREEN,
+ 'Y': Color.YELLOW
+ }
+
+ def start(self):
+ self.log_file = open(self.log_file_path, 'w', encoding='utf-8')
+
+ def close(self):
+ if self.log_file is not None:
+ self.log_file.close()
+ return
+
+ def info(self, text, color='W', prefix=None, pre_color='B', pre_end=': ', prints=True):
+ assert self.log_file is not None, "Please firstly confirm 'logger.start()' method"
+ color_code = self.color.get(color, Color.WHITE)
+ pre_color = self.color.get(pre_color, Color.BLUE)
+ prefix = str(prefix).rstrip() if prefix is not None else prefix
+ if prefix is not None:
+ prefix = prefix.rstrip(':') if ':' in pre_end else prefix
+ prefix += pre_end
+ self.log_file.write(f"{prefix if prefix is not None else ''}{text}\n")
+ if prints:
+ cprint(text, color_code, prefix=prefix, pre_color=pre_color)
+
+ def error(self, text):
+ assert self.log_file is not None, "Please firstly confirm 'logger.start()' method"
+ # self.log_file.write(text + '\n')
+ log_error(text)
+
+ def warn(self, text):
+ assert self.log_file is not None, "Please firstly confirm 'logger.start()' method"
+ # self.log_file.write(text + '\n')
+ log_warn(text)
+
+
+def r(val):
+ return int(np.random.random() * val)
+
+
+# if __name__ == '__main__':
+# #print(colorize(31, 'I am fine!', bold=True, highlight=True))
+# #colorprint(35, 'I am fine!')
+# #error('get out')
+# import time
+#
+# # ends after 5 seconds
+# with time_limit(2):
+# time.sleep(3)
diff --git a/utils/export_util.py b/utils/export_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..d56385289353d9ac1b1dcee606f9347ccad8da14
--- /dev/null
+++ b/utils/export_util.py
@@ -0,0 +1,118 @@
+# -*- coding:utf-8 –*-
+# tensorflow=2.5.0 onnx-tf=1.9.0
+import os
+import time
+import onnxruntime
+# import tensorflow as tf
+
+
+def onnx_export_pipeline(onnx_path, export_func, simplify=True, onnx_sim_path=None, pack=True, view_net=True):
+ if not os.path.exists(os.path.dirname(onnx_path)):
+ os.makedirs(os.path.dirname(onnx_path))
+
+ # 导出
+ export_func(export_name=onnx_path) # 转化onnx的函数,与各自的配置密切相关
+ if not simplify:
+ print(f"Onnx model export in '{onnx_path}'")
+
+ if simplify: # 算子简化
+ time.sleep(1)
+ onnx_sim_path = onnx_sim_path if onnx_sim_path else onnx_path
+ os.system(f"python -m onnxsim {onnx_path} {onnx_sim_path}")
+ print(f"Simplify onnx model export in '{onnx_sim_path}'")
+
+ if pack: # 压缩打包
+ time.sleep(1)
+ src_path = onnx_sim_path if simplify else onnx_path
+ src_dir, src_name = os.path.dirname(src_path), os.path.basename(src_path)
+ os.system(f'cd {src_dir} && tar -zcf {src_name}.tgz {src_name}')
+ print(f"TGZ file save in '{src_path}.tgz'")
+
+ if view_net: # 查看网络结构
+ import netron
+ src_path = onnx_sim_path if simplify else onnx_path
+ netron.start(src_path)
+
+
+class ONNXModel:
+ def __init__(self, onnx_path):
+ """
+ :param onnx_path:
+ """
+ self.onnx_session = onnxruntime.InferenceSession(onnx_path)
+ self.input_name = self.get_input_name(self.onnx_session)
+ self.output_name = self.get_output_name(self.onnx_session)
+ print(f"loading {onnx_path}")
+
+ @staticmethod
+ def get_output_name(onnx_session):
+ """
+ output_name = onnx_session.get_outputs()[0].name
+ :param onnx_session:
+ :return:
+ """
+ output_name = []
+ for node in onnx_session.get_outputs():
+ output_name.append(node.name)
+ return output_name
+
+ @staticmethod
+ def get_input_name(onnx_session):
+ """
+ input_name = onnx_session.get_inputs()[0].name
+ :param onnx_session:
+ :return:
+ """
+ input_name = []
+ for node in onnx_session.get_inputs():
+ input_name.append(node.name)
+ return input_name
+
+ @staticmethod
+ def get_input_feed(input_name, image_numpy):
+ """
+ input_feed={self.input_name: image_numpy}
+ :param input_name:
+ :param image_numpy:
+ :return:
+ """
+ input_feed = {}
+ for name in input_name:
+ input_feed[name] = image_numpy
+ return input_feed
+
+ def forward(self, image_numpy):
+ '''
+ # image_numpy = image.transpose(2, 0, 1)
+ # image_numpy = image_numpy[np.newaxis, :]
+ # onnx_session.run([output_name], {input_name: x})
+ # :param image_numpy:
+ # :return:
+ '''
+ input_feed = self.get_input_feed(self.input_name, image_numpy)
+ outputs = self.onnx_session.run(self.output_name, input_feed=input_feed)
+ return outputs
+
+
+def to_numpy(tensor):
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor
+
+
+def export_tflite(onnx_path, tflite_path):
+ save_model_dir = os.path.splitext(tflite_path)[0]
+ onnx_to_pb_cmd = f'onnx-tf convert -i {onnx_path} -o {save_model_dir}'
+ os.system(onnx_to_pb_cmd)
+ print(f"Convert onnx model to pb in '{save_model_dir}'")
+
+ time.sleep(1)
+ # make a converter object from the saved tensorflow file
+ converter = tf.lite.TFLiteConverter.from_saved_model(save_model_dir)
+ # tell converter which type of optimization techniques to use
+ # to view the best option for optimization read documentation of tflite about optimization
+ # go to this link https://www.tensorflow.org/lite/guide/get_started#4_optimize_your_model_optional
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ # convert the model
+ tf_lite_model = converter.convert()
+ # save the converted model
+ open(tflite_path, 'wb').write(tf_lite_model)
+ print(f"TFlite model save in '{tflite_path}'")
diff --git a/utils/images.py b/utils/images.py
new file mode 100755
index 0000000000000000000000000000000000000000..39406246a78e18ca89ade86abd112b68d73a2118
--- /dev/null
+++ b/utils/images.py
@@ -0,0 +1,539 @@
+# -*- coding:utf-8 –*-
+import random
+from math import *
+
+import cv2
+from PIL import ImageDraw, Image, ImageFont
+
+from utils.common import *
+from utils.multiprogress import MultiThreading
+
+
+def show(image, name="Press 'q' exit"):
+ if not hasattr(image, 'shape'):
+ img_array = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+ else:
+ img_array = image.copy()
+ cv2.imshow(name, img_array)
+ cv2.moveWindow(name, 0, 0)
+ key = cv2.waitKey(0) & 0xEFFFFF
+ if key == ord('q'):
+ cv2.destroyWindow(name)
+
+
+def crop_face_square(img, box, pad=5):
+ img_h, img_w, img_c = img.shape
+ if box != [0, 0, 0, 0]:
+ x0, y0, x1, y1 = box
+ w, h = x1-x0, y1-y0
+ # 扩充为正方形
+ long_side = min(max(h, w) + int(2*pad), img_h)
+ w_add, h_add = (long_side - w) // 2, (long_side - h) // 2
+ crop_x0, crop_y0 = max(x0 - w_add, 0), max(y0 - h_add, 0)
+ crop_x1, crop_y1 = min(crop_x0 + long_side, img_w), min(crop_y0 + long_side, img_h)
+ crop_x0, crop_y0 = crop_x1 - long_side, crop_y1 - long_side
+ return img[crop_y0:crop_y1, crop_x0:crop_x1], [crop_x0, crop_y0, crop_x1, crop_y1]
+ else:
+ # print('No detected box, crop right rect.')
+ if img_h == 960:
+ img = img[60:780, :]
+ img_h, img_w = img.shape[:2]
+ return img[:, img_w - img_h:img_w], [0, 0, 0, 0]
+
+
+def crop_face_square_rate(img, box, rate=0.1):
+ img_h, img_w, img_c = img.shape
+ if box != [0, 0, 0, 0]:
+ x0, y0, x1, y1 = box
+ w, h = x1-x0, y1-y0
+ # 扩充为正方形
+ pad = max(w, h) * rate
+ long_side = min(max(h, w) + int(2*pad), img_h)
+ w_add, h_add = (long_side - w) // 2, (long_side - h) // 2
+ crop_x0, crop_y0 = max(x0 - w_add, 0), max(y0 - h_add, 0)
+ crop_x1, crop_y1 = min(crop_x0 + long_side, img_w), min(crop_y0 + long_side, img_h)
+ crop_x0, crop_y0 = crop_x1 - long_side, crop_y1 - long_side
+ return img[crop_y0:crop_y1, crop_x0:crop_x1], [crop_x0, crop_y0, crop_x1, crop_y1]
+ else:
+ if img_h == 960:
+ crop_x0, crop_x1 = img_w - 720, img_w
+ crop_y0, crop_y1 = 60, 780
+ else:
+ crop_x0, crop_x1 = img_w - img_h, img_w
+ crop_y0, crop_y1 = 0, img_h
+ return img[crop_y0:crop_y1, crop_x0:crop_x1], [crop_x0, crop_y0, crop_x1, crop_y1]
+
+
+def expand_box_rate(img, box, rate=0.1):
+ img_h, img_w, img_c = img.shape
+ if box != [0, 0, 0, 0]:
+ x0, y0, x1, y1 = box
+ w, h = x1-x0, y1-y0
+ # 扩充为正方形
+ pad = max(w, h) * rate
+ long_side = min(max(h, w) + int(2*pad), img_h)
+ w_add, h_add = (long_side - w) // 2, (long_side - h) // 2
+ crop_x0, crop_y0 = x0 - w_add, y0 - h_add
+ crop_x1, crop_y1 = crop_x0 + long_side, crop_y0 + long_side
+ return [crop_x0, crop_y0, crop_x1, crop_y1]
+ else:
+ if img_h == 960:
+ crop_x0, crop_x1 = img_w - 720, img_w
+ crop_y0, crop_y1 = 60, 780
+ else:
+ crop_x0, crop_x1 = img_w - img_h, img_w
+ crop_y0, crop_y1 = 0, img_h
+ return [crop_x0, crop_y0, crop_x1, crop_y1]
+
+
+def crop_with_pad(img, crop_box):
+ img_h, img_w, img_c = img.shape
+ x0, y0, x1, y1 = crop_box
+ w, h = x1 - x0, y1 - y0
+ if tuple(crop_box) == (0, 0, 0, 0) or w <= 50 or h <= 50: # 背景,裁右半图
+ if img_h == 960:
+ crop_x0, crop_x1 = img_w - 720, img_w
+ crop_y0, crop_y1 = 60, 780
+ else:
+ crop_x0, crop_x1 = img_w - img_h, img_w
+ crop_y0, crop_y1 = 0, img_h
+ crop_img = img[crop_y0:crop_y1, crop_x0:crop_x1]
+ return crop_img
+ else:
+ crop_x0, crop_y0 = max(x0, 0), max(y0, 0)
+ crop_x1, crop_y1 = min(x1, img_w), min(y1, img_h)
+ left, top, right, bottom = crop_x0 - x0, crop_y0 - y0, x1 - crop_x1, y1 - crop_y1
+ crop_img = img[crop_y0:crop_y1, crop_x0:crop_x1]
+ crop_img = cv2.copyMakeBorder(crop_img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128))
+ return crop_img
+
+
+def clip_paste_rect(bg_img, fg_img, loc, refine=True):
+ """
+ 选取粘贴区域,不要超出背景范围
+ loc: 中心点(cx, cy)
+ """
+ bg_h, bg_w = bg_img.shape[:2]
+ fg_h, fg_w = fg_img.shape[:2]
+ fg_h = fg_h - 1 if fg_h % 2 else fg_h
+ fg_w = fg_w - 1 if fg_w % 2 else fg_w
+ cx, cy = loc
+ left, top, right, bottom = cx - fg_w // 2, cy - fg_h // 2, cx + fg_w // 2, cy + fg_h // 2
+
+ if refine:
+ right, bottom = min(right, bg_w), min(bottom, bg_h)
+ left, top = right - fg_w, bottom - fg_h
+ left, top = max(0, left), max(0, top)
+ right, bottom = left + fg_w, top + fg_h
+
+ plot_x1, plot_y1, plot_x2, plot_y2 = left, top, right, bottom
+ use_x1, use_y1, use_x2, use_y2 = 0, 0, fg_w, fg_h
+
+ if left < 0:
+ plot_x1, use_x1 = 0, -left
+ if top < 0:
+ plot_y1, use_y1 = 0, -top
+ if right > bg_w:
+ plot_x2, use_x2 = bg_w, fg_w - (right - bg_w)
+ if bottom > bg_h:
+ plot_y2, use_y2 = bg_h, fg_h - (bottom - bg_h)
+
+ use_bg = bg_img[plot_y1:plot_y2, plot_x1:plot_x2]
+ use_fg = fg_img[use_y1:use_y2, use_x1:use_x2]
+ window = (plot_x1, plot_y1, plot_x2, plot_y2)
+
+ return use_bg, use_fg, window
+
+
+# @clock_custom(fmt=DEFAULT_FMT)
+def paste(bg_img, fg_img, loc, trans_thresh=1, refine=True):
+ """
+ 贴图
+ loc: center (cx, cy)
+ """
+ use_bg, use_fg, window = clip_paste_rect(bg_img, fg_img, loc, refine)
+ plot_x1, plot_y1, plot_x2, plot_y2 = window
+ b, g, r, a = cv2.split(use_fg)
+ a[a > 0] = 255
+ a = np.dstack([a, a, a]) * trans_thresh
+ use_bg = use_bg * (255.0 - a) / 255
+ use_bg += use_fg[:, :, :3] * (a / 255)
+ bg_img[plot_y1:plot_y2, plot_x1:plot_x2] = use_bg.astype('uint8')
+ return bg_img
+
+
+def put_chinese_text(img, text, position, font, text_color=(0, 255, 0)):
+ if isinstance(img, np.ndarray): # 判断是否OpenCV图片类型
+ img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ # 创建一个可以在给定图像上绘图的对象
+ draw = ImageDraw.Draw(img)
+ draw.text(position, text, text_color, font=font)
+ # 转换回OpenCV格式
+ return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
+
+
+def compute_mean_std(img_list, preprocess=None, workers=12):
+ """
+ 计算公式:
+ S^2 = sum((x - x')^2) / N = sum(x^2 + x'^2 - 2xx') / N
+ = (sum(x^2) + sum(x'^2) - 2x'*sum(x)) / N
+ = (sum(x^2) + N*(x'^2) - 2x'*(N * x')) / N
+ = (sum(x^2) - N * (x'^2)) / N
+ = sum(x^2) / N - (sum(x) / N)^2
+ = mean(x^2) - mean(x)^2 = E(x^2) - E(x)^2
+ :param img_list:
+ :param workers:
+ :param preprocess: 图像预处理函数,需要有BGR->RGB以及归一化操作
+ :return: RGB means, stds
+ """
+
+ def cal_func(info):
+ i, img_path = info
+ img = cv2.imread(img_path)
+ if preprocess:
+ img = preprocess(img)
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = img.astype(np.float32) / 255
+ img = img.reshape((-1, 3))
+ channel_mean = np.mean(img, axis=0)
+ channel_var = np.mean(img**2, axis=0)
+ if i % 1000 == 0:
+ print(f"{i}/{len(img_list)}")
+ return channel_mean, channel_var
+
+ exe = MultiThreading(list(enumerate(img_list)), workers=min(workers, len(img_list)))
+ res = exe.run(cal_func)
+ all_means = np.array([r[0] for r in res])
+ all_vars = np.array([r[1] for r in res])
+ means = np.mean(all_means, axis=0)
+ vars = np.mean(all_vars, axis=0)
+ stds = np.sqrt(vars - means ** 2)
+
+ return means, stds
+
+
+def draw_province(f, val):
+ img = Image.new("RGB", (45, 70), (255, 255, 255))
+ draw = ImageDraw.Draw(img)
+ draw.text((0, 3), val, (0, 0, 0), font=f)
+ img = img.resize((23, 70))
+ char = np.array(img)
+ return char
+
+
+def draw_chars(f, val):
+ img = Image.new("RGB", (23, 70), (255, 255, 255))
+ draw = ImageDraw.Draw(img)
+ draw.text((0, 2), val, (0, 0, 0), font=f)
+ A = np.array(img)
+ return A
+
+
+# 双层车牌
+def draw_province_double(f, val):
+ img = Image.new("RGB", (60, 60), (255, 255, 255))
+ draw = ImageDraw.Draw(img)
+ draw.text((0, -12), val, (0, 0, 0), font=f)
+ img = img.resize((80, 60))
+ # img.show()
+ char = np.array(img)
+ return char
+
+
+def draw_chars_ceil(f, val):
+ img = Image.new("RGB", (30, 45), (255, 255, 255))
+ draw = ImageDraw.Draw(img)
+ draw.text((1, -12), val, (0, 0, 0), font=f)
+ img = img.resize((80, 60))
+ # img.show()
+ char = np.array(img)
+ return char
+
+
+def draw_chars_floor(f, val):
+ img = Image.new("RGB", (30, 45), (255, 255, 255))
+ draw = ImageDraw.Draw(img)
+ draw.text((1, -12), val, (0, 0, 0), font=f)
+ img = img.resize((65, 110))
+ # img.show()
+ char = np.array(img)
+ return char
+
+
+# 图片做旧处理
+def add_smudginess(img):
+ smu = cv2.imread('data/bgp/smu.jpg')
+ img_h, img_w = img.shape[:2]
+ rows = r(smu.shape[0] - img_h)
+ cols = r(smu.shape[1] - img_w)
+ adder = smu[rows:rows + img_h, cols:cols + img_w]
+ adder = cv2.resize(adder, (img_w, img_h))
+ adder = cv2.bitwise_not(adder)
+ val = random.random() * 0.5
+ img = cv2.addWeighted(img, 1 - val, adder, val, 0.0)
+ return img
+
+
+def rot(img, angel, shape, max_angel):
+ """ 使图像轻微的畸变
+ img 输入图像
+ factor 畸变的参数
+ size 为图片的目标尺寸
+ """
+ size_o = [shape[1], shape[0]]
+ size = (shape[1] + int(shape[0] * sin((float(max_angel)/180) * 3.14)), shape[0])
+ interval = abs(int(sin((float(angel) / 180) * 3.14) * shape[0]))
+ pts1 = np.float32([[0, 0], [0, size_o[1]], [size_o[0], 0], [size_o[0], size_o[1]]])
+ if angel > 0:
+ pts2 = np.float32([[interval, 0], [0, size[1]], [size[0], 0], [size[0]-interval, size_o[1]]])
+ else:
+ pts2 = np.float32([[0, 0], [interval, size[1]], [size[0]-interval, 0], [size[0], size_o[1]]])
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+ dst = cv2.warpPerspective(img, M, size)
+ return dst
+
+
+def random_rot(img, factor, size):
+ shape = size
+ pts1 = np.float32([[0, 0], [0, shape[0]], [shape[1], 0], [shape[1], shape[0]]])
+ pts2 = np.float32([[r(factor), r(factor)],
+ [r(factor), shape[0] - r(factor)],
+ [shape[1] - r(factor), r(factor)],
+ [shape[1] - r(factor), shape[0] - r(factor)]])
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+ dst = cv2.warpPerspective(img, M, size)
+ return dst
+
+
+def random_rot_expend(img, factor, size):
+ height, width = size
+ degree = factor
+ # 旋转后的尺寸
+ n_h = int(width * fabs(sin(radians(degree))) + height * fabs(cos(radians(degree))))
+ n_w = int(height * fabs(sin(radians(degree))) + width * fabs(cos(radians(degree))))
+ M = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1)
+ M[0, 2] += (n_w - width) / 2 # 重点在这步,目前不懂为什么加这步
+ M[1, 2] += (n_h - height) / 2 # 重点在这步
+ rot_img = cv2.warpAffine(img, M, (n_w, n_h), borderValue=(0, 0, 0))
+ return rot_img
+
+
+def random_rot_keep_size_left_right(img, factor1, factor2, size):
+ # 透视变换 factor1 大于0 factor2 可正可负
+ shape = size
+ width = shape[0]
+ height = shape[1]
+ pts1 = np.float32([[0, 0], [width - 1, 0], [0, height-1], [width - 1, height-1]])
+ point1_x = 0
+ point1_y = 0
+ point2_x = width - r(factor1)
+ point2_y = r(factor2) # shape[0] - r(factor)
+ point3_x = 0 # shape[1] - r(factor)
+ point3_y = height # r(factor)
+ point4_x = width - r(factor1) # shape[1]
+ point4_y = height - r(factor2) # shape[0]
+ max_x = max(point2_x, point4_x)
+ max_y = point3_y
+ if factor2 < 0:
+ point1_x = 0
+ point1_y = 0 - r(factor2)
+ point2_x = width - r(factor1)
+ point2_y = 0
+ point3_x = 0
+ point3_y = height + r(factor2)
+ point4_x = width - r(factor1)
+ point4_y = height
+ max_x = max(point2_x, point4_x)
+ max_y = point4_y
+ pts2 = np.float32([[point1_x, point1_y],
+ [point2_x, point2_y],
+ [point3_x, point3_y],
+ [point4_x, point4_y]])
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+ size2 = (max_x, max_y)
+ dst = cv2.warpPerspective(img, M, size2) # cv2.warpPerspective(img, M, size)
+ return dst
+
+
+# 仿射变换
+def affine_transform(img, factor, size):
+ shape = size
+ pts1 = np.float32([[0, shape[0]], [shape[1], 0], [shape[1], shape[0]]])
+ pts2 = np.float32([[r(factor), shape[0] - r(factor)],
+ [shape[1] - r(factor), r(factor)],
+ [shape[1], shape[0]]]) # [shape[1] - r(factor), shape[0] - r(factor)]])
+ M = cv2.getAffineTransform(pts1, pts2)
+ dst = cv2.warpAffine(img, M, size)
+ return dst
+
+
+# 腐蚀
+def cv_erode(img, factor):
+ value = r(factor)+1
+ kernel = np.ones((value, value), np.uint8) * factor
+ erosion = cv2.erode(img, kernel, iterations=1)
+ return erosion
+
+
+# 膨胀
+def cv_dilate(img, factor):
+ value = r(factor)+1
+ kernel = np.ones((value, value), np.uint8)
+ dilate = cv2.dilate(img, kernel, iterations=1)
+ return dilate
+
+
+def add_random_noise(img, factor):
+ value = r(factor)
+ for k in range(value): # Create 5000 noisy pixels
+ i = random.randint(0, img.shape[0] - 1)
+ j = random.randint(0, img.shape[1] - 1)
+ color = (random.randrange(256), random.randrange(256), random.randrange(256))
+ img[i, j] = color
+ return img
+
+
+# 生成卷积核和锚点
+def gen_kernel_anchor(length, angle):
+ half = length / 2
+ EPS = np.finfo(float).eps
+ alpha = (angle - floor(angle / 180) * 180) / 180 * pi
+ cosalpha = cos(alpha)
+ sinalpha = sin(alpha)
+ if cosalpha < 0:
+ xsign = -1
+ elif angle == 90:
+ xsign = 0
+ else:
+ xsign = 1
+ psfwdt = 1
+ # 模糊核大小
+ sx = int(fabs(length * cosalpha + psfwdt * xsign - length * EPS))
+ sy = int(fabs(length * sinalpha + psfwdt - length * EPS))
+ psf1 = np.zeros((sy, sx))
+ # psf1是左上角的权值较大,越往右下角权值越小的核。
+ # 这时运动像是从右下角到左上角移动
+ for i in range(0, sy):
+ for j in range(0, sx):
+ psf1[i][j] = i * fabs(cosalpha) - j * sinalpha
+ rad = sqrt(i * i + j * j)
+ if rad >= half and fabs(psf1[i][j]) <= psfwdt:
+ temp = half - fabs((j + psf1[i][j] * sinalpha) / cosalpha)
+ psf1[i][j] = sqrt(psf1[i][j] * psf1[i][j] + temp * temp)
+ psf1[i][j] = psfwdt + EPS - fabs(psf1[i][j])
+ if psf1[i][j] < 0:
+ psf1[i][j] = 0
+ # 运动方向是往左上运动,锚点在(0,0)
+ anchor = (0, 0)
+ # 运动方向是往右上角移动,锚点一个在右上角
+ # 同时,左右翻转核函数,使得越靠近锚点,权值越大
+ if 0 < angle < 90:
+ psf1 = np.fliplr(psf1)
+ anchor = (psf1.shape[1] - 1, 0)
+ elif -90 < angle < 0: # 同理:往右下角移动
+ psf1 = np.flipud(psf1)
+ psf1 = np.fliplr(psf1)
+ anchor = (psf1.shape[1] - 1, psf1.shape[0] - 1)
+ elif angle < -90: # 同理:往左下角移动
+ psf1 = np.flipud(psf1)
+ anchor = (0, psf1.shape[0] - 1)
+ psf1 = psf1 / psf1.sum()
+ return psf1, anchor
+
+
+def hsv_transform(img):
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
+ # hsv[:, :, 0] = hsv[:, :, 0] * (0.2 + np.random.random() * 0.8)
+ hsv[:, :, 1] = hsv[:, :, 1] * (0.5 + np.random.random()*0.5)
+ hsv[:, :, 2] = hsv[:, :, 2] * (0.5 + np.random.random()*0.5)
+ img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
+ return img
+
+
+def kth_diag_indices(array, k=0, n=-1):
+ """
+ 第k个对角线索引
+ :param array: 输入二维矩阵
+ :param k: k=0 主对角线 k>0 右移 k<0 左移
+ :param n: n=-1, 对角线所有元素,否则只返回长为n的列表
+ :return: 索引
+ """
+ if n == -1:
+ rows, cols = np.diag_indices_from(array)
+ else:
+ rows, cols = np.diag_indices(n)
+ if k < 0:
+ return rows[-k:], cols[:k]
+ elif k > 0:
+ return rows[:-k], cols[k:]
+ else:
+ return rows, cols
+
+
+def slash_mask(mask, start_index, end_index, set_value=0.0, mod='l', length=-1):
+ """
+ 制作斜条纹的蒙版
+ :param mask:
+ :param start_index:
+ :param end_index:
+ :param set_value:
+ :param mod: 左斜还是右斜
+ :param length: 长度
+ :return:
+ """
+ h, w = mask.shape[:2]
+ assert length <= min(h, w)
+ if mod == 'r':
+ mask = np.fliplr(mask)
+ for i in range(start_index, end_index+1):
+ mask[kth_diag_indices(mask, i, length)] = set_value
+ if mod == 'r':
+ mask = np.fliplr(mask)
+ return mask
+
+
+def line_mask(mask, start_index, end_index, set_value=0.0, mod='h', length=-1):
+ """
+ 制作条纹蒙版
+ :param mask:
+ :param start_index:
+ :param end_index:
+ :param set_value:
+ :param mod: h 横 v 竖
+ :param length: 长度
+ :return:
+ """
+ h, w = mask.shape[:2]
+ if mod == 'h':
+ assert length <= w
+ assert 0 <= start_index < end_index < h
+ if length == -1 or length == w:
+ mask[start_index: end_index+1, :] = set_value
+ else:
+ left = random.randint(0, w-length-1)
+ right = left + length
+ mask[start_index: end_index + 1, left:right] = set_value
+ else:
+ assert length <= h
+ assert 0 <= start_index < end_index < w
+ if length == -1 or length == h:
+ mask[:, start_index: end_index + 1] = set_value
+ else:
+ top = random.randint(0, h - length - 1)
+ bottom = top + length
+ mask[top:bottom + 1, start_index: end_index] = set_value
+ return mask
+
+
+def read_binary_images(file_path):
+ img = np.fromfile(file_path, dtype=np.uint8)
+ img = np.reshape(img, (720, 1280, 3))
+ cv2.imshow('img', img)
+ cv2.waitKey(0)
+ print()
+
+
+
+if __name__ == '__main__':
+ read_binary_images('/Users/didi/Desktop/座次/pic/bgr1641010923242.bgr')
\ No newline at end of file
diff --git a/utils/labels.py b/utils/labels.py
new file mode 100644
index 0000000000000000000000000000000000000000..a84bb725cdea5120f25075ade96624a38328b6ac
--- /dev/null
+++ b/utils/labels.py
@@ -0,0 +1,873 @@
+import copy
+import numbers
+import os
+import torch
+import numpy as np
+import random
+import pandas as pd
+from utils.common import cprint, Color, log_warn, log_error
+import matplotlib.pyplot as plt
+
+pd.set_option('display.max_columns', None)
+pd.set_option('display.max_rows', None)
+pd.set_option('max_colwidth', 100)
+pd.set_option('display.width', 5000)
+pd.options.mode.chained_assignment = None # default='warn'
+plt.rcParams['axes.unicode_minus'] = False # 正常显示正负号
+
+
+def labels_to_class_weights(labels, nc=80):
+ # Get class weights (inverse frequency) from training labels
+ # 从训练标签计算每一类的权重,次数出现越少的类别越重要,对应weights越大。没有出现时,值为1
+ if labels[0] is None: # no labels loaded
+ return torch.Tensor()
+
+ labels = np.array([l[1] for l in labels if l[1] != -1], dtype=np.int)
+ # labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
+ # classes = labels[:, 0].astype(np.int) # labels = [class xywh]
+ classes = labels
+ weights = np.bincount(classes, minlength=nc) # occurrences per class,计算每一类出现的次数
+
+ # Prepend gridpoint count (for uCE training)
+ # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
+ # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
+
+ weights[weights == 0] = 1 # replace empty bins with 1
+ weights = 1 / weights # number of targets per class
+ weights /= weights.sum() # normalize
+ return torch.from_numpy(weights)
+
+
+def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
+ # Produces image weights based on class_weights and image contents
+ class_counts = np.array([np.bincount(np.array([x[1]], dtype=np.int), minlength=nc) for x in labels])
+ image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
+ # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
+ return image_weights
+
+
+def labels_to_class_weights_mtl(labels, num_classes=[3, 3]):
+ # Get class weights (inverse frequency) from training labels
+ # 从训练标签计算每一类的权重,次数出现越少的类别越重要,对应weights越大。没有出现时,值为1
+ if labels[0] is None: # no labels loaded
+ return torch.Tensor()
+
+ labels = np.array([l[1] for l in labels], dtype=np.int)
+ # labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
+ # classes = labels[:, 0].astype(np.int) # labels = [class xywh]
+ all_weights = []
+ for i, nc in enumerate(num_classes):
+ cur_labels = np.array([l[1][i] for l in labels if l[1][i] != -1], dtype=np.int)
+ weights = np.bincount(cur_labels, minlength=nc) # occurrences per class,计算每一类出现的次数
+ weights[weights == 0] = 1 # replace empty bins with 1
+ weights = 1 / weights # number of targets per class
+ weights /= weights.sum() # normalize
+ all_weights.append(weights)
+ return all_weights
+
+
+def labels_to_image_weights_mtl(labels, num_classes=[3, 3], class_weights=[]):
+ # Produces image weights based on class_weights and image contents
+ labels = np.array([l[1] for l in labels], dtype=np.int)
+ class_weights = [weights / weights.sum() for weights in class_weights]
+ all_image_weights = []
+ for i, nc in enumerate(num_classes):
+ class_counts = np.array(
+ [np.bincount(np.array([x], dtype=np.int), minlength=nc) for x in np.squeeze(labels[:, i])])
+ image_weights = (class_weights[i].reshape(1, nc) * class_counts).sum(1)
+ all_image_weights.append(np.squeeze(image_weights))
+ # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
+ # return np.array(all_image_weights).mean(0)
+ return all_image_weights[0]
+
+
+def sort_mtl_labels(inputs):
+ if inputs[0] is None: # no labels loaded
+ return []
+ random.shuffle(inputs)
+ ori_labels = np.array([l[1] for l in inputs], dtype=np.float)
+ labels = ori_labels.copy()
+ labels[labels >= 0] = 1
+ labels[labels < 0] = 0
+ counts = (np.sum(labels, axis=1) - 1) * 100
+ n, c = labels.shape[:2]
+ for i in range(c):
+ labels[:, i][labels[:, i] > 0] = c - i
+ res = (np.sum(labels, axis=1) + counts).squeeze()
+ idx = np.argsort(-res).tolist()
+ ori_labels = ori_labels[idx]
+ return idx
+
+
+def load_labels(label_file, classes=None, prints=True):
+ label_dict = {}
+ if not os.path.exists(label_file):
+ return label_dict
+ with open(label_file) as f:
+ for line in f:
+ info = line.strip().split(' ')
+ name = info[0]
+ labels = list(map(eval, info[1:]))
+ if len(labels) == 1:
+ label_dict[name] = labels[0]
+ else:
+ label_dict[name] = labels
+
+ if prints:
+ print(f"load {len(label_dict)} labels from '{label_file}'")
+
+ if classes is not None:
+ summary_labels(label_dict, classes)
+
+ return label_dict
+
+
+def load_labels_batch(file_lst, classes=None):
+ if isinstance(file_lst, str):
+ file_lst = [file_lst]
+ label_dict = {}
+ for i, file in enumerate(file_lst):
+ dct = load_labels(file, classes)
+ label_dict.update(dct)
+ print(f'total get {len(label_dict.keys())} items')
+
+ if classes is not None:
+ summary_labels(label_dict, classes)
+
+ return label_dict
+
+
+def save_labels(label_file, label_dict, mod='w'):
+ with open(label_file, mod) as f:
+ for img_name, labels in label_dict.items():
+ if isinstance(labels, (int, float, str)):
+ f.write(f"{img_name} {labels}\n")
+ elif isinstance(labels, (tuple, list)):
+ labels = list(map(str, labels))
+ f.write(' '.join([img_name] + labels) + '\n')
+ else:
+ raise NotImplementedError
+ print(f"Save {len(label_dict)} annotation in '{label_file}'")
+
+
+def merge_labels(file_list, save_file):
+ label_dict = load_labels_batch(file_list)
+ save_labels(save_file, label_dict)
+
+
+def concat_labels(file_list, save_file, dim=0):
+ """
+ :param file_list:
+ :param save_file:
+ :param dim: 0 纵向合并,1 横向合并
+ :return:
+ """
+ assert isinstance(file_list, list) and len(file_list) >= 2
+ if dim == 0:
+ merge_labels(file_list, save_file)
+ else:
+ first_dict = load_labels(file_list[0])
+ common_keys = set(first_dict.keys())
+ dicts_list = [first_dict]
+ for label_file in file_list[1:]:
+ tmp_dict = load_labels(label_file)
+ common_keys = common_keys.intersection(set(tmp_dict.keys()))
+ dicts_list.append(tmp_dict)
+ print(f"{len(common_keys)} common keys")
+
+ label_dict = {}
+ for name in common_keys:
+ tmp_label = []
+ for tmp_dict in dicts_list:
+ label = tmp_dict[name]
+ if isinstance(label, int):
+ tmp_label.append(label)
+ elif isinstance(label, (list, tuple)):
+ tmp_label += list(label)
+ else:
+ raise NotImplementedError
+ label_dict[name] = tmp_label
+
+ save_labels(save_file, label_dict)
+
+
+def summary_labels(label_dict, classes=None):
+ labels = list(label_dict.values())
+ if not labels:
+ return
+ if isinstance(labels[0], int):
+ labels = [l for l in labels if l != -1]
+ print(f"all: {len(labels)}", end=' ')
+ if classes:
+ assert max(labels) < len(classes)
+ for i, category in enumerate(classes):
+ end = '\n' if i == len(classes) - 1 else ' '
+ print(f"{category}: {labels.count(i)}", end=end)
+ else:
+ all_cls = sorted(list(set(labels)))
+ for i, category in enumerate(all_cls):
+ end = '\n' if i == len(classes) - 1 else ' '
+ print(f"{category}: {labels.count(category)}", end=end)
+ elif isinstance(labels[0], list):
+ label_array = np.array(labels, dtype=int)
+ label_count = list(np.sum(label_array, axis=0))
+ if classes:
+ assert len(classes) == len(labels[0])
+ for i, (category, count) in enumerate(zip(classes, label_count)):
+ end = '\n' if i == len(classes) - 1 else ' '
+ print(f"{category}: {count}", end=end)
+ else:
+ for i, count in enumerate(label_count):
+ end = '\n' if i == len(label_count) - 1 else ' '
+ print(f"{i}: {count}", end=end)
+
+
+def summary_label_file(label_file, classes=None):
+ if isinstance(label_file, str):
+ label_dict = load_labels(label_file)
+ elif isinstance(label_file, (list, tuple)):
+ label_dict = load_labels_batch(label_file) if len(label_file) >= 2 else label_file[0]
+ else:
+ raise NotImplementedError
+ summary_labels(label_dict, classes)
+
+
+class MTLabel:
+ _invalid = -1
+
+ def __init__(self, input_label=None, tasks=None, classes=None):
+
+ label_data = self._load(input_label)
+ tasks = self._check_tasks(tasks)
+ classes = self._check_classes(classes)
+
+ if not label_data.empty:
+ tasks = [f"task{i}" for i in range(label_data.shape[1])] if tasks is None else tasks
+ assert len(tasks) == label_data.shape[1], \
+ f"Tasks length {len(tasks)} not match to label length {label_data.shape[1]}"
+ label_data.columns = tasks
+
+ if tasks is not None:
+ classes = [None] * len(tasks) if classes is None else classes
+ assert len(tasks) == len(classes), f"Tasks {tasks} not match to {classes}"
+
+ self.label_data = label_data
+ self.tasks = copy.deepcopy(tasks)
+ self.classes = copy.deepcopy(classes)
+
+ @staticmethod
+ def _load(input_label):
+ if input_label is None:
+ return pd.DataFrame()
+ if isinstance(input_label, str):
+ if not os.path.isfile(input_label):
+ raise FileNotFoundError(f'Check label file path: {input_label}')
+ label_data = pd.read_csv(input_label, sep=' ', header=None, index_col=0)
+ cprint(f"{label_data.shape if label_data.shape[1] > 1 else label_data.shape[0]} labels from '{input_label}'", prefix='Load')
+ elif isinstance(input_label, pd.core.frame.DataFrame):
+ label_data = copy.deepcopy(input_label)
+ elif isinstance(input_label, dict):
+ label_data = pd.DataFrame(input_label).transpose()
+ elif isinstance(input_label, list):
+ label_data = pd.DataFrame(dict(input_label)).transpose()
+ else:
+ raise TypeError(f'{type(input_label)} is not support')
+ label_data.index.rename("name", inplace=True)
+ return label_data
+
+ @staticmethod
+ def _save(label_data, save_file, filter_invalid=True):
+ save_dir = os.path.dirname(os.path.abspath(save_file))
+ os.makedirs(save_dir, exist_ok=True)
+ label_data.fillna(MTLabel._invalid, inplace=True)
+ if filter_invalid:
+ label_data = label_data[(label_data != MTLabel._invalid).any(axis=1)]
+ label_data = label_data.astype(object)
+ label_data.to_csv(save_file, sep=' ', index=True, header=False)
+ cprint(f"{label_data.shape if label_data.shape[1] > 1 else label_data.shape[0]} labels in '{save_file}'", prefix='Save')
+
+ @classmethod
+ def _new(cls, input_label=None, tasks=None, classes=None):
+ return cls(input_label, tasks, classes)
+
+ @staticmethod
+ def _check_tasks(tasks, int_ok=False):
+ if tasks is None or not tasks:
+ return
+ if isinstance(tasks, (str, numbers.Integral)):
+ tasks = [tasks]
+ elif isinstance(tasks, tuple):
+ tasks = list(tasks)
+ assert tasks and isinstance(tasks, list), f"arg 'task' should be type (int, str, tuple, list)"
+ for i, t in enumerate(tasks):
+ if isinstance(t, numbers.Integral):
+ if not int_ok:
+ raise TypeError(f"'int' type task {t} at {i} not support, set 'int_ok=True' ?")
+ elif not isinstance(t, str):
+ raise TypeError(f"{type(t)} task {t} at {i} not support")
+ return tasks
+
+ @staticmethod
+ def _check_classes(classes):
+ if classes is None:
+ return
+ assert classes and isinstance(classes, list), f"classes {classes} should be a list"
+ if isinstance(classes[0], str):
+ classes = [classes]
+ return classes
+
+ @staticmethod
+ def _check_names(names, int_ok=False):
+ if names is None or not names:
+ return
+ if isinstance(names, (str, numbers.Integral)):
+ names = [names]
+ elif isinstance(names, (tuple, set)):
+ names = list(names)
+ assert names and isinstance(names, list), f"arg 'names' should be type (int, str, tuple, set, list)"
+ for i, t in enumerate(names):
+ if isinstance(t, numbers.Integral):
+ if not int_ok:
+ raise TypeError(f"'int' type name {t} at {i} not support, set 'int_ok=True' ?")
+ elif not isinstance(t, str):
+ raise TypeError(f"{type(t)} name {t} at {i} not support")
+ return names
+
+ @staticmethod
+ def _type_data(values, types):
+ return [t(v) for t, v in zip(types, values)]
+
+ def _check_value(self, value, max_len):
+ if isinstance(value, list):
+ if len(value) < max_len:
+ value = value + [self._invalid] * (max_len - len(value))
+ elif len(value) > max_len:
+ value = value[:max_len]
+ return value
+
+ def _convert_tasks(self, tasks):
+ tasks = self._check_tasks(tasks, int_ok=True)
+ if tasks is None:
+ return
+ for i, t in enumerate(tasks):
+ if isinstance(t, numbers.Integral):
+ assert t < len(self.tasks), f"index {t} out of range({len(self.tasks)})"
+ tasks[i] = self.tasks[t]
+ elif isinstance(t, str):
+ assert t in self.tasks, f"'{t}' not in {self.tasks}"
+ return tasks
+
+ def _convert_names(self, names):
+ names = self._check_names(names, int_ok=True)
+ if names is None:
+ return
+ for i, t in enumerate(names):
+ if isinstance(t, numbers.Integral):
+ assert t < len(self), f"index {t} out of range({len(self)})"
+ names[i] = self.names[t]
+ # elif isinstance(t, str):
+ # assert t in set(self.names), f"'{t}' not in names"
+ return names
+
+ @staticmethod
+ def _map_task_classes(tasks, classes):
+ assert isinstance(tasks, list) and isinstance(classes, list) and len(tasks) == len(classes)
+ map_dict = {t: c for t, c in zip(tasks, classes)}
+ return map_dict
+
+ @property
+ def shape(self):
+ return self.label_data.shape
+
+ @property
+ def empty(self):
+ return self.label_data.empty
+
+ @property
+ def index(self):
+ return self.label_data.index
+
+ @property
+ def values(self):
+ return self.label_data.values
+
+ @property
+ def names(self):
+ return list(self.index)
+
+ @property
+ def columns(self):
+ return self.tasks
+
+ @property
+ def dtypes(self):
+ return self.label_data.dtypes
+
+ def head(self, n=5):
+ return self.label_data.astype(object).head(n)
+
+ def tail(self, n=5):
+ return self.label_data.astype(object).tail(n)
+
+ def astype(self, dtype):
+ self.label_data = self.label_data.astype(dtype)
+ return self
+
+ def summary(self, tasks=None, extra_info=''):
+ print("-" * 120)
+ tasks = self._convert_tasks(tasks)
+ if tasks is None and self.tasks is None:
+ log_warn(f"Label data is empty or 'tasks' is None")
+ return
+ tasks = self.tasks if tasks is None else tasks
+ tc_dict = self._map_task_classes(self.tasks, self.classes)
+ cprint(extra_info, prefix="Summary")
+ for t in tasks:
+ classes = tc_dict[t]
+ cur_label = self.label_data[t]
+ cur_label = cur_label[cur_label != self._invalid]
+ cprint(f"Task {t}\t==> all: {len(cur_label)}", end=' ')
+
+ if cur_label.empty:
+ print()
+ continue
+
+ if classes is None:
+ classes = ['mean', 'std', 'min', 'max', '1%', '50%', '99%']
+ res = cur_label.describe(percentiles=[0.01, 0.99], include=[np.number])
+ for i, category in enumerate(classes):
+ end = '\n' if i == len(classes) - 1 else ' '
+ cprint(f"{category}: {res[category]:.3f}", end=end)
+ else:
+ res = cur_label.value_counts()
+ for i, category in enumerate(classes):
+ end = '\n' if i == len(classes) - 1 else '\t'
+ if i in res:
+ cprint(f"{category}: {res[i]}", end=end)
+ else:
+ cprint(f"{category}: 0", end=end)
+ print("-"*120)
+
+ def plot(self, tasks=None):
+ tasks = self._convert_tasks(tasks)
+ if tasks is None:
+ tasks = self.tasks
+ tc_dict = self._map_task_classes(self.tasks, self.classes)
+ cls_part = []
+ reg_part = []
+ for t in tasks:
+ classes = tc_dict[t]
+ if classes is None:
+ reg_part.append(t)
+ else:
+ cls_part.append(t)
+ if cls_part:
+ self._plt_bar(cls_part, in_one=False)
+ if reg_part:
+ self._hist(reg_part)
+
+ def _hist(self, tasks=None, cols=3, in_one=False):
+ tasks = self._convert_tasks(tasks)
+ if tasks is None:
+ tasks = self.tasks
+ cols = cols if len(tasks) >= cols else len(tasks)
+ rows, mod = divmod(len(tasks), cols)
+ rows += mod != 0
+ for i, t in enumerate(tasks):
+ plt.subplot(rows, cols, i + 1) if in_one else plt.figure()
+ cur_label = self.label_data[t]
+ cur_label = cur_label[cur_label != self._invalid]
+ bin_size = (cur_label.max() - cur_label.min()) / 11
+ cur_label.hist(bins=np.arange(cur_label.min(), 1.01 * cur_label.max(), bin_size))
+ plt.title(t)
+ plt.tight_layout()
+ plt.grid(False)
+ plt.show()
+
+ def _plt_bar(self, tasks=None, cols=3, in_one=False):
+ tasks = self._convert_tasks(tasks)
+ if tasks is None:
+ tasks = self.tasks
+ tc_dict = self._map_task_classes(self.tasks, self.classes)
+ cols = cols if len(tasks) >= cols else len(tasks)
+ rows, mod = divmod(len(tasks), cols)
+ rows += mod != 0
+
+ if in_one:
+ fig, axes = plt.subplots(rows, cols)
+
+ for i, t in enumerate(tasks):
+ # plt.figure()
+ # plt.subplot(rows, cols, i + 1)
+ classes = tc_dict[t]
+ cur_label = self.label_data[t]
+ cur_label = cur_label[cur_label != self._invalid]
+ res = cur_label.value_counts()
+ values = [res[i] for i in range(len(classes))]
+ df = pd.DataFrame({"category": classes, "count": values})
+ if in_one:
+ r, c = divmod(i, cols)
+ ax = axes[r, c] if rows > 1 else axes[i]
+ else:
+ ax = None
+ df.plot(kind='bar', x="category", y="count", title=t, grid=False, rot=30 if in_one else 0,
+ ax=ax, legend=False)
+ plt.tight_layout()
+ plt.show()
+
+ def set_tasks(self, tasks):
+ tasks = self._check_tasks(tasks)
+ if self.tasks is not None and len(tasks) != len(self.tasks):
+ log_error(f"new tasks length {len(tasks)} not equal to ori {len(self.tasks)}")
+ return
+ self.tasks = copy.deepcopy(tasks)
+ self.label_data.columns = self.tasks
+
+ def set_classes(self, classes):
+ classes = self._check_classes(classes)
+ if self.classes is not None and len(classes) != len(self.classes):
+ log_error(f"new tasks length {len(classes)} not equal to ori {len(self.classes)}")
+ return
+ self.classes = copy.deepcopy(classes)
+
+ def insert(self, task, value=None, loc=None, category=None, dtype=None):
+ if task in self.tasks:
+ raise KeyError(f"{task} already exits")
+ if value is None:
+ value = self._invalid
+ if isinstance(category, str):
+ category = [category]
+ assert category is None or isinstance(category, list)
+
+ self.label_data.insert(len(self.tasks) if loc is None else loc, task, value)
+ if dtype is not None:
+ self.label_data[task] = self.label_data[task].astype(dtype)
+
+ if loc is None:
+ self.tasks.append(task)
+ self.classes.append(category)
+ else:
+ self.tasks.insert(loc, task)
+ self.classes.insert(loc, category)
+
+ def remove(self, task):
+ self.__delitem__(task)
+
+ def add(self, key, value, keep_dtypes=False):
+ ori_dtypes = self.label_data.dtypes
+ code = 1 if key not in self.label_data.index else 0
+ self.label_data.loc[key] = value
+ if (ori_dtypes != self.label_data.dtypes).any() and keep_dtypes:
+ self.label_data = self.label_data.astype(ori_dtypes)
+ return code
+
+ def update(self, other_label, tasks=None, classes=None, inplace=False):
+ if not isinstance(other_label, type(self)):
+ other_label = self._new(other_label, tasks, classes)
+ if other_label.empty:
+ log_warn(f"Empty label data, check path or data")
+ return
+ assert (other_label.columns == self.label_data.columns).all(), \
+ f"current label({self.label_data.shape}) not match to input({other_label.shape}) at columns"
+
+ ori_dtypes = self.label_data.dtypes
+ other_label = other_label.label_data
+ other_label.columns = self.label_data.columns
+
+ common_index = self.label_data.index.intersection(other_label.index)
+ if common_index.empty:
+ label_data = pd.concat([self.label_data, other_label])
+ cprint(f"{len(other_label)} add.", prefix='Update')
+ else:
+ common1 = self.label_data.loc[common_index]
+ common2 = other_label.loc[common_index]
+ no_equal = common1[common1.ne(common2).any(axis=1)]
+ add_label = other_label[~other_label.index.isin(common_index)]
+ label_data = self.label_data[:]
+ label_data.update(common2)
+ label_data = pd.concat([label_data, add_label])
+ label_data = label_data.astype(ori_dtypes)
+ cprint(f"{len(common1)} common ({len(no_equal)} update), {len(add_label)} add.", prefix='Update')
+
+ if inplace:
+ self.label_data = label_data
+ else:
+ return self._new(label_data, self.tasks, self.classes)
+
+ def join(self, other_label, tasks=None, classes=None, inplace=False):
+ if not isinstance(other_label, type(self)):
+ other_label = self._new(other_label, tasks, classes)
+ if other_label.empty:
+ log_warn(f"Empty label data, check path or data")
+ return
+ assert other_label.shape[0] == self.label_data.shape[0], \
+ f"current label({self.label_data.shape}) not match to input({other_label.shape}) at index"
+
+ ori_tasks = self.tasks
+ ori_classes = self.classes
+ other_tasks = other_label.tasks
+ other_classes = other_label.classes
+ other_dict = self._map_task_classes(other_tasks, other_classes)
+ use_tasks = [t for t in other_tasks if t not in ori_tasks]
+ use_classes = [other_dict[t] for t in use_tasks]
+ new_tasks = ori_tasks + use_tasks
+ new_classes = ori_classes + use_classes
+
+ label_data = pd.concat([self.label_data, other_label.label_data], axis=1)
+ label_data = label_data.T
+ label_data = label_data[~label_data.index.duplicated(keep='last')].T
+ label_data = label_data[new_tasks]
+
+ if inplace:
+ self.label_data = label_data
+ self.tasks = new_tasks
+ self.classes = new_classes
+ else:
+ return self._new(label_data, new_tasks, new_classes)
+
+ def concat(self, other_label, tasks=None, classes=None, inplace=False, fill=True):
+ if not isinstance(other_label, type(self)):
+ other_label = self._new(other_label, tasks, classes)
+ if other_label.empty:
+ log_warn(f"Empty label data, check path or data")
+ return
+
+ ori_tasks = self.tasks
+ ori_classes = self.classes
+ other_tasks = other_label.tasks
+ other_classes = other_label.classes
+ other_dict = self._map_task_classes(other_tasks, other_classes)
+ use_tasks = [t for t in other_tasks if t not in ori_tasks]
+ use_classes = [other_dict[t] for t in use_tasks]
+ new_tasks = ori_tasks + use_tasks
+ new_classes = ori_classes + use_classes
+
+ label_data = self.label_data[:]
+ other_label = other_label.label_data
+ common_index = label_data.index.intersection(other_label.index)
+ common_columns = label_data.columns.intersection(other_label.columns)
+ other_index = other_label.index[~other_label.index.isin(common_index)]
+ other_columns = other_label.columns[~other_label.columns.isin(common_columns)]
+
+ if not common_index.empty and not common_columns.empty:
+ label_data.update(other_label.loc[common_index, common_columns]) # 交差部分
+ if not common_columns.empty:
+ label_data = pd.concat([label_data, other_label.loc[other_index, common_columns]]) # 相同列
+ if not common_index.empty:
+ label_data = pd.concat([label_data, other_label.loc[common_index, other_columns]], axis=1) # 相同行
+ # 不交叉部分
+ if common_index.empty and common_columns.empty:
+ label_data = pd.concat([label_data, other_label])
+ else:
+ label_data.update(other_label)
+
+ if fill:
+ label_data.fillna(MTLabel._invalid, inplace=True)
+
+ cprint(f"update common ({len(common_index)}, {len(common_columns)}), "
+ f"add ({len(other_index)}, {len(other_columns)})", prefix="Concat")
+
+ if inplace:
+ self.label_data = label_data
+ self.tasks = new_tasks
+ self.classes = new_classes
+ else:
+ return self._new(label_data, new_tasks, new_classes)
+
+ def sample(self, num):
+ obj = self._new(self.label_data.sample(num), self.tasks, self.classes)
+ return obj
+
+ def pick_tasks(self, tasks, inplace=False):
+ tasks = self._convert_tasks(tasks)
+ if tasks is None:
+ log_warn("'task' is None, use ori label")
+ return self
+ classes = [self.classes[self.tasks.index(t)] for t in tasks]
+ if inplace:
+ self.label_data = self.label_data[tasks]
+ self.tasks = copy.deepcopy(tasks)
+ self.classes = copy.deepcopy(classes)
+ return self
+ else:
+ obj = self._new(self.label_data[tasks], tasks=tasks, classes=classes)
+ return obj
+
+ def pick_names(self, name_list, inplace=False):
+ names_list = self._check_names(name_list)
+ if names_list is None:
+ log_warn("'name_list' is None, use ori label")
+ return self
+ if inplace:
+ self.label_data = self.label_data.loc[names_list]
+ return self
+ else:
+ obj = self._new(self.label_data.loc[names_list], tasks=self.tasks, classes=self.classes)
+ return obj
+
+ def tolist(self):
+ names = self.index.tolist()
+ values = self.values.tolist()
+ dtypes = [t.type for t in self.dtypes.tolist()]
+ values = [self._type_data(v, dtypes) for v in values]
+ return list(zip(names, values))
+
+ def todict(self):
+ names = self.index.tolist()
+ values = self.values.tolist()
+ dtypes = [t.type for t in self.dtypes.tolist()]
+ values = [self._type_data(v, dtypes) for v in values]
+ return dict(zip(names, values))
+
+ def export(self, save_file, filter_invalid=True):
+ self._save(self.label_data, save_file, filter_invalid)
+
+ def export_tasks(self, task, save_file, filter_invalid=True):
+ obj = self.pick_tasks(task)
+ self._save(obj.label_data, save_file, filter_invalid)
+
+ def __getitem__(self, item):
+ # 普通切片
+ item = slice(item, item + 1) if isinstance(item, numbers.Integral) else item
+ if isinstance(item, slice):
+ obj = self._new(self.label_data[item], tasks=self.tasks, classes=self.classes)
+ return obj
+
+ # 列名/图像名
+ elif isinstance(item, str):
+ if item in self.tasks:
+ obj = self._new(self.label_data[[item]], tasks=[item], classes=[self.classes[self.tasks.index(item)]])
+ elif item in self.label_data.index:
+ obj = self._new(self.label_data.loc[[item]], tasks=self.tasks, classes=self.classes)
+ else:
+ raise KeyError(f"'{item}' not in tasks or index")
+ return obj
+
+ # 目前只支持task列表
+ elif isinstance(item, list):
+ return self.pick_tasks(item)
+
+ # 二维切片
+ elif isinstance(item, tuple):
+ assert len(item) == 2, f"tuple length {len(item)} != 2"
+ row_slice, col_slice = item
+ if isinstance(row_slice, (slice, numbers.Integral)) or \
+ (isinstance(row_slice, str) and row_slice not in self.tasks):
+ obj = self.__getitem__(row_slice)
+ if isinstance(col_slice, (slice, numbers.Integral)):
+ obj = obj.__getitem__(obj.tasks[col_slice])
+ return obj
+ else:
+ return obj.__getitem__(col_slice)
+ else:
+ raise TypeError(f"{item} first item type {type(row_slice)} is an invalid key")
+ else:
+ raise TypeError(f"{item} is an invalid key")
+
+ def __setitem__(self, key, value):
+ """
+ 只接受label_data中已有的数据
+ loc: only work on index, can assign a new index or column value
+ iloc: work on position, can not assign a new index or column value
+ at: get scalar values. It's a very fast loc
+ iat: Get scalar values. It's a very fast iloc
+ """
+ if isinstance(key, (numbers.Integral, slice)):
+ self.label_data.iloc[key] = value
+
+ # 列名/index
+ elif isinstance(key, str):
+ if key in self.tasks:
+ self.label_data.loc[:, key] = value
+ elif key in self.label_data.index:
+ self.label_data.loc[key] = value
+ else:
+ raise KeyError(f"'{key}' not in tasks or index")
+
+ # task 列表
+ elif isinstance(key, list):
+ key = self._convert_tasks(key)
+ value = self._check_value(value, len(key))
+ self.label_data.loc[:, key] = value
+
+ # 二维切片
+ elif isinstance(key, tuple):
+ assert len(key) == 2, f"tuple length {len(key)} != 2"
+ row_slice, col_slice = key
+
+ if isinstance(row_slice, numbers.Integral):
+ row_slice = self.label_data.index[row_slice]
+ elif isinstance(row_slice, str):
+ assert row_slice in self.label_data.index, f"{row_slice} not in index"
+ elif not isinstance(row_slice, slice):
+ raise TypeError(f"{key} first item type {type(row_slice)} is an invalid key")
+
+ if isinstance(col_slice, numbers.Integral):
+ col_slice = self.tasks[col_slice]
+ elif isinstance(col_slice, str):
+ assert col_slice in self.tasks, f"{col_slice} not in {self.tasks}"
+ elif isinstance(col_slice, list):
+ col_slice = self._convert_tasks(col_slice)
+ elif not isinstance(col_slice, slice):
+ raise TypeError(f"{key} second item type {type(col_slice)} is an invalid key")
+
+ if isinstance(row_slice, str) and isinstance(col_slice, str):
+ self.label_data.at[row_slice, col_slice] = value
+ else:
+ self.label_data.loc[row_slice, col_slice] = value
+
+ def __delitem__(self, key):
+ if not isinstance(key, (str, list)):
+ raise KeyError(f"'{key}' is an invalid key")
+ key = self._convert_tasks(key)
+ self.label_data.drop(columns=key, axis=1, inplace=True)
+ tc_dict = self._map_task_classes(self.tasks, self.classes)
+ self.tasks = [t for t in self.tasks if t not in set(key)]
+ self.classes = [tc_dict[t] for t in self.tasks]
+
+ def __getattr__(self, item):
+ if item in self.tasks:
+ return self.__getitem__(item)
+
+ # def __setattr__(self, key, value):
+ # if key in self.tasks:
+ # self.label_data[key] = value
+
+ def __len__(self):
+ return len(self.label_data)
+
+ def __repr__(self):
+ return self.label_data.astype(object).__repr__()
+
+ def __copy__(self):
+ return MTLabel(self.label_data, self.tasks, self.classes)
+
+ def __deepcopy__(self, memodict={}):
+ return MTLabel(copy.deepcopy(self.label_data, memodict),
+ copy.deepcopy(self.tasks, memodict),
+ copy.deepcopy(self.classes, memodict))
+
+
+if __name__ == '__main__':
+ dms_tasks = ["ems", "eye", 'mouth', 'glass', 'mask', 'smoke', 'phone', "eyelid_r", "eyelid_l"]
+ dms_classes = [['normal', 'look_left', 'look_down', 'look_right', 'invalid'],
+ ['normal', 'close_eye', 'invalid'],
+ ['normal', 'yawn', 'invalid'],
+ ['normal', 'glass', 'invalid'],
+ ['normal', 'mask', 'invalid'],
+ ['normal', 'smoke'],
+ ['normal', 'phone'], None, None]
+ test_label = MTLabel('../test_dms3_labels_v4_rec.txt', dms_tasks, dms_classes)
+ test_label.summary()
+ # test_label.plot(["eyelid_r", "eyelid_l"])
+ # print(test_label.head())
+ # a = test_label[:3, :5]
+ # b = test_label[2:5, 4:]
+ # b[0] = 8
+ # print(a)
+ # print(b)
+ #
+ # c = a.concat(b, fill=True)
+ # print(c)
+ test_label['ems'].summary()
+ test_label['ems'].export("ems_rec.txt")
diff --git a/utils/multiprogress.py b/utils/multiprogress.py
new file mode 100644
index 0000000000000000000000000000000000000000..798455e761cf3d498f42e346f4c27db1922b7428
--- /dev/null
+++ b/utils/multiprogress.py
@@ -0,0 +1,61 @@
+# -*- coding:utf-8 –*-
+import os
+from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
+from multiprocessing import Pool
+
+from tqdm import tqdm
+
+
+class MultiProgress(object):
+ def __init__(self, info_lst, workers=6):
+ self.info_lst = info_lst
+ self.num_items = len(info_lst)
+ self.workers = workers
+ self.pool = Pool(workers)
+
+ def run(self, func):
+ with tqdm(total=self.num_items) as bar:
+ for _ in self.pool.imap(func, self.info_lst):
+ bar.update(1)
+
+ def __repr__(self):
+ return f'MultiProgress(info_lst={type(self.info_lst)}, workers={self.workers})'
+
+
+class MultiThreading(object):
+ def __init__(self, info_lst, workers=6, pbar=True, deception='Running...'):
+ self.info_lst = list(info_lst)
+ self.num_items = len(info_lst)
+ self.process_bar = pbar
+ self.deception = deception
+ self.workers = workers
+ self.exe = ThreadPoolExecutor(workers)
+
+ def run(self, func):
+ if self.process_bar:
+ res = list(tqdm(self.exe.map(func, self.info_lst), total=self.num_items, desc=self.deception))
+ else:
+ res = list(self.exe.map(func, self.info_lst))
+ return list(res)
+
+ def __repr__(self):
+ return f'MultiThreading(info_lst={type(self.info_lst)}, workers={self.workers})'
+
+
+class MultiProcess(object):
+ def __init__(self, info_lst, pbar=True):
+ self.info_lst = info_lst
+ self.num_items = len(info_lst)
+ self.workers = os.cpu_count()
+ self.process_bar = pbar
+
+ def run(self, func):
+ with ProcessPoolExecutor() as exe:
+ if self.process_bar:
+ res = list(tqdm(exe.map(func, self.info_lst), total=self.num_items))
+ else:
+ res = list(exe.map(func, self.info_lst))
+ return res
+
+ def __repr__(self):
+ return f'MultiProcess(info_lst={self.num_items}, workers={self.workers})'
diff --git a/utils/os_util.py b/utils/os_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..69083c7fe902107743e5a8fd2d70598e6f64a4ef
--- /dev/null
+++ b/utils/os_util.py
@@ -0,0 +1,429 @@
+# -*- coding:utf-8 –*-
+import os
+import re
+import random
+import urllib.request as request
+from glob import glob
+
+import cv2
+import numpy as np
+import requests
+
+from utils.common import clock_custom, Color, colorprint, log_error
+from utils.multiprogress import MultiThreading
+from utils.labels import load_labels
+
+
+def is_day(img_path):
+ pattern = re.compile('.*20[1-2][0-9][0-1][0-9][0-3][0-9]([0-2][0-9][0-5][0-9]).*')
+ res = pattern.match(img_path)
+ if res is None:
+ return -1
+ cur_time = int(res.group(1))
+ if 630 <= cur_time <= 1830:
+ return 1
+ else:
+ return 0
+
+
+# 从目录获取所有的图片完整路径
+def get_file_paths(folder, upper=False, sort=True, abs_path=True, mod='img'):
+ if mod == 'img':
+ extensions = ['jpg', 'jpeg', 'png', 'bmp']
+ elif mod == 'vdo':
+ extensions = ['mp4', 'avi', 'mov']
+ else:
+ extensions = [mod]
+ img_files = []
+ for ext in extensions:
+ ext = ext.upper() if upper else ext
+ files = glob('%s/*.%s' % (folder, ext)) # 通配符,找到所有以ext为后缀的文件
+ if not abs_path:
+ files = [os.path.basename(path) for path in files]
+ img_files += files
+ return sorted(img_files) if sort else img_files
+
+
+def get_files_paths_batch(data_dir, upper=False, sort=True, abs_path=True, mod='img'):
+ subs = [os.path.join(data_dir, sub) for sub in os.listdir(data_dir)
+ if not sub.startswith('.') and os.path.isdir(os.path.join(data_dir, sub))]
+ print(subs)
+ all_paths = []
+ for sub in subs:
+ all_paths += get_file_paths(sub, upper, sort, abs_path, mod)
+ return all_paths
+
+
+# 读取中文路径图片
+def read_img(img_path):
+ img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
+ return img
+
+
+# 写中文路径图片
+def save_img(save_path, img):
+ cv2.imencode('.jpg', img,)[1].tofile(save_path)
+
+
+def read_online_image(img_url):
+ try:
+ response = request.urlopen(img_url)
+ img_array = np.array(bytearray(response.read()), dtype=np.uint8)
+ img = cv2.imdecode(img_array, -1)
+ return img
+ except Exception:
+ print('{} read failed!'.format(img_url))
+ return
+
+
+# @clock_custom('[{elapsed:0.8f}s] {name}()')
+# 速度快
+def download_file(file_url, save_path):
+ if os.path.exists(save_path):
+ colorprint(Color.YELLOW, 'Image %s already exists. Skipping download.' % save_path)
+ return 0
+ try:
+ resp = requests.get(file_url)
+ file = resp.content
+ with open(save_path, 'wb') as fp:
+ fp.write(file)
+ return 1
+ except:
+ colorprint(Color.RED, 'Warning: can not download from {}'.format(file_url))
+ return 0
+
+
+def download_files(url_list, save_dir, workers=8):
+ os.makedirs(save_dir, exist_ok=True)
+
+ def kernel(url):
+ name = os.path.basename(url)
+ save_path = os.path.join(save_dir, name)
+ res = download_file(url, save_path)
+ return res
+
+ exe = MultiThreading(url_list, workers)
+ result = exe.run(kernel)
+ print(f"Download {result.count(1)} files in '{save_dir}', {result.count(0)} failed.")
+
+
+# @clock_custom('[{elapsed:0.8f}s] {name}()')
+def download_image(img_url, save_path):
+ if os.path.exists(save_path):
+ print('Image %s already exists. Skipping download.' % save_path)
+ return
+ img = read_online_image(img_url)
+ if img is None:
+ return
+ save_img(save_path, img)
+
+
+def download_video(vdo_url, save_path=None):
+ if save_path is None:
+ save_path = vdo_url.split('/')[-1]
+ # print("开始下载:%s" % os.path.basename(save_path))
+ if os.path.exists(save_path):
+ print('Video %s already exists. Skipping download.' % save_path)
+ return
+ r = requests.get(vdo_url, stream=True).content
+ with open(save_path, 'wb') as f:
+ f.write(r)
+ f.flush()
+ print("%s 下载完成!\n" % os.path.basename(save_path))
+ return
+
+
+def check_images(img_dir, tmp_dir=None, batch=False):
+ if tmp_dir is None:
+ tmp_dir = os.path.join(img_dir, 'tmp')
+ if not os.path.exists(tmp_dir):
+ os.makedirs(tmp_dir)
+
+ if batch:
+ img_paths = get_files_paths_batch(img_dir, sort=False)
+ else:
+ img_paths = get_file_paths(img_dir, sort=False)
+
+ def check(img_path):
+ try:
+ img = cv2.imread(img_path)
+ size = img.shape
+ return 0
+ except Exception:
+ mv_cmd = f'mv {img_path} {tmp_dir}'
+ print(os.path.basename(img_path))
+ os.system(mv_cmd)
+ return 1
+
+ exe = MultiThreading(img_paths, 6)
+ res = exe.run(check)
+ print(f"total {len(img_paths)} images, {sum(res)} wrong.")
+
+
+def divide_by_shape(img_dir, batch=False, b100=(1280, 720), b200=(1280, 960)):
+ tmp_dir = os.path.join(img_dir, 'tmp')
+ b100_dir = os.path.join(img_dir, 'b100')
+ b200_dir = os.path.join(img_dir, 'b200')
+ for sub in [tmp_dir, b200_dir, b100_dir]:
+ if not os.path.exists(sub):
+ os.makedirs(sub)
+
+ if batch:
+ img_paths = get_files_paths_batch(img_dir, sort=False)
+ else:
+ img_paths = get_file_paths(img_dir, sort=False)
+
+ def divide(img_path):
+ try:
+ img = cv2.imread(img_path)
+ h, w = img.shape[:2]
+ if (w, h) == b100:
+ mv_cmd = f'mv {img_path} {b100_dir}'
+ print(mv_cmd)
+ os.system(mv_cmd)
+ return 1
+ elif (w, h) == b200:
+ mv_cmd = f'mv {img_path} {b200_dir}'
+ print(mv_cmd)
+ os.system(mv_cmd)
+ return 2
+ else:
+ return 3
+ except Exception:
+ mv_cmd = f'mv {img_path} {tmp_dir}'
+ print(mv_cmd)
+ os.system(mv_cmd)
+ return 0
+
+ exe = MultiThreading(img_paths, 6)
+ res = list(exe.run(divide))
+ print(f"total {len(img_paths)} images, {res.count(1)} b100 {res.count(2)} b200 "
+ f"{res.count(3)} other {res.count(0)} wrong.")
+
+
+def copy_files(ori_dir, dst_dir, file_type='img'):
+ if not os.path.exists(dst_dir):
+ os.makedirs(dst_dir)
+
+ if isinstance(ori_dir, str) and os.path.isdir(ori_dir):
+ print("load images, please wait ...")
+ img_paths = get_file_paths(ori_dir, abs_path=True, mod=file_type)
+ elif isinstance(ori_dir, list):
+ img_paths = ori_dir
+ else:
+ raise NotImplementedError(f"check input, '{ori_dir}' should be a dir or list of paths")
+
+ print(f"total {len(img_paths)} images")
+
+ def copy(img_path):
+ new_path = os.path.join(dst_dir, os.path.basename(img_path))
+ if os.path.exists(new_path):
+ return 0
+ cp_cmd = f"cp {img_path} {new_path}"
+ os.system(cp_cmd)
+ return 1
+
+ exe = MultiThreading(img_paths, 16)
+ res = exe.run(copy)
+ print(f"total {len(img_paths)} images, copy {res.count(1)} files, skip {res.count(0)} files")
+
+
+def copy_minute_images(data_dir, save_dir, width=1280):
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ img_paths = get_file_paths(data_dir)
+
+ def copy(info):
+ i, img_path = info
+ img = cv2.imread(img_path)
+ w = img.shape[1]
+ if w != width:
+ return 0
+ cp_cmd = f"cp {img_path} {save_dir}"
+ print(i, cp_cmd)
+ # os.system(cp_cmd)
+ return 1
+
+ exe = MultiThreading(list(enumerate(img_paths)), 16)
+ res = exe.run(copy)
+ print(f"total {len(img_paths)} images, {res.count(1)} minute images")
+
+
+def day_or_night(img_path, day=(630, 1830)):
+ pat_str = r'.+/202[0-1][0-1][0-9][0-3][0-9]([0-2][0-9][0-5][0-9])[0-9]+_.+'
+ pattern = re.compile(pat_str)
+ res = pattern.match(img_path)
+ if res is None:
+ return -1
+ cur_time = int(res.group(1))
+ if day[0] <= cur_time <= day[1]:
+ return 1
+ else:
+ return 0
+
+
+def divide_by_time(img_dir, batch=False, day=(630, 1830)):
+ day_dir = os.path.join(img_dir, 'day')
+ night_dir = os.path.join(img_dir, 'night')
+ for sub in [day_dir, night_dir]:
+ if not os.path.exists(sub):
+ os.makedirs(sub)
+
+ if batch:
+ img_paths = get_files_paths_batch(img_dir, sort=False)
+ else:
+ img_paths = get_file_paths(img_dir, sort=False)
+
+ def divide(img_path):
+ r = day_or_night(img_path, day)
+ if r:
+ mv_cmd = f'mv {img_path} {day_dir}'
+ print(mv_cmd)
+ os.system(mv_cmd)
+ else:
+ mv_cmd = f'mv {img_path} {night_dir}'
+ print(mv_cmd)
+ os.system(mv_cmd)
+ return r
+
+ exe = MultiThreading(img_paths, 6)
+ res = list(exe.run(divide))
+ print(f"total {len(img_paths)} images, {res.count(1)} day {res.count(0)} night.")
+
+
+def sample_images(img_dir, sample, mod='mv', save_dir=None):
+ img_paths = get_file_paths(img_dir, sort=False)
+ random.shuffle(img_paths)
+ sampled = random.sample(img_paths, sample)
+ sampled = [(idx, img_path) for idx, img_path in enumerate(sampled)]
+
+ if not save_dir:
+ save_dir = img_dir + '_sample'
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ def mv_img(info):
+ idx, img_path = info
+ img_name = os.path.basename(img_path)
+ new_path = os.path.join(save_dir, img_name)
+ cmd = f'{mod} {img_path} {new_path}'
+ print(idx, cmd)
+ os.system(cmd)
+ return 1
+
+ exe = MultiThreading(sampled, 12)
+ res = exe.run(mv_img)
+ print(sum(list(res)))
+ return sampled
+
+
+def pick_files(label_file, data_dir, save_dir=None, label_cond=None, os_cmd='cp'):
+ label_dict = load_labels(label_file)
+ if save_dir is None:
+ save_dir = data_dir.rstrip('/') + '_pick'
+ os.makedirs(save_dir, exist_ok=True)
+
+ if label_cond is None:
+ picked = list(label_dict.keys())
+ else:
+ picked = [n for n, l in label_dict.items() if l == label_cond]
+ # picked = [n for n, l in label_dict.items() if l[5] == 1]
+ # picked = [n for n, l in label_dict.items()]
+
+ assert os_cmd in ['cp', 'mv']
+
+ def _pick(img_name):
+ img_path = os.path.join(data_dir, img_name)
+ if not os.path.exists(img_path):
+ log_error(f"{img_path} not exist.")
+ return 0
+ new_path = os.path.join(save_dir, img_name)
+ cmd = f'{os_cmd} {img_path} {new_path}'
+ os.system(cmd)
+ return 1
+
+ exe = MultiThreading(picked, workers=10)
+ res = exe.run(_pick)
+ print(f"total {len(picked)} items, {os_cmd} {res.count(1)} items, {res.count(0)} not exist.")
+
+
+def load_images(input_data, workers=12):
+ if input_data and isinstance(input_data, list) and os.path.isfile(input_data[0]):
+ img_paths = input_data
+ elif os.path.isdir(input_data):
+ img_paths = get_file_paths(input_data, sort=False)
+ else:
+ raise NotImplementedError
+
+ def load(img_path):
+ try:
+ img = read_img(img_path)
+ return img_path, img
+ except Exception:
+ return img_path, None
+
+ exe = MultiThreading(img_paths, max(workers, 8))
+ res = exe.run(load)
+ out_paths = [r[0] for r in res]
+ assert img_paths == out_paths
+
+ cache_images = [r[1] for r in res if r[1] is not None]
+ assert len(cache_images) == len(img_paths), \
+ f"Not load complete! Input paths length {len(img_paths)} != load length {len(cache_images)}"
+
+ return tuple(cache_images)
+
+
+def export_binary_files(inp_list, save_dir):
+ os.makedirs(save_dir, exist_ok=True)
+ save_file = save_dir.rstrip('/') + '.txt'
+
+ def export(info):
+ img_name, raw_img = info
+ new_name = os.path.splitext(img_name)[0] + '.raw'
+ save_path = os.path.join(save_dir, new_name)
+ relative_path = os.path.join(os.path.basename(save_dir), new_name)
+ raw_img.tofile(save_path)
+ return relative_path
+
+ exe = MultiThreading(inp_list)
+ res = exe.run(export)
+
+ with open(save_file, 'w') as f:
+ f.write('\n'.join(res))
+
+ print(f"Save {len(res)} binary files in '{save_file}'")
+
+
+def sequence_to_file(lst, save_file):
+ if not lst:
+ return
+ assert isinstance(lst, (list, tuple, set))
+
+ lst = [str(l) for l in lst]
+ with open(save_file, "w") as f:
+ f.write('\n'.join(lst))
+ print(f"save {len(lst)} items in '{save_file}'")
+
+
+if __name__ == '__main__':
+ img_dir = '/nfs/volume-236-2/qilongyu/person_seats/classify/images'
+ # copy_minute_images(img_dir, save_dir='/nfs/volume-236-2/iov_ann/minute_images')
+ # sample_occ_images(img_dir, sample=1, mod='cp', save_dir='/Users/didi/Desktop/视频分类/dataset/ps_occ')
+ # sample_images(img_dir, 1000, save_dir='/Users/didi/Desktop/人数座次/personData/fhm/fhm_1', mod='mv')
+ # data_dir = '/Users/didi/Desktop/视频分类/dataset/inner'
+ # save_dir = '/Users/didi/Desktop/视频分类/dataset/min_1208'
+ # for sub in os.listdir(data_dir):
+ # if sub.startswith('20201208'):
+ # cmd = f"mv {data_dir}/{sub} {save_dir}/{sub}"
+ # print(cmd)
+ # os.system(cmd)
+ # data_dir = '/Users/didi/Desktop/视频分类/dataset/anomaly_video2/ppp'
+ # print(os.listdir(data_dir))
+ # get_files_paths_batch(data_dir)
+ # download_video('http://100.69.239.80:8002/static/tac_permanent_ns_inner/20220629163931_991745085639127040_338345_173696_.mp4'
+ # )
+ # load_images('/Users/didi/Desktop/MTL/dataset/images')
+
+ check_images("/mnt/qly/dms3/images")
diff --git a/utils/plt_util.py b/utils/plt_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fbc7a9e93e2540d9f09fa088c022387e1c60fb1
--- /dev/null
+++ b/utils/plt_util.py
@@ -0,0 +1,669 @@
+# -*- coding:utf-8 –*-
+import collections
+import io
+import os
+
+import cv2
+import matplotlib.font_manager as fm
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from PIL import Image, ImageFont, ImageDraw
+from matplotlib.pyplot import MultipleLocator
+
+from utils.images import put_chinese_text, paste
+from utils.common import Color, cprint
+
+ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
+# FONT = fm.FontProperties(fname=os.path.join(ROOT_DIR, "files/simhei.ttf"), size=14)
+FRAME = cv2.imread(f'{ROOT_DIR}/files/frame_b.png', -1)
+INFO = cv2.imread(f'{ROOT_DIR}/files/info_box.png', -1)
+# plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文标签
+plt.rcParams['axes.unicode_minus'] = False # 正常显示正负号
+# plt.rc('axes', axisbelow=True)
+
+
+# 绘制柱状图
+def plot_bar(values, xlabel=None, ylabel=None, title='res', color="blue", width=0.3, xticks=None, save_path=None):
+ """
+ 根据已有值绘制柱状图
+ :param values: 每个柱子的对应值
+ :param xlabel: x轴标签
+ :param ylabel: y轴标签
+ :param title: 图片标题
+ :param color: 颜色,默认紫罗兰
+ :param width: 柱子宽度
+ :param xticks: x轴刻度
+ :param save_path: 保存路径
+ :return: None
+ """
+ # 创建一个点数为 8 x 6 的窗口, 并设置分辨率为 80像素/每英寸
+ plt.figure(figsize=(2*len(values), 6), dpi=80)
+ # 再创建一个规格为 1 x 1 的子图
+ # plt.subplot(1, 1, 1)
+ # 柱子总数
+ n = len(values)
+ # 包含每个柱子下标的序列
+ index = np.arange(n)
+ plt.bar(index, values, width, label="num", color=color)
+
+ # 设置横轴标签
+ if xlabel:
+ plt.xlabel(xlabel, fontsize=14)
+ # 设置纵轴标签
+ if ylabel:
+ plt.ylabel(ylabel, fontsize=14)
+ # 添加标题
+ plt.title(title, fontsize=16)
+ # 添加纵横轴的刻度
+ if xticks:
+ plt.xticks(index, xticks)
+ plt.tick_params(axis='both', which='major', labelsize=16)
+ # 添加图例
+ plt.legend(loc="upper right")
+ plt.show()
+
+
+# 绘制饼图
+def plot_pie(values, labels, title='distribute', explode=None, colors=None):
+ # 设置绘图的主题风格
+ plt.style.use("ggplot")
+ # 构造数据
+ if explode is None:
+ explode = [0, 0.05, 0, 0, 0][:len(values)] # 用于突出显示大专学历人群
+
+ if colors is None:
+ colors = ['#9999ff', '#ff9999', '#7777aa', '#2442aa', '#dd5555'][:len(values)]
+
+ # 中文乱码和坐标轴负号处理
+ plt.rcParams['font.sans-serif'] = ['KaiTi']
+ plt.rcParams['axes.unicode_minus'] = False
+ # 将横、纵坐标轴标准化处理,保证饼图是一个正圆,否则为椭圆
+ plt.axes(aspect="equal")
+
+ # 控制x轴和y轴的范围
+ plt.xlim(0, 10)
+ plt.ylim(0, 10)
+
+ # 绘制饼图
+ plt.pie(x=values, # 绘图数据
+ explode=explode, # 突出显示大专人群
+ labels=labels, # 添加教育水标签
+ colors=colors, # 设置饼图的自定义填充色
+ autopct='%.1f%%', # 设置百分比格式,这里保留一位小数
+ pctdistance=0.5, # 设置百分比标签与圆心的距离
+ labeldistance=1.1, # 设置教育水平标签与圆心的距离,1.1指1.1倍半径的位置
+ startangle=180, # 设置饼图的初始角度
+ radius=3.3, # 设置饼图的半径
+ counterclock=False, # 是否逆时针,这里设置为顺时针方向
+ wedgeprops={'linewidth': 0.1, 'edgecolor': 'green'}, # 设置饼图内外边界的属性值
+ textprops={'fontsize': 10, 'color': 'k'}, # 设置文本标签的属性值
+ center=(4, 4), # 设置饼图的原点
+ shadow=False, # 在饼图下面画一个阴影。默认值:False,即不画阴影;
+ frame=1) # 是否显示饼图的图框 ,这里设置显示
+ # 删除x轴和y轴的刻度
+ plt.xticks(())
+ plt.yticks(())
+ plt.legend()
+ # 添加图标签
+ plt.title(title)
+ plt.show() # 显示图形
+
+
+def plot_hist(data, columns, cols=3):
+ data = np.array(data)
+ assert isinstance(columns, (tuple, list)) and data.shape[1] == len(columns)
+
+ df = pd.DataFrame(data, columns=tuple(columns))
+
+ cprint('均值:', Color.BLUE)
+ print(df.mean())
+ cprint('方差:', Color.BLUE)
+ print(df.std())
+ cprint('最小值:', Color.BLUE)
+ print(df.min())
+ cprint('最大值:', Color.BLUE)
+ print(df.max())
+ cprint('995分位:', Color.BLUE)
+ print(df.quantile(0.995))
+ cprint('0.005分位:', Color.BLUE)
+ print(df.quantile(0.005))
+
+ cols = cols if len(columns) >= cols else len(columns)
+ rows, mod = divmod(len(columns), cols)
+ rows += mod != 0
+ for i, col in enumerate(columns):
+ plt.subplot(rows, cols, i + 1)
+ bin_size = (df[col].max() - df[col].min()) / 11
+ df[col].hist(bins=np.arange(df[col].min(), 1.01 * df[col].max(), bin_size))
+ plt.title(col)
+ plt.tight_layout()
+ plt.grid(False)
+ plt.show()
+
+
+class DrawRes:
+
+ color_map = {
+ 'r': (0, 0, 255),
+ 'g': (0, 255, 0),
+ 'b': (255, 0, 0),
+ 'y': (0, 255, 255),
+ 'f': (255, 0, 255),
+ 'w': (255, 255, 255)
+ }
+
+ def __init__(self, project, classes=None):
+ self.colors = ['g', 'r', 'y', 'b', 'w', 'f']
+ self.project = project.lower()
+ if project == 'ps':
+ self.classes = ['driver', 'front', 'rear', 'rear2', 'cover', 'sleep']
+ driver_loc = [0.65, 0.35, 1, 1]
+ front_loc = [0, 0.35, 0.35, 1]
+ rear_loc = [0, 0, 1, 0.3]
+ cover_loc = [0.35, 0.5, 0.7, 1]
+ sleep_loc = [0.35, 0.2, 0.7, 0.6]
+ self.location = [driver_loc, front_loc, rear_loc, cover_loc, sleep_loc]
+ elif project == 'psr':
+ self.classes = ["drive", "front", "rear", "rear2", "left", "middle", "right", "sleep", "cover"]
+ driver_loc = [0.6, 0.3, 1, 1]
+ front_loc = [0, 0.3, 0.4, 1]
+ rear_loc = [0, 0, 1, 0.1]
+ left = [0, 0.1, 0.4, 0.3]
+ mid = [0.4, 0.1, 0.6, 0.3]
+ right = [0.6, 0.1, 1, 0.3]
+ sleep_loc = [0.4, 0.3, 0.6, 0.4]
+ cover_loc = [0.4, 0.4, 0.6, 0.6]
+ self.location = [driver_loc, front_loc, rear_loc, left, mid, right, sleep_loc, cover_loc]
+ else:
+ draw_loc = [0.35, 0.35, 0.6, 0.6]
+ self.location = [draw_loc]
+ self.classes = classes
+ assert self.classes is not None
+ num_classes = len(self.classes)
+ self.colors = self.colors + ['r'] * (num_classes - 6) if num_classes >= 6 else self.colors
+
+ def draw_pred(self, img, pred, prob=None, crop_box=None, use_mask=True, use_cn=False):
+ drawn = img.copy()
+ if crop_box is not None and use_mask:
+ cx1, cy1, cx2, cy2 = crop_box
+ mask = np.full_like(img, fill_value=128, dtype=np.uint8)
+ mask[cy1: cy2, cx1:cx2, :] = 0
+ drawn = cv2.add(drawn, mask)
+
+ if self.project.startswith('ps'):
+ assert len(pred) == len(self.classes), f'person seats {len(self.classes)} result'
+ if prob is not None:
+ assert len(prob) == len(self.classes), f'person seats {len(self.classes)} result'
+ prob = prob[:2] + [prob[3] if (pred[2] + pred[3]) == 2 else prob[2]] + prob[4:]
+ pred = pred[:2] + [pred[2] + pred[3]] + pred[4:]
+ for i, label in enumerate(pred):
+ if prob is None:
+ text = f'{self.classes[i]}: {label}' if i < 3 else f'{self.classes[i+1]}: {label}'
+ else:
+ text = f'{self.classes[i]}: {label} {prob[i]:.3f}' if i < 3 \
+ else f'{self.classes[i+1]}: {label} {prob[i]:.3f}'
+ drawn = self.draw(drawn, loc=i, text=text, use_cn=use_cn)
+ else:
+ text = f"{self.classes[pred]}" if prob is None else f"{self.classes[pred]}: {prob:.3f}"
+ drawn = self.draw(drawn, loc=0, text=text, txt_color=self.colors[pred], use_cn=use_cn)
+
+ return drawn
+
+ def draw(self, img, text, loc, box_color='r', txt_color='r', use_cn=False):
+ if loc < 0 or loc >= len(self.location):
+ return img
+ h, w = img.shape[:2]
+ factor = w / 1280
+
+ left, top = int(self.location[loc][0] * w), int(self.location[loc][1] * h)
+ right, down = int(self.location[loc][2] * w), int(self.location[loc][3] * h)
+ cur_w = right - left
+ cur_h = down - top
+
+ if use_cn:
+ drawn = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ drawer = ImageDraw.Draw(drawn)
+ font_size, thickness = int(40 * factor), int(1*factor)
+ font = ImageFont.truetype(os.path.join(ROOT_DIR, "files/simsun.ttc"), size=font_size, encoding="utf-8")
+ txt_size = drawer.textsize(text, font)
+ else:
+ drawn = img.copy()
+ font_size, thickness = 1*factor, int(2*factor)
+ font = cv2.FONT_HERSHEY_TRIPLEX
+ txt_size = cv2.getTextSize(text, font, font_size, thickness)[0]
+
+ if self.project.startswith('ps'):
+ if loc < len(self.location) - 2:
+ if use_cn:
+ drawer.rectangle(((left, top), (right, down)), outline=self.color_map[box_color][::-1], width=2*thickness)
+ else:
+ cv2.rectangle(drawn, pt1=(left, top), pt2=(right, down), color=self.color_map[box_color], thickness=thickness)
+ text_loc = (left + (cur_w - txt_size[0]) // 2, top + (cur_h - txt_size[1]//2) // 2)
+ else:
+ text_loc = ((w - txt_size[0]) // 2, (top + down) // 2)
+ else:
+ text_loc = ((w - txt_size[0]) // 2, (h - txt_size[1]//2) // 2)
+
+ if use_cn:
+ drawn = put_chinese_text(drawn, text, text_loc, font, self.color_map[txt_color][::-1])
+ else:
+ cv2.putText(drawn, text, text_loc, fontFace=font, fontScale=font_size, color=self.color_map[txt_color], thickness=thickness)
+
+ return drawn
+
+ def old_draw_top2(self, rect, probs, factor=1, use_cn=False):
+ drawn = cv2.resize(rect, dsize=None, fx=2*factor, fy=2*factor) # 输入640时factor为1,1280 2
+ b, g, r, a = cv2.split(drawn)
+ drawn = cv2.merge([b, g, r])
+ h, w = drawn.shape[:2]
+ id1, id2 = np.argsort(probs)[::-1][:2]
+ class1, class2 = self.classes[id1], self.classes[id2]
+ max_len = max(len(class1), len(class2))
+ text1 = f"{class1:{max_len}s}: {probs[id1]:.3f}"
+ text2 = f"{class2:{max_len}s}: {probs[id2]:.3f}"
+
+ if use_cn:
+ font_size, thickness = int(50*factor), 1
+ font = ImageFont.truetype(os.path.join(ROOT_DIR, "files/simhei.ttf"), size=font_size, encoding="utf-8")
+ drawn = Image.fromarray(cv2.cvtColor(drawn, cv2.COLOR_BGR2RGB))
+ drawer = ImageDraw.Draw(drawn)
+ txt_size1 = drawer.textsize(text1, font)
+ txt_size2 = drawer.textsize(text2, font)
+ start_w = (w - txt_size1[0]) // 2
+ start_h = (h - txt_size1[1] - txt_size2[1]) // 3
+ txt_loc1 = (start_w, start_h)
+ txt_loc2 = (start_w, h - start_h - txt_size2[1])
+ drawn = put_chinese_text(drawn, text1, txt_loc1, font, self.color_map['g'][::-1])
+ drawn = put_chinese_text(drawn, text2, txt_loc2, font, self.color_map['w'][::-1])
+ else:
+ font = cv2.FONT_HERSHEY_TRIPLEX
+ font_size, thickness = 1.2*factor, 1
+ txt_size1 = cv2.getTextSize(text1, font, font_size, thickness)[0]
+ txt_size2 = cv2.getTextSize(text2, font, font_size, thickness)[0]
+ start_w = (w - txt_size1[0]) // 2
+ start_h = (h - txt_size1[1] - txt_size2[1]) // 3
+ txt_loc1 = (start_w, start_h+txt_size1[1])
+ txt_loc2 = (start_w, h - start_h)
+ cv2.putText(drawn, text1, txt_loc1, fontFace=font, fontScale=font_size,
+ color=self.color_map['g'], thickness=thickness)
+ cv2.putText(drawn, text2, txt_loc2, fontFace=font, fontScale=font_size,
+ color=self.color_map['w'], thickness=thickness)
+
+ drawn = cv2.merge([drawn, a])
+ return drawn
+
+ def draw_top2(self, img, probs, crop_box=None, use_mask=True, use_frame=True, use_cn=False, txt_loc=None):
+ id1, id2 = np.argsort(probs)[::-1][:2]
+ class1, class2 = self.classes[id1], self.classes[id2]
+ max_len = max(len(class1), len(class2))
+ text1 = f"{class1:{max_len}s}: {probs[id1]:.3f}"
+ text2 = f"{class2:{max_len}s}: {probs[id2]:.3f}"
+ drawn = draw_texts(img, [text1, text2], crop_box=crop_box, use_mask=use_mask, use_frame=use_frame,
+ use_cn=use_cn, txt_colors=[self.color_map['g'], self.color_map['w']], txt_rect_loc=txt_loc)
+ return drawn
+
+
+class DrawMTL:
+ def __init__(self, project, tasks, task_types, classes):
+ self.project = project.lower()
+ self.tasks = tasks
+ self.task_types = task_types
+ self.classes = classes
+
+ def draw_result(self, img, preds, probs, box=None, crop_box=None, use_mask=True, use_frame=True, use_cn=False,
+ box_color=(0, 255, 0), txt_color=(100, 255, 0), font=cv2.FONT_HERSHEY_TRIPLEX):
+ assert len(preds) == len(self.tasks) and len(preds) == len(probs)
+ text_list = []
+ for ti, (pred, sub_probs) in enumerate(zip(preds, probs)):
+ assert len(sub_probs) == len(self.classes[ti]), f"{sub_probs} length != {self.classes[ti]} length"
+ if self.task_types[ti] == 0:
+ text_list.append(f'{self.tasks[ti].upper()}: {self.classes[ti][pred]} {sub_probs[pred]:.3f}')
+ else:
+ text_list.append(f'{self.tasks[ti].upper()}: {pred:.3f}')
+ drawn = draw_texts(img, text_list, box, crop_box, use_mask, use_frame, use_cn, box_color, txt_color, None, font)
+ return drawn
+
+ @staticmethod
+ def draw_ind(img, txt='No Driver!', use_frame=True, use_cn=False,
+ color=(0, 0, 255), font=cv2.FONT_HERSHEY_TRIPLEX):
+ drawn = draw_texts(img, txt, use_frame=use_frame, use_cn=use_cn, txt_colors=color, font=font)
+ return drawn
+
+
+def draw_texts(img, text_list, box=None, base_box=None, use_mask=True, use_frame=True, use_cn=False,
+ box_color=(0, 255, 0), txt_colors=(100, 255, 0), txt_rect_loc=None, font=cv2.FONT_HERSHEY_TRIPLEX):
+
+ h, w = img.shape[:2]
+ drawn = img.copy()
+
+ if box is not None:
+ x1, y1, x2, y2 = box
+ cv2.rectangle(drawn, (x1, y1), (x2, y2), (0, 0, 255), thickness=w//640)
+
+ if base_box is not None:
+ cx1, cy1, cx2, cy2 = base_box
+ cv2.rectangle(drawn, (cx1, cy1), (cx2, cy2), box_color, thickness=w//640+1)
+ if use_mask:
+ # drawn[60:300, 0:210, :] = cv2.GaussianBlur(drawn[60:300, 0:210, :], (51, 51), 0) # 高斯模糊API(模糊 可改参数)高斯核函数必须是奇数
+ # mask = np.full_like(img, fill_value=128, dtype=np.uint8)
+ # mask[cy1: cy2, cx1:cx2, :] = 0
+ # drawn = cv2.subtract(drawn, mask)
+ mask = np.full_like(img, fill_value=128, dtype=np.uint8)
+ mask[cy1: cy2, cx1:cx2, :] = 0
+ drawn = cv2.add(drawn, mask)
+
+ if isinstance(text_list, str) or not isinstance(text_list, collections.Sequence):
+ text_list = [text_list]
+ text_list = [str(x) for x in text_list]
+ max_len_txt = sorted(text_list, key=lambda x: -len(x))[0]
+
+ if isinstance(txt_colors[0], int):
+ txt_colors = [txt_colors for _ in text_list]
+ if len(txt_colors) < len(text_list):
+ txt_colors = txt_colors * (1 + len(text_list)//len(txt_colors))
+
+ factor = w / 1280
+ if use_cn:
+ drawer = ImageDraw.Draw(Image.fromarray(cv2.cvtColor(drawn, cv2.COLOR_BGR2RGB)))
+ font_size, thickness, space = int((40 + round(15/len(text_list))) * factor), 1, int(10 * factor)
+ font = ImageFont.truetype(os.path.join(ROOT_DIR, "files/simhei.ttf"), size=font_size, encoding="utf-8")
+ txt_size = drawer.textsize(max_len_txt, font)
+ else:
+ font_size, thickness, space = (1 + round(1/len(text_list))) * factor, 1 + round(1/len(text_list)), int(20 * factor)
+ txt_size = cv2.getTextSize(max_len_txt, font, font_size, thickness)[0]
+
+ info_rect = cv2.resize(INFO, (txt_size[0] + 2 * space, len(text_list) * txt_size[1] + (len(text_list) + 1) * space))
+ b, g, r, a = cv2.split(info_rect)
+ info_rect = cv2.merge([b, g, r])
+ for i, text in enumerate(text_list):
+ if use_cn:
+ info_rect = put_chinese_text(info_rect, text, (space, i*(txt_size[1]+space)+space), font, txt_colors[i][::-1])
+ else:
+ cv2.putText(info_rect, text, (space, (i + 1) * (txt_size[1] + space)), font, font_size, txt_colors[i], thickness)
+ info_rect = cv2.merge([info_rect, a])
+
+ if use_frame: # 添加外框
+ frame_side = cv2.resize(FRAME, (w, h))
+ drawn = paste(drawn, frame_side, (w // 2, h // 2), 1)
+
+ if txt_rect_loc is None: # 添加文本区域
+ txt_rect_loc = (w//2, h//2) if base_box is None else (base_box[0] // 2, (base_box[1] + base_box[3]) // 2)
+ drawn = paste(drawn, info_rect, txt_rect_loc, 0.9, refine=True)
+
+ return drawn
+
+
+def plot_scores(classes, score_array, title='result', x_label='frame', y_label='score', save_path=None, show=True):
+ num, num_class = score_array.shape
+ assert len(classes) == num_class
+ colors = ['green', 'red', 'blue', 'aqua', 'yellow', 'hotpink']
+ if num_class > len(colors):
+ colors = colors * (1 + num_class // len(colors))
+
+ cols = 3 if num_class >= 3 else num_class
+ rows, mod = divmod(num_class, cols)
+ rows += mod != 0
+ plt.figure(figsize=(9, 2*rows))
+ plt.title(title, fontsize=11)
+ plt.xlabel(x_label, fontsize=12)
+ plt.ylabel(y_label, fontsize=12)
+
+ x = list(range(0, num))
+ for i in range(num_class):
+ plt.subplot(rows, cols, i+1)
+ plt.plot(x, score_array[:, i], color=colors[i], label=classes[i], linewidth=0.5)
+ plt.tick_params(axis='both', which='major', labelsize=5)
+ x_major_locator = MultipleLocator(50) # 把x轴的刻度间隔设置为50,并存在变量里
+ y_major_locator = MultipleLocator(0.05) # 把y轴的刻度间隔设置为0,05,并存在变量里
+ ax = plt.gca() # ax为两条坐标轴的实例
+ ax.xaxis.set_major_locator(x_major_locator) # 设置x轴的主刻度
+ ax.yaxis.set_major_locator(y_major_locator)
+ plt.xlim(-10, num+10) # 把x轴的刻度范围设置为-10到580,因为-10不满一个刻度间隔,所以数字不会显示出来,但是能看到一点空白
+ plt.ylim(-0.01, 1.01)
+ plt.legend() # 显示图例
+ plt.grid()
+
+ if save_path:
+ plt.savefig(save_path)
+
+ # if show:
+ # plt.show()
+
+
+def plot_scores_mtl(tasks, task_types, classes, score_array, title='result', x_label='frame', y_label='score',
+ save_dir=None, save_path=None, show=True):
+
+ if save_path is not None:
+ save_fig_path = save_path
+ os.makedirs(save_fig_path, exist_ok=True)
+ elif save_dir is not None:
+ save_fig_path = f"{save_dir}/{os.path.splitext(title)[0]}"
+ os.makedirs(save_fig_path, exist_ok=True)
+ else:
+ save_fig_path = None
+
+ num, total_num_class = score_array.shape
+ sub_class_length = [len(sub) for sub in classes]
+ assert len(classes) == len(tasks)
+ assert sum(sub_class_length) == total_num_class
+
+ colors = ['green', 'red', 'blue', 'aqua', 'yellow', 'hotpink']
+ if max(sub_class_length) > len(colors):
+ colors = colors * (1 + max(sub_class_length)//len(colors))
+
+ # 合并两个眼睑距离
+ array_list, st_idx = [], 0
+ for sub_classes in classes:
+ array_list.append(score_array[:, st_idx: st_idx + len(sub_classes)])
+ st_idx += len(sub_classes)
+
+ x = list(range(0, num))
+ for ti, task in enumerate(tasks):
+ plt.figure(figsize=(12, 7))
+ sub_array, sub_classes = array_list[ti], classes[ti]
+ for i, sc in enumerate(sub_classes):
+ plt.plot(x, sub_array[:, i], color=colors[i], label=sc, linewidth=0.5)
+ plt.tick_params(axis='both', which='major', labelsize=5)
+ plt.xlim(-10, num + 10) # 把x轴的刻度范围设置为-10到580,因为-10不满一个刻度间隔,所以数字不会显示出来,但是能看到一点空白
+ ax = plt.gca() # ax为两条坐标轴的实例
+ x_major_locator = MultipleLocator(50) # 把x轴的刻度间隔设置为50,并存在变量里
+ ax.xaxis.set_major_locator(x_major_locator) # 设置x轴的主刻度
+ if task_types[ti] == 0:
+ plt.ylim(-0.01, 1.01)
+ y_major_locator = MultipleLocator(0.05) # 把y轴的刻度间隔设置为0,05,并存在变量里
+ ax.yaxis.set_major_locator(y_major_locator)
+ plt.title(task.upper(), fontsize=11)
+ plt.xlabel(x_label, fontsize=12)
+ plt.ylabel(y_label, fontsize=12)
+ plt.legend() # 显示图例
+ plt.grid()
+
+ if save_fig_path is not None:
+ plt.savefig(f"{save_fig_path}/{task}.jpg")
+
+
+def syn_plot_scores(classes, indexes, score_array, cur_time='', width=320, height=480, factor=1, window_length=None,
+ title='result', x_label='frame', y_label='score'):
+ # plt.style.use('dark_background')
+ backgroud = np.zeros((height, width, 3), dtype=np.uint8)
+ if cur_time:
+ font = cv2.FONT_HERSHEY_DUPLEX
+ font_size, thickness = 1.4*factor, 1
+ txt_size = cv2.getTextSize(cur_time, font, font_size, thickness)[0]
+ txt_loc = ((width-txt_size[0])//2, 30+txt_size[1])
+ cv2.putText(backgroud, cur_time, txt_loc, fontFace=font, fontScale=font_size,
+ color=(255, 255, 255), thickness=thickness)
+
+ num, num_class = score_array.shape
+ assert len(classes) == num_class
+
+ colors = ['green', 'red', 'blue', 'aqua', 'yellow', 'hotpink']
+ pots = ['*', 'o', 's', '^', '+', 'x']
+ if num_class > len(colors):
+ colors = colors * (1 + num_class // len(colors))
+ if num_class > len(pots):
+ pots = pots * (1 + num_class // len(pots))
+
+ use_array = score_array.copy()
+ if isinstance(window_length, int) and num > window_length:
+ use_array = use_array[-window_length:]
+ indexes = indexes[-window_length:]
+ num = window_length
+
+ space = width // 10
+
+ cols = 1
+ rows, mod = divmod(num_class, 4)
+ rows += mod != 0
+ fig = plt.figure(figsize=(8, 4 * rows), dpi=80, frameon=True)
+ # plt.title(title)
+ # plt.xlabel(x_label)
+ # plt.ylabel(y_label)
+
+ for i in range(rows):
+ if rows > 1:
+ plt.subplot(rows, cols, i + 1)
+ sub_array = use_array[:, 4*i:4*(i+1)] if i < rows-1 else use_array[:, 4*i:]
+ for j in range(sub_array.shape[1]):
+ plt.plot(indexes, sub_array[:, j], color=colors[j], label=classes[4*i+j],
+ linewidth=0.5, marker=pots[j], ms=6, zorder=3)
+ # x_major_locator = MultipleLocator(space) # 把x轴的刻度间隔设置为50,并存在变量里
+ # y_major_locator = MultipleLocator(0.05) # 把y轴的刻度间隔设置为0,05,并存在变量里
+ # ax = plt.gca() # ax为两条坐标轴的实例
+ # ax.xaxis.set_major_locator(x_major_locator) # 设置x轴的主刻度
+ # ax.yaxis.set_major_locator(y_major_locator)
+ plt.tick_params(axis='both', which='major', labelsize=int(18*factor))
+ plt.yticks(fontsize=int(24*factor))
+ plt.xticks(rotation=90, fontsize=int(16*factor))
+ plt.xlim(-space // 3, num + space // 3) # x轴的刻度范围设置
+ plt.ylim(-0.01, 1.01)
+ plt.legend(shadow=True) # 显示图例
+ # plt.margins(x=0)
+ plt.axis('tight')
+ ax = plt.gca()
+ # ax.spines['top'].set_color('none')
+ ax.spines['left'].set_color('none')
+ ax.spines['right'].set_color('none')
+
+ buffer = io.BytesIO()
+ canvas = fig.canvas
+ canvas.print_png(buffer)
+ data = buffer.getvalue()
+ buffer.write(data)
+ chart = Image.open(buffer)
+ chart = np.array(chart)
+ chart = cv2.cvtColor(chart, cv2.COLOR_RGB2BGR)
+ h, w = chart.shape[:2]
+ chart = chart[30:h, 30:w-30, :3]
+ h, w = chart.shape[:2]
+ f = max(w/width, h/height)
+ n_w, n_h = int(w/f), int(h/f)
+ chart = cv2.resize(chart, (n_w, n_h))
+ backgroud[(height-n_h+30)//2:(height-n_h+30)//2+n_h, (width-n_w)//2: (width-n_w)//2+n_w] = chart
+
+ return backgroud
+
+
+def syn_plot_scores_mtl(tasks, task_types_index, classes, indexes, score_array, width=320, height=480, window_length=None,
+ factor=1, cur_time='', title='result', x_label='frame', y_label='score'):
+
+ # plt.style.use('dark_background')
+ backgroud = np.zeros((height, width, 3), dtype=np.uint8)
+
+ num, total_num_class = score_array.shape
+ sub_class_length = [len(sub) for sub in classes]
+ assert len(classes) == len(tasks)
+ assert sum(sub_class_length) == total_num_class
+
+ colors = ['green', 'red', 'blue', 'aqua', 'yellow', 'hotpink']
+ pots = ['*', 'o', 's', '^', '+', 'x']
+ if max(sub_class_length) > len(colors):
+ colors = colors * (1 + max(sub_class_length) // len(colors))
+ if max(sub_class_length) > len(pots):
+ pots = pots * (1 + max(sub_class_length) // len(pots))
+
+ use_array = score_array.copy()
+ if isinstance(window_length, int) and num > window_length:
+ use_array = use_array[-window_length:]
+ indexes = indexes[-window_length:]
+ num = window_length
+
+ # 合并两个眼睑距离
+ ri, li = task_types_index[1]
+ tasks = tasks[:ri] + ['distance']
+ classes = classes[:ri] + [['dist_r', 'dist_l']]
+ array_list, st_idx = [], 0
+ for sub_classes in classes:
+ array_list.append(use_array[:, st_idx: st_idx+len(sub_classes)])
+ st_idx += len(sub_classes)
+
+ focus_task_ids = [0, 1, ri]
+ fig = plt.figure(figsize=(8, 4 * len(focus_task_ids)), dpi=80, frameon=True)
+ # plt.title(title)
+ # plt.xlabel(x_label)
+ # plt.ylabel(y_label)
+ for i, idx in enumerate(focus_task_ids):
+ plt.subplot(len(focus_task_ids), 1, i + 1)
+ sub_array, sub_classes, task = array_list[idx], classes[idx], tasks[idx]
+ for j in range(sub_array.shape[1]):
+ plt.plot(indexes, sub_array[:, j], color=colors[j], label=sub_classes[j],
+ linewidth=0.5, marker=pots[j], ms=6, zorder=3)
+ # x_major_locator = MultipleLocator(space) # 把x轴的刻度间隔设置为50,并存在变量里
+ # y_major_locator = MultipleLocator(0.05) # 把y轴的刻度间隔设置为0,05,并存在变量里
+ # ax = plt.gca() # ax为两条坐标轴的实例
+ # ax.xaxis.set_major_locator(x_major_locator) # 设置x轴的主刻度
+ # ax.yaxis.set_major_locator(y_major_locator)
+ plt.title(task.upper())
+ plt.tick_params(axis='both', which='major', labelsize=int(18*factor))
+ plt.yticks(fontsize=int(24*factor))
+ plt.xticks(rotation=90, fontsize=int(16*factor))
+ plt.xlim(-width // 30, num + width // 30) # x轴的刻度范围设置
+ plt.ylim(-0.01, 1.01)
+ plt.legend(shadow=True) # 显示图例
+ # plt.margins(x=0)
+ plt.axis('tight')
+ ax = plt.gca()
+ # ax.spines['top'].set_color('none')
+ ax.spines['left'].set_color('none')
+ ax.spines['right'].set_color('none')
+
+ buffer = io.BytesIO()
+ canvas = fig.canvas
+ canvas.print_png(buffer)
+ data = buffer.getvalue()
+ buffer.write(data)
+ chart = Image.open(buffer)
+ chart = cv2.cvtColor(np.array(chart), cv2.COLOR_RGB2BGR)
+ h, w = chart.shape[:2]
+ chart = chart[60:h-60, 40:w-40, :3]
+ h, w = chart.shape[:2]
+ f = max(w/width, h/height)
+ n_w, n_h = int(w/f), int(h/f)
+ chart = cv2.resize(chart, (n_w, n_h))
+ backgroud[(height-n_h)//2: (height-n_h)//2+n_h, (width-n_w)//2: (width-n_w)//2+n_w] = chart
+
+ # if cur_time:
+ # font = cv2.FONT_HERSHEY_DUPLEX
+ # font_size, thickness = 1.4*factor, 1
+ # txt_size = cv2.getTextSize(cur_time, font, font_size, thickness)[0]
+ # txt_loc = ((width-txt_size[0])//2, 30+txt_size[1])
+ # cv2.putText(backgroud, cur_time, txt_loc, fontFace=font, fontScale=font_size,
+ # color=(255, 255, 255), thickness=thickness)
+
+ return backgroud
+
+
+if __name__ == '__main__':
+ plot_bar(values=[4310, 4069, 4104, 4455, 4152, 4120, 4040], xlabel='', ylabel='', title='B100 dataset copy',
+ xticks=['normal', 'head_right', 'head_left', 'head_down', 'eye_left', "eye_down", 'invalid'])
+
+ # lst = [42791, 18309]
+ # for item in lst:
+ # print(item/sum(lst))
+
+ # ratio = [0.3962, 0.6013, 0.0025]
+ # plot_pie(values=ratio, labels=['B200', 'B100', 'Other'])
+ #
+ # ratio = [0.48, 0.52]
+ # plot_pie(values=ratio, labels=['Day', 'Night'], title='time distribute')
diff --git a/utils/time_util.py b/utils/time_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e2b9c7642a6491358af9e0bfec5e847878aceff
--- /dev/null
+++ b/utils/time_util.py
@@ -0,0 +1,156 @@
+# -*- coding:utf-8 –*-
+import time
+from datetime import datetime, timedelta
+
+
+def convert_input_time(input_time, digit=13):
+ if isinstance(input_time, (int, float)):
+ stamp = stamp_pad(input_time, digit)
+ return stamp
+ elif isinstance(input_time, str):
+ format_str = "%Y-%m-%d %H:%M:%S"
+ if '/' in input_time:
+ format_str = "%Y/%m/%d %H:%M:%S"
+ elif ' ' not in input_time and ':' not in input_time:
+ format_str = '%Y%m%d%H%M%S'
+ stamp = convert_time_str(input_time, format_str, digit)
+ return stamp
+ else:
+ raise ValueError('UnSupport type')
+
+
+def stamp_pad(stamp, digit=10):
+ stamp = int(stamp)
+ length = len(str(stamp))
+ if length < digit:
+ stamp = int(str(stamp) + '0'*(digit-length))
+ if length > digit:
+ stamp = int(str(stamp)[:digit])
+ return stamp
+
+
+def convert_stamp(stamp, format_str="%Y-%m-%d %H:%M:%S"):
+ date_time = time.strftime(format_str, time.localtime(stamp_pad(stamp, digit=10)))
+ return date_time
+
+
+def convert_time_str(time_str, format_str="%Y-%m-%d %H:%M:%S", digit=13):
+ t = datetime.strptime(time_str, format_str).timestamp()
+ return stamp_pad(t, digit)
+
+
+def transform_format(time_str, dst_fmt="%Y-%m-%d %H:%M:%S"):
+ stamp = convert_input_time(time_str, digit=10)
+ return convert_stamp(stamp, dst_fmt)
+
+
+# @clock_custom('[{elapsed:0.8f}s] {name}()')
+# 计算两个时间戳的差值
+def cal_stamps_diff(stamp1, stamp2):
+ stamp1 = stamp_pad(stamp1)
+ stamp2 = stamp_pad(stamp2)
+ if stamp1 <= stamp2:
+ diff = datetime.fromtimestamp(stamp2) - datetime.fromtimestamp(stamp1)
+ return diff.seconds
+ else:
+ diff = datetime.fromtimestamp(stamp1) - datetime.fromtimestamp(stamp2)
+ return -diff.seconds
+
+
+# 速度快
+# @clock_custom('[{elapsed:0.8f}s] {name}()')
+def cal_diff_seconds(stamp1, stamp2):
+ stamp1 = stamp_pad(stamp1)
+ stamp2 = stamp_pad(stamp2)
+ return stamp2-stamp1
+
+
+# 计算数小时前的时间戳
+def cal_yesterday_stamp(delta_hours=24):
+ now = datetime.now()
+ delta = timedelta(hours=delta_hours)
+ before = (now-delta).timestamp()
+ return stamp_pad(before)
+
+
+def get_now_date():
+ return time.strftime('%m%d', time.localtime(time.time()))
+
+
+def get_today():
+ return datetime.now().strftime('%Y%m%d')
+
+
+def get_days_before(days=5):
+ return (datetime.now() - timedelta(days=days)).strftime('%Y%m%d')
+
+
+def get_date(day=0, time_str='%Y%m%d'):
+ if day == 0:
+ return datetime.now().strftime(time_str)
+ elif day > 0:
+ return (datetime.now() - timedelta(days=day)).strftime(time_str)
+
+
+def get_season():
+ month = int(datetime.now().strftime('%m'))
+ if month <= 2 or month == 12:
+ return 4
+ elif month <= 5:
+ return 1
+ elif month <= 8:
+ return 2
+ else:
+ return 3
+
+
+class Timer:
+ def __init__(self):
+ # store the start time, end time, and total number of frames
+ # that were examined between the start and end intervals
+ self._start = None
+ self._end = None
+ self._numFrames = 0
+
+ def start(self):
+ # start the timer
+ self._start = datetime.now()
+ return self
+
+ def stop(self):
+ # stop the timer
+ self._end = datetime.now()
+
+ def update(self):
+ # increment the total number of frames examined during the
+ # start and end intervals
+ self._numFrames += 1
+
+ def elapsed(self):
+ # return the total number of seconds between the start and
+ # end interval
+ return (self._end - self._start).total_seconds()
+
+ def fps(self):
+ # compute the (approximate) frames per second
+ return self._numFrames / self.elapsed()
+
+
+if __name__ == '__main__':
+ print(convert_stamp(1625465966842))
+ print(convert_time_str('2020-10-02 05:36:00'))
+ print(convert_time_str('2020-10-02 06:46:00'))
+ print(convert_input_time('20201208102335', digit=10))
+ # print(convert_stamp(1595945099609, "%Y%m%d%H%M%S"))
+ # stamp1 = 1594288522155
+ # stamp2 = 1594288582019
+ # cal_diff_seconds(stamp1, stamp2)
+ # cal_stamps_diff(stamp1, stamp2)
+ # print(get_now_date())
+ # print((datetime.datetime.now()-datetime.timedelta(days=5)).strftime('%Y%m%d'))
+"""
+{"detectId":11,"startTime":1595945079609,"endTime":1595945119609,"commandId":"738411578604601345","deviceId":"03agptmyb2be9a2c","content":"{\"content\":\"\",\"fileType\":5,\"lat\":23.012758891079372,\"lng\":113.22142671753562,\"startTime\":1595945099670}","timestamp":1595945099609}
+{"detectId":11,"startTime":1595945088792,"endTime":1595945128792,"commandId":"738412534595338240","deviceId":"03agptmyb2be9a2c","content":"{\"content\":\"\",\"fileType\":5,\"lat\":23.012758891079372,\"lng\":113.22142671753562,\"startTime\":1595945108854}","timestamp":1595945108792}
+{"detectId":11,"startTime":1595945115257,"endTime":1595945155257,"commandId":"738414926988763136","deviceId":"03agph2zhgb56534","content":"{\"content\":\"\",\"fileType\":5,\"lat\":23.105039173700952,\"lng\":113.136789077983,\"startTime\":1595945135322}","timestamp":1595945135257}
+""" \
+""