George commited on
Commit
a22775d
1 Parent(s): bb98044

Upload 47 files

Browse files
Files changed (47) hide show
  1. .gitattributes +1 -0
  2. .gitignore +4 -0
  3. README.md +35 -13
  4. demo.py +63 -0
  5. images/Coccinella_septempunctata.jpg +3 -0
  6. insectid/base.py +51 -0
  7. insectid/detector.py +56 -0
  8. insectid/identifier.py +71 -0
  9. insectid/models/quarrying_insect_detector.onnx +3 -0
  10. insectid/models/quarrying_insect_identifier.onnx +3 -0
  11. insectid/models/quarrying_insectid_label_map.txt +0 -0
  12. khandy/boxes/boxes_and_indices.py +68 -0
  13. khandy/boxes/boxes_clip.py +34 -0
  14. khandy/boxes/boxes_coder.py +69 -0
  15. khandy/boxes/boxes_convert.py +101 -0
  16. khandy/boxes/boxes_filter.py +113 -0
  17. khandy/boxes/boxes_overlap.py +166 -0
  18. khandy/boxes/boxes_transform_flip.py +135 -0
  19. khandy/boxes/boxes_transform_rotate.py +140 -0
  20. khandy/boxes/boxes_transform_scale.py +86 -0
  21. khandy/boxes/boxes_transform_translate.py +136 -0
  22. khandy/boxes/boxes_utils.py +28 -0
  23. khandy/dict_utils.py +168 -0
  24. khandy/draw_utils.py +148 -0
  25. khandy/feature_utils.py +62 -0
  26. khandy/file_io_utils.py +87 -0
  27. khandy/fs_utils.py +375 -0
  28. khandy/hash_utils.py +25 -0
  29. khandy/image/align_and_crop.py +60 -0
  30. khandy/image/crop_or_pad.py +138 -0
  31. khandy/image/flip.py +72 -0
  32. khandy/image/image_hash.py +69 -0
  33. khandy/image/misc.py +329 -0
  34. khandy/image/resize.py +177 -0
  35. khandy/image/rotate.py +72 -0
  36. khandy/image/translate.py +57 -0
  37. khandy/label/detect.py +582 -0
  38. khandy/list_utils.py +68 -0
  39. khandy/misc.py +245 -0
  40. khandy/numpy_utils.py +173 -0
  41. khandy/points/pts_letterbox.py +19 -0
  42. khandy/points/pts_transform_scale.py +33 -0
  43. khandy/split_utils.py +73 -0
  44. khandy/text_utils.py +33 -0
  45. khandy/time_utils.py +101 -0
  46. khandy/version.py +3 -0
  47. requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/Coccinella_septempunctata.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ _local/
3
+ *.pyc
4
+ local_models_*/
README.md CHANGED
@@ -1,13 +1,35 @@
1
- ---
2
- title: Insecta
3
- emoji: 📉
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.42.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 特性
2
+ - 支持 2037 类 (可能是目, 科, 属或种等) 昆虫或其他节肢动物
3
+ - 模型开源, 持续更新.
4
+
5
+ # 安装
6
+ 先安装 Anaconda, 然后执行
7
+ ```
8
+ git clone https://github.com/quarrying/quarrying-insect-id.git
9
+ cd quarrying-insect-id
10
+ conda create -n insectid python=3.6 -y
11
+ conda activate insectid
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ # 用法
16
+
17
+ 参考 [demo.py](<demo.py>), 也可以在我的个人网站 (<https://www.quarryman.cn/insect>) 体验识别效果.
18
+
19
+
20
+ # ChangeLog
21
+
22
+ - 20211204 更新识别模型, 支持 2037 个昆虫分类单元, top1/top5 准确率为 0.922/0.981.
23
+ - 20211125 更新检测模型.
24
+ - 20211018 更新检测模型.
25
+ - 20211011 更新检测模型.
26
+ - 20211009 更新识别模型, 支持 1702 个昆虫分类单元, top1/top5 准确率为 0.915/0.973.
27
+ - 20210920 更新识别模型, 支持 1534 个昆虫分类单元.
28
+ - 20210908 更新识别模型, 支持 1372 个昆虫分类单元.
29
+ - 20210825 更新识别模型, 支持 1234 个昆虫分类单元.
30
+ - 20210815 更新识别模型, 支持 1068 个昆虫分类单元.
31
+ - 20210801 更新识别模型, 支持 868 个昆虫分类单元.
32
+ - 20210713 更新检测模型.
33
+ - 20210712 更新识别模型, 支持 840 个昆虫分类单元.
34
+ - 20210704 更新识别模型, 支持 820 个昆虫分类单元.
35
+ - 20210701 发布第一版模型, 支持 786 个昆虫分类单元.
demo.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import cv2
5
+ import khandy
6
+ import numpy as np
7
+
8
+ from insectid import InsectDetector
9
+ from insectid import InsectIdentifier
10
+
11
+
12
+ if __name__ == '__main__':
13
+ src_dirs = [r'images']
14
+
15
+ detector = InsectDetector()
16
+ identifier = InsectIdentifier()
17
+ src_filenames = sum([khandy.list_files_in_dir(src_dir, True)
18
+ for src_dir in src_dirs], [])
19
+ src_filenames = sorted(
20
+ src_filenames, key=lambda t: os.stat(t).st_mtime, reverse=True)
21
+
22
+ for k, filename in enumerate(src_filenames):
23
+ print('[{}/{}] {}'.format(k+1, len(src_filenames), filename))
24
+ start_time = time.time()
25
+ image = khandy.imread(filename)
26
+ if image is None:
27
+ continue
28
+ if max(image.shape[:2]) > 1280:
29
+ image = khandy.resize_image_long(image, 1280)
30
+ image_for_draw = image.copy()
31
+ image_height, image_width = image.shape[:2]
32
+
33
+ boxes, confs, classes = detector.detect(image)
34
+ for box, conf, class_ind in zip(boxes, confs, classes):
35
+ box = box.astype(np.int32)
36
+ box_width = box[2] - box[0] + 1
37
+ box_height = box[3] - box[1] + 1
38
+ if box_width < 30 or box_height < 30:
39
+ continue
40
+
41
+ cropped = khandy.crop_or_pad(image, box[0], box[1], box[2], box[3])
42
+ results = identifier.identify(cropped)
43
+ print(results[0])
44
+ prob = results[0]['probability']
45
+ if prob < 0.10:
46
+ text = 'Unknown'
47
+ else:
48
+ text = '{}: {:.3f}'.format(
49
+ results[0]['chinese_name'], results[0]['probability'])
50
+ position = [box[0] + 2, box[1] - 20]
51
+ position[0] = min(max(position[0], 0), image_width)
52
+ position[1] = min(max(position[1], 0), image_height)
53
+ cv2.rectangle(image_for_draw,
54
+ (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
55
+ image_for_draw = khandy.draw_text(image_for_draw, text, position,
56
+ font='simsun.ttc', font_size=15)
57
+
58
+ print('Elapsed: {:.3f}s'.format(time.time() - start_time))
59
+ cv2.imshow('image', image_for_draw)
60
+ key = cv2.waitKey(0)
61
+ if key == 27:
62
+ cv2.destroyAllWindows()
63
+ break
images/Coccinella_septempunctata.jpg ADDED

Git LFS Details

  • SHA256: 831ddaf998d6a116fcc7aadc273df9a1f5436ecb0f8fecfafd4946d7b8081f13
  • Pointer size: 132 Bytes
  • Size of remote file: 3.97 MB
insectid/base.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+
4
+
5
+ class OnnxModel(object):
6
+ def __init__(self, model_path):
7
+ sess_options = onnxruntime.SessionOptions()
8
+ # # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.
9
+ # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
10
+ # # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
11
+ # sess_options.intra_op_num_threads = multiprocessing.cpu_count()
12
+ onnx_gpu = (onnxruntime.get_device() == 'GPU')
13
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if onnx_gpu else ['CPUExecutionProvider']
14
+ self.sess = onnxruntime.InferenceSession(model_path, sess_options, providers=providers)
15
+ self._input_names = [item.name for item in self.sess.get_inputs()]
16
+ self._output_names = [item.name for item in self.sess.get_outputs()]
17
+
18
+ @property
19
+ def input_names(self):
20
+ return self._input_names
21
+
22
+ @property
23
+ def output_names(self):
24
+ return self._output_names
25
+
26
+ def forward(self, inputs):
27
+ to_list_flag = False
28
+ if not isinstance(inputs, (tuple, list)):
29
+ inputs = [inputs]
30
+ to_list_flag = True
31
+ input_feed = {name: input for name, input in zip(self.input_names, inputs)}
32
+ outputs = self.sess.run(self.output_names, input_feed)
33
+ if (len(self.output_names) == 1) and to_list_flag:
34
+ return outputs[0]
35
+ else:
36
+ return outputs
37
+
38
+
39
+ def check_image_dtype_and_shape(image):
40
+ if not isinstance(image, np.ndarray):
41
+ raise Exception(f'image is not np.ndarray!')
42
+
43
+ if isinstance(image.dtype, (np.uint8, np.uint16)):
44
+ raise Exception(f'Unsupported image dtype, only support uint8 and uint16, got {image.dtype}!')
45
+ if image.ndim not in {2, 3}:
46
+ raise Exception(f'Unsupported image dimension number, only support 2 and 3, got {image.ndim}!')
47
+ if image.ndim == 3:
48
+ num_channels = image.shape[-1]
49
+ if num_channels not in {1, 3, 4}:
50
+ raise Exception(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')
51
+
insectid/detector.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import khandy
4
+ import numpy as np
5
+
6
+ from .base import OnnxModel
7
+ from .base import check_image_dtype_and_shape
8
+
9
+
10
+ class InsectDetector(OnnxModel):
11
+ def __init__(self):
12
+ current_dir = os.path.dirname(os.path.abspath(__file__))
13
+ model_path = os.path.join(current_dir, 'models/quarrying_insect_detector.onnx')
14
+ self.input_width = 640
15
+ self.input_height = 640
16
+ super(InsectDetector, self).__init__(model_path)
17
+
18
+ def _preprocess(self, image):
19
+ check_image_dtype_and_shape(image)
20
+
21
+ # image size normalization
22
+ image, scale, pad_left, pad_top = khandy.letterbox_image(
23
+ image, self.input_width, self.input_height, 0, return_scale=True)
24
+ # image channel normalization
25
+ image = khandy.normalize_image_channel(image, swap_rb=True)
26
+ # image dtype normalization
27
+ image = khandy.rescale_image(image, 'auto', np.float32)
28
+ # to tensor
29
+ image = np.transpose(image, (2,0,1))
30
+ image = np.expand_dims(image, axis=0)
31
+ return image, scale, pad_left, pad_top
32
+
33
+ def _post_process(self, outputs_list, scale, pad_left, pad_top, conf_thresh, iou_thresh):
34
+ pred = outputs_list[0][0]
35
+ pass_t = pred[:, 4] > conf_thresh
36
+ pred = pred[pass_t]
37
+
38
+ boxes = khandy.convert_boxes_format(pred[:, :4], 'cxcywh', 'xyxy')
39
+ boxes = khandy.unletterbox_2d_points(boxes, scale, pad_left, pad_top, False)
40
+ confs = np.max(pred[:, 5:] * pred[:, 4:5], axis=-1)
41
+ classes = np.argmax(pred[:, 5:] * pred[:, 4:5], axis=-1)
42
+ keep = khandy.non_max_suppression(boxes, confs, iou_thresh)
43
+ return boxes[keep], confs[keep], classes[keep]
44
+
45
+ def detect(self, image, conf_thresh=0.5, iou_thresh=0.5):
46
+ image, scale, pad_left, pad_top = self._preprocess(image)
47
+ outputs_list = self.forward(image)
48
+ boxes, confs, classes = self._post_process(
49
+ outputs_list,
50
+ scale=scale,
51
+ pad_left=pad_left,
52
+ pad_top=pad_top,
53
+ conf_thresh=conf_thresh,
54
+ iou_thresh=iou_thresh)
55
+ return boxes, confs, classes
56
+
insectid/identifier.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ from collections import OrderedDict
4
+
5
+ import khandy
6
+ import numpy as np
7
+
8
+ from .base import OnnxModel
9
+ from .base import check_image_dtype_and_shape
10
+
11
+
12
+ class InsectIdentifier(OnnxModel):
13
+ def __init__(self):
14
+ current_dir = os.path.dirname(os.path.abspath(__file__))
15
+ model_path = os.path.join(current_dir, 'models/quarrying_insect_identifier.onnx')
16
+ label_map_path = os.path.join(current_dir, 'models/quarrying_insectid_label_map.txt')
17
+ super(InsectIdentifier, self).__init__(model_path)
18
+
19
+ self.label_name_dict = self._get_label_name_dict(label_map_path)
20
+ self.names = [self.label_name_dict[i]['chinese_name'] for i in range(len(self.label_name_dict))]
21
+ self.num_classes = len(self.label_name_dict)
22
+
23
+ @staticmethod
24
+ def _get_label_name_dict(filename):
25
+ records = khandy.load_list(filename)
26
+ label_name_dict = {}
27
+ for record in records:
28
+ label, chinese_name, latin_name = record.split(',')
29
+ label_name_dict[int(label)] = OrderedDict([('chinese_name', chinese_name),
30
+ ('latin_name', latin_name)])
31
+ return label_name_dict
32
+
33
+ @staticmethod
34
+ def _preprocess(image):
35
+ check_image_dtype_and_shape(image)
36
+
37
+ # image size normalization
38
+ image = khandy.letterbox_image(image, 224, 224)
39
+ # image channel normalization
40
+ image = khandy.normalize_image_channel(image, swap_rb=True)
41
+ # image dtype normalization
42
+ # image dtype and value range normalization
43
+ mean, stddev = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
44
+ image = khandy.normalize_image_value(image, mean, stddev, 'auto')
45
+ # to tensor
46
+ image = np.transpose(image, (2,0,1))
47
+ image = np.expand_dims(image, axis=0)
48
+ return image
49
+
50
+ def predict(self, image):
51
+ inputs = self._preprocess(image)
52
+ logits = self.forward(inputs)
53
+ probs = khandy.softmax(logits)
54
+ return probs
55
+
56
+ def identify(self, image, topk=5):
57
+ assert isinstance(topk, int)
58
+ if topk <= 0 or topk > self.num_classes:
59
+ topk = self.num_classes
60
+
61
+ probs = self.predict(image)
62
+ topk_probs, topk_indices = khandy.top_k(probs, topk)
63
+
64
+ results = []
65
+ for ind, prob in zip(topk_indices[0], topk_probs[0]):
66
+ one_result = copy.deepcopy(self.label_name_dict[ind])
67
+ one_result['probability'] = prob
68
+ results.append(one_result)
69
+ return results
70
+
71
+
insectid/models/quarrying_insect_detector.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d1c0615c8dc604248d1b8c48c8414a7b7a84d653ce54ba612ff928ba8f38745
3
+ size 28315428
insectid/models/quarrying_insect_identifier.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e9c9c498633cbb7797560009e1826b34a6ed2aa5b95c9e7b1e184ad8cbb2355
3
+ size 22675272
insectid/models/quarrying_insectid_label_map.txt ADDED
The diff for this file is too large to render. See raw diff
 
khandy/boxes/boxes_and_indices.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def _concat(arr_list, axis=0):
5
+ """Avoids a copy if there is only a single element in a list.
6
+ """
7
+ if len(arr_list) == 1:
8
+ return arr_list[0]
9
+ return np.concatenate(arr_list, axis)
10
+
11
+
12
+ def convert_boxes_list_to_boxes_and_indices(boxes_list):
13
+ """
14
+ Args:
15
+ boxes_list (np.ndarray): list or tuple of ndarray with shape (N_i, 4+K)
16
+
17
+ Returns:
18
+ boxes (ndarray): shape (M, 4+K) where M is sum of N_i.
19
+ indices (ndarray): shape (M, 1) where M is sum of N_i.
20
+
21
+ References:
22
+ `mmdet.core.bbox.bbox2roi` in mmdetection
23
+ `convert_boxes_to_roi_format` in TorchVision
24
+ `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
25
+ """
26
+ assert isinstance(boxes_list, (list, tuple))
27
+ boxes = _concat(boxes_list, axis=0)
28
+
29
+ indices_list = [np.full((len(b), 1), i, boxes.dtype)
30
+ for i, b in enumerate(boxes_list)]
31
+ indices = _concat(indices_list, axis=0)
32
+ return boxes, indices
33
+
34
+
35
+ def convert_boxes_and_indices_to_boxes_list(boxes, indices, num_indices):
36
+ """
37
+ Args:
38
+ boxes (np.ndarray): shape (N, 4+K)
39
+ indices (np.ndarray): shape (N,) or (N, 1), maybe batch index
40
+ in mini-batch or class label index.
41
+ num_indices (int): number of index.
42
+
43
+ Returns:
44
+ list (ndarray): boxes list of each index
45
+
46
+ References:
47
+ `mmdet.core.bbox2result` in mmdetection
48
+ `mmdet.core.bbox.roi2bbox` in mmdetection
49
+ `convert_boxes_to_roi_format` in TorchVision
50
+ `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
51
+ """
52
+ boxes = np.asarray(boxes)
53
+ indices = np.asarray(indices)
54
+ assert boxes.ndim == 2, "boxes ndim must be 2, got {}".format(boxes.ndim)
55
+ assert (indices.ndim == 1) or (indices.ndim == 2 and indices.shape[-1] == 1), \
56
+ "indices ndim must be 1 or 2 if last dimension size is 1, got shape {}".format(indices.shape)
57
+ assert boxes.shape[0] == indices.shape[0], "the 1st dimension size of boxes and indices "\
58
+ "must be the same, got {} != {}".format(boxes.shape[0], indices.shape[0])
59
+
60
+ if boxes.shape[0] == 0:
61
+ return [np.zeros((0, boxes.shape[1]), dtype=np.float32)
62
+ for i in range(num_indices)]
63
+ else:
64
+ if indices.ndim == 2:
65
+ indices = np.squeeze(indices, axis=-1)
66
+ return [boxes[indices == i, :] for i in range(num_indices)]
67
+
68
+
khandy/boxes/boxes_clip.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def clip_boxes(boxes, reference_box, copy=True):
5
+ """Clip boxes to reference box.
6
+
7
+ References:
8
+ `clip_to_window` in TensorFlow object detection API.
9
+ """
10
+ if copy:
11
+ boxes = boxes.copy()
12
+ ref_x_min, ref_y_min, ref_x_max, ref_y_max = reference_box[:4]
13
+ lower = np.array([ref_x_min, ref_y_min, ref_x_min, ref_y_min])
14
+ upper = np.array([ref_x_max, ref_y_max, ref_x_max, ref_y_max])
15
+ np.clip(boxes[..., :4], lower, upper, boxes[..., :4])
16
+ return boxes
17
+
18
+
19
+ def clip_boxes_to_image(boxes, image_width, image_height, subpixel=True, copy=True):
20
+ """Clip boxes to image boundaries.
21
+
22
+ References:
23
+ `clip_boxes` in py-faster-rcnn
24
+ `core.boxes_op_list.clip_to_window` in TensorFlow object detection API.
25
+ `structures.Boxes.clip` in detectron2
26
+
27
+ Notes:
28
+ Equivalent to `clip_boxes(boxes, [0,0,image_width-1,image_height-1], copy)`
29
+ """
30
+ if not subpixel:
31
+ image_width -= 1
32
+ image_height -= 1
33
+ reference_box = [0, 0, image_width, image_height]
34
+ return clip_boxes(boxes, reference_box, copy)
khandy/boxes/boxes_coder.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class FasterRcnnBoxCoder:
5
+ """Faster RCNN box coder.
6
+
7
+ Notes:
8
+ boxes should be in cxcywh format.
9
+ """
10
+
11
+ def __init__(self, stddevs=None):
12
+ """Constructor for FasterRcnnBoxCoder.
13
+
14
+ Args:
15
+ stddevs: List of 4 positive scalars to scale ty, tx, th and tw.
16
+ If set to None, does not perform scaling. For Faster RCNN,
17
+ the open-source implementation recommends using [0.1, 0.1, 0.2, 0.2].
18
+ """
19
+ if stddevs:
20
+ assert len(stddevs) == 4
21
+ for scalar in stddevs:
22
+ assert scalar > 0
23
+ self.stddevs = stddevs
24
+
25
+ def encode(self, boxes, reference_boxes, copy=True):
26
+ """Encode boxes with respect to reference boxes.
27
+ """
28
+ if copy:
29
+ boxes = boxes.copy()
30
+
31
+ boxes[..., 2:4] += 1e-8
32
+ reference_boxes[..., 2:4] += 1e-8
33
+
34
+ boxes[..., 0:2] -= reference_boxes[..., 0:2]
35
+ boxes[..., 0:2] /= reference_boxes[..., 2:4]
36
+ boxes[..., 2:4] /= reference_boxes[..., 2:4]
37
+ boxes[..., 2:4] = np.log(boxes[..., 2:4], boxes[..., 2:4])
38
+ if self.stddevs:
39
+ boxes[..., 0:4] /= self.stddevs
40
+ return boxes
41
+
42
+ def decode(self, rel_boxes, reference_boxes, copy=True):
43
+ """Decode relative codes to boxes.
44
+ """
45
+ if copy:
46
+ rel_boxes = rel_boxes.copy()
47
+
48
+ if self.stddevs:
49
+ rel_boxes[..., 0:4] *= self.stddevs
50
+
51
+ rel_boxes[..., 0:2] *= reference_boxes[..., 2:4]
52
+ rel_boxes[..., 0:2] += reference_boxes[..., 0:2]
53
+ rel_boxes[..., 2:4] = np.exp(rel_boxes[..., 2:4], rel_boxes[..., 2:4])
54
+ rel_boxes[..., 2:4] *= reference_boxes[..., 2:4]
55
+ return rel_boxes
56
+
57
+ def decode_points(self, rel_points, reference_boxes, copy=True):
58
+ """Decode relative codes to points.
59
+ """
60
+ if copy:
61
+ rel_points = rel_points.copy()
62
+ if self.stddevs:
63
+ rel_points[..., 0::2] *= self.stddevs[0]
64
+ rel_points[..., 1::2] *= self.stddevs[1]
65
+ rel_points[..., 0::2] *= reference_boxes[..., 2:3]
66
+ rel_points[..., 1::2] *= reference_boxes[..., 3:4]
67
+ rel_points[..., 0::2] += reference_boxes[..., 0:1]
68
+ rel_points[..., 1::2] += reference_boxes[..., 1:2]
69
+ return rel_points
khandy/boxes/boxes_convert.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def convert_xyxy_to_xywh(boxes, copy=True):
5
+ """Convert [x_min, y_min, x_max, y_max] format to [x_min, y_min, width, height] format.
6
+ """
7
+ if copy:
8
+ boxes = boxes.copy()
9
+ boxes[..., 2:4] -= boxes[..., 0:2]
10
+ return boxes
11
+
12
+
13
+ def convert_xywh_to_xyxy(boxes, copy=True):
14
+ """Convert [x_min, y_min, width, height] format to [x_min, y_min, x_max, y_max] format.
15
+ """
16
+ if copy:
17
+ boxes = boxes.copy()
18
+ boxes[..., 2:4] += boxes[..., 0:2]
19
+ return boxes
20
+
21
+
22
+ def convert_xywh_to_cxcywh(boxes, copy=True):
23
+ """Convert [x_min, y_min, width, height] format to [cx, cy, width, height] format.
24
+ """
25
+ if copy:
26
+ boxes = boxes.copy()
27
+ boxes[..., 0:2] += boxes[..., 2:4] * 0.5
28
+ return boxes
29
+
30
+
31
+ def convert_cxcywh_to_xywh(boxes, copy=True):
32
+ """Convert [cx, cy, width, height] format to [x_min, y_min, width, height] format.
33
+ """
34
+ if copy:
35
+ boxes = boxes.copy()
36
+ boxes[..., 0:2] -= boxes[..., 2:4] * 0.5
37
+ return boxes
38
+
39
+
40
+ def convert_xyxy_to_cxcywh(boxes, copy=True):
41
+ """Convert [x_min, y_min, x_max, y_max] format to [cx, cy, width, height] format.
42
+ """
43
+ if copy:
44
+ boxes = boxes.copy()
45
+ boxes[..., 2:4] -= boxes[..., 0:2]
46
+ boxes[..., 0:2] += boxes[..., 2:4] * 0.5
47
+ return boxes
48
+
49
+
50
+ def convert_cxcywh_to_xyxy(boxes, copy=True):
51
+ """Convert [cx, cy, width, height] format to [x_min, y_min, x_max, y_max] format.
52
+ """
53
+ if copy:
54
+ boxes = boxes.copy()
55
+ boxes[..., 0:2] -= boxes[..., 2:4] * 0.5
56
+ boxes[..., 2:4] += boxes[..., 0:2]
57
+ return boxes
58
+
59
+
60
+ def convert_boxes_format(boxes, in_fmt, out_fmt, copy=True):
61
+ """Converts boxes from given in_fmt to out_fmt.
62
+
63
+ Supported in_fmt and out_fmt are:
64
+ 'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
65
+ 'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
66
+ 'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h
67
+ being width and height.
68
+
69
+ Args:
70
+ boxes: boxes which will be converted.
71
+ in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].
72
+ out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']
73
+
74
+ Returns:
75
+ boxes: Boxes into converted format.
76
+
77
+ References:
78
+ torchvision.ops.box_convert
79
+ """
80
+ allowed_fmts = ("xyxy", "xywh", "cxcywh")
81
+ if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
82
+ raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
83
+ if copy:
84
+ boxes = boxes.copy()
85
+ if in_fmt == out_fmt:
86
+ return boxes
87
+
88
+ if (in_fmt, out_fmt) == ("xyxy", "xywh"):
89
+ boxes = convert_xyxy_to_xywh(boxes, copy=False)
90
+ elif (in_fmt, out_fmt) == ("xywh", "xyxy"):
91
+ boxes = convert_xywh_to_xyxy(boxes, copy=False)
92
+ elif (in_fmt, out_fmt) == ("xywh", "cxcywh"):
93
+ boxes = convert_xywh_to_cxcywh(boxes, copy=False)
94
+ elif (in_fmt, out_fmt) == ("cxcywh", "xywh"):
95
+ boxes = convert_cxcywh_to_xywh(boxes, copy=False)
96
+ elif (in_fmt, out_fmt) == ("xyxy", "cxcywh"):
97
+ boxes = convert_xyxy_to_cxcywh(boxes, copy=False)
98
+ elif (in_fmt, out_fmt) == ("cxcywh", "xyxy"):
99
+ boxes = convert_cxcywh_to_xyxy(boxes, copy=False)
100
+ return boxes
101
+
khandy/boxes/boxes_filter.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def filter_small_boxes(boxes, min_width, min_height):
5
+ """Filters all boxes with side smaller than min size.
6
+
7
+ Args:
8
+ boxes: a numpy array with shape [N, 4] holding N boxes.
9
+ min_width (float): minimum width
10
+ min_height (float): minimum height
11
+
12
+ Returns:
13
+ keep: indices of the boxes that have width larger than
14
+ min_width and height larger than min_height.
15
+
16
+ References:
17
+ `_filter_boxes` in py-faster-rcnn
18
+ `prune_small_boxes` in TensorFlow object detection API.
19
+ `structures.Boxes.nonempty` in detectron2
20
+ `ops.boxes.remove_small_boxes` in torchvision
21
+ """
22
+ widths = boxes[:, 2] - boxes[:, 0]
23
+ heights = boxes[:, 3] - boxes[:, 1]
24
+ # keep represents indices to keep,
25
+ # mask represents bool ndarray, so use mask here.
26
+ mask = (widths >= min_width)
27
+ mask &= (heights >= min_height)
28
+ return np.nonzero(mask)[0]
29
+
30
+
31
+ def filter_boxes_outside(boxes, reference_box):
32
+ """Filters bounding boxes that fall outside reference box.
33
+
34
+ References:
35
+ `prune_outside_window` in TensorFlow object detection API.
36
+ """
37
+ x_min, y_min, x_max, y_max = reference_box[:4]
38
+ mask = ((boxes[:, 0] >= x_min) & (boxes[:, 1] >= y_min) &
39
+ (boxes[:, 2] <= x_max) & (boxes[:, 3] <= y_max))
40
+ return np.nonzero(mask)[0]
41
+
42
+
43
+ def filter_boxes_completely_outside(boxes, reference_box):
44
+ """Filters bounding boxes that fall completely outside of reference box.
45
+
46
+ References:
47
+ `prune_completely_outside_window` in TensorFlow object detection API.
48
+ """
49
+ x_min, y_min, x_max, y_max = reference_box[:4]
50
+ mask = ((boxes[:, 0] < x_max) & (boxes[:, 1] < y_max) &
51
+ (boxes[:, 2] > x_min) & (boxes[:, 3] > y_min))
52
+ return np.nonzero(mask)[0]
53
+
54
+
55
+ def non_max_suppression(boxes, scores, thresh, classes=None, ratio_type="iou"):
56
+ """Greedily select boxes with high confidence
57
+ Args:
58
+ boxes: [[x_min, y_min, x_max, y_max], ...]
59
+ scores: object confidence
60
+ thresh: retain overlap_ratio <= thresh
61
+ classes: class labels
62
+
63
+ Returns:
64
+ indices to keep
65
+
66
+ References:
67
+ `py_cpu_nms` in py-faster-rcnn
68
+ torchvision.ops.nms
69
+ torchvision.ops.batched_nms
70
+ """
71
+
72
+ if boxes.size == 0:
73
+ return np.empty((0,), dtype=np.int64)
74
+ if classes is not None:
75
+ # strategy: in order to perform NMS independently per class,
76
+ # we add an offset to all the boxes. The offset is dependent
77
+ # only on the class idx, and is large enough so that boxes
78
+ # from different classes do not overlap
79
+ max_coordinate = np.max(boxes)
80
+ offsets = classes * (max_coordinate + 1)
81
+ boxes = boxes + offsets[:, None]
82
+
83
+ x_mins = boxes[:, 0]
84
+ y_mins = boxes[:, 1]
85
+ x_maxs = boxes[:, 2]
86
+ y_maxs = boxes[:, 3]
87
+ areas = (x_maxs - x_mins) * (y_maxs - y_mins)
88
+ order = scores.flatten().argsort()[::-1]
89
+
90
+ keep = []
91
+ while order.size > 0:
92
+ i = order[0]
93
+ keep.append(i)
94
+
95
+ max_x_mins = np.maximum(x_mins[i], x_mins[order[1:]])
96
+ max_y_mins = np.maximum(y_mins[i], y_mins[order[1:]])
97
+ min_x_maxs = np.minimum(x_maxs[i], x_maxs[order[1:]])
98
+ min_y_maxs = np.minimum(y_maxs[i], y_maxs[order[1:]])
99
+ widths = np.maximum(0, min_x_maxs - max_x_mins)
100
+ heights = np.maximum(0, min_y_maxs - max_y_mins)
101
+ intersect_areas = widths * heights
102
+
103
+ if ratio_type in ["union", 'iou']:
104
+ ratio = intersect_areas / (areas[i] + areas[order[1:]] - intersect_areas)
105
+ elif ratio_type == "min":
106
+ ratio = intersect_areas / np.minimum(areas[i], areas[order[1:]])
107
+ else:
108
+ raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
109
+
110
+ inds = np.nonzero(ratio <= thresh)[0]
111
+ order = order[inds + 1]
112
+ return np.asarray(keep)
113
+
khandy/boxes/boxes_overlap.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def paired_intersection(boxes1, boxes2):
5
+ """Compute paired intersection areas between boxes.
6
+ Args:
7
+ boxes1: a numpy array with shape [N, 4] holding N boxes
8
+ boxes2: a numpy array with shape [N, 4] holding N boxes
9
+
10
+ Returns:
11
+ a numpy array with shape [N,] representing itemwise intersection area
12
+
13
+ References:
14
+ `core.box_list_ops.matched_intersection` in Tensorflow object detection API
15
+
16
+ Notes:
17
+ can called as itemwise_intersection, matched_intersection, aligned_intersection
18
+ """
19
+ max_x_mins = np.maximum(boxes1[:, 0], boxes2[:, 0])
20
+ max_y_mins = np.maximum(boxes1[:, 1], boxes2[:, 1])
21
+ min_x_maxs = np.minimum(boxes1[:, 2], boxes2[:, 2])
22
+ min_y_maxs = np.minimum(boxes1[:, 3], boxes2[:, 3])
23
+ intersect_widths = np.maximum(0, min_x_maxs - max_x_mins)
24
+ intersect_heights = np.maximum(0, min_y_maxs - max_y_mins)
25
+ return intersect_widths * intersect_heights
26
+
27
+
28
+ def pairwise_intersection(boxes1, boxes2):
29
+ """Compute pairwise intersection areas between boxes.
30
+
31
+ Args:
32
+ boxes1: a numpy array with shape [N, 4] holding N boxes.
33
+ boxes2: a numpy array with shape [M, 4] holding M boxes.
34
+
35
+ Returns:
36
+ a numpy array with shape [N, M] representing pairwise intersection area.
37
+
38
+ References:
39
+ `core.box_list_ops.intersection` in Tensorflow object detection API
40
+ `utils.box_list_ops.intersection` in Tensorflow object detection API
41
+ """
42
+ if boxes1.shape[0] * boxes2.shape[0] == 0:
43
+ return np.zeros((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
44
+
45
+ swap = False
46
+ if boxes1.shape[0] > boxes2.shape[0]:
47
+ boxes1, boxes2 = boxes2, boxes1
48
+ swap = True
49
+ intersect_areas = np.empty((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
50
+
51
+ for i in range(boxes1.shape[0]):
52
+ max_x_mins = np.maximum(boxes1[i, 0], boxes2[:, 0])
53
+ max_y_mins = np.maximum(boxes1[i, 1], boxes2[:, 1])
54
+ min_x_maxs = np.minimum(boxes1[i, 2], boxes2[:, 2])
55
+ min_y_maxs = np.minimum(boxes1[i, 3], boxes2[:, 3])
56
+ intersect_widths = np.maximum(0, min_x_maxs - max_x_mins)
57
+ intersect_heights = np.maximum(0, min_y_maxs - max_y_mins)
58
+ intersect_areas[i, :] = intersect_widths * intersect_heights
59
+ if swap:
60
+ intersect_areas = intersect_areas.T
61
+ return intersect_areas
62
+
63
+
64
+ def paired_overlap_ratio(boxes1, boxes2, ratio_type='iou'):
65
+ """Compute paired overlap ratio between boxes.
66
+
67
+ Args:
68
+ boxes1: a numpy array with shape [N, 4] holding N boxes
69
+ boxes2: a numpy array with shape [N, 4] holding N boxes
70
+ ratio_type:
71
+ iou: Intersection-over-union (iou).
72
+ ioa: Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
73
+ their intersection area over box2's area. Note that ioa is not symmetric,
74
+ that is, IOA(box1, box2) != IOA(box2, box1).
75
+ min: Compute the ratio as the area of intersection between box1 and box2,
76
+ divided by the minimum area of the two bounding boxes.
77
+
78
+ Returns:
79
+ a numpy array with shape [N,] representing itemwise overlap ratio.
80
+
81
+ References:
82
+ `core.box_list_ops.matched_iou` in Tensorflow object detection API
83
+ `structures.boxes.matched_boxlist_iou` in detectron2
84
+ `mmdet.core.bbox.bbox_overlaps`, see https://mmdetection.readthedocs.io/en/v2.17.0/api.html#mmdet.core.bbox.bbox_overlaps
85
+ """
86
+ intersect_areas = paired_intersection(boxes1, boxes2)
87
+ areas1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
88
+ areas2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
89
+
90
+ if ratio_type in ['union', 'iou', 'giou']:
91
+ union_areas = areas1 - intersect_areas
92
+ union_areas += areas2
93
+ intersect_areas /= union_areas
94
+ elif ratio_type == 'min':
95
+ min_areas = np.minimum(areas1, areas2)
96
+ intersect_areas /= min_areas
97
+ elif ratio_type == 'ioa':
98
+ intersect_areas /= areas2
99
+ else:
100
+ raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
101
+
102
+ if ratio_type == 'giou':
103
+ min_xy_mins = np.minimum(boxes1[:, 0:2], boxes2[:, 0:2])
104
+ max_xy_mins = np.maximum(boxes1[:, 2:4], boxes2[:, 2:4])
105
+ # mebb = minimum enclosing bounding boxes
106
+ mebb_whs = np.maximum(0, max_xy_mins - min_xy_mins)
107
+ mebb_areas = mebb_whs[:, 0] * mebb_whs[:, 1]
108
+ union_areas -= mebb_areas
109
+ union_areas /= mebb_areas
110
+ intersect_areas += union_areas
111
+ return intersect_areas
112
+
113
+
114
+ def pairwise_overlap_ratio(boxes1, boxes2, ratio_type='iou'):
115
+ """Compute pairwise overlap ratio between boxes.
116
+
117
+ Args:
118
+ boxes1: a numpy array with shape [N, 4] holding N boxes
119
+ boxes2: a numpy array with shape [M, 4] holding M boxes
120
+ ratio_type:
121
+ iou: Intersection-over-union (iou).
122
+ ioa: Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
123
+ their intersection area over box2's area. Note that ioa is not symmetric,
124
+ that is, IOA(box1, box2) != IOA(box2, box1).
125
+ min: Compute the ratio as the area of intersection between box1 and box2,
126
+ divided by the minimum area of the two bounding boxes.
127
+
128
+ Returns:
129
+ a numpy array with shape [N, M] representing pairwise overlap ratio.
130
+
131
+ References:
132
+ `utils.np_box_ops.iou` in Tensorflow object detection API
133
+ `utils.np_box_ops.ioa` in Tensorflow object detection API
134
+ `utils.np_box_ops.giou` in Tensorflow object detection API
135
+ `mmdet.core.bbox.bbox_overlaps`, see https://mmdetection.readthedocs.io/en/v2.17.0/api.html#mmdet.core.bbox.bbox_overlaps
136
+ `torchvision.ops.box_iou`, see https://pytorch.org/vision/stable/ops.html#torchvision.ops.box_iou
137
+ `torchvision.ops.generalized_box_iou`, see https://pytorch.org/vision/stable/ops.html#torchvision.ops.generalized_box_iou
138
+ http://ww2.mathworks.cn/help/vision/ref/bboxoverlapratio.html
139
+ """
140
+ intersect_areas = pairwise_intersection(boxes1, boxes2)
141
+ areas1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
142
+ areas2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
143
+
144
+ if ratio_type in ['union', 'iou', 'giou']:
145
+ union_areas = np.expand_dims(areas1, axis=1) - intersect_areas
146
+ union_areas += np.expand_dims(areas2, axis=0)
147
+ intersect_areas /= union_areas
148
+ elif ratio_type == 'min':
149
+ min_areas = np.minimum(np.expand_dims(areas1, axis=1), np.expand_dims(areas2, axis=0))
150
+ intersect_areas /= min_areas
151
+ elif ratio_type == 'ioa':
152
+ intersect_areas /= np.expand_dims(areas2, axis=0)
153
+ else:
154
+ raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
155
+
156
+ if ratio_type == 'giou':
157
+ min_xy_mins = np.minimum(boxes1[:, None, 0:2], boxes2[:, 0:2])
158
+ max_xy_mins = np.maximum(boxes1[:, None, 2:4], boxes2[:, 2:4])
159
+ # mebb = minimum enclosing bounding boxes
160
+ mebb_whs = np.maximum(0, max_xy_mins - min_xy_mins)
161
+ mebb_areas = mebb_whs[:, :, 0] * mebb_whs[:, :, 1]
162
+ union_areas -= mebb_areas
163
+ union_areas /= mebb_areas
164
+ intersect_areas += union_areas
165
+ return intersect_areas
166
+
khandy/boxes/boxes_transform_flip.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .boxes_utils import assert_and_normalize_shape
3
+
4
+
5
+ def flip_boxes(boxes, x_center=0, y_center=0, direction='h'):
6
+ """
7
+ Args:
8
+ boxes: (N, 4+K)
9
+ x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
10
+ y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
11
+ direction: str
12
+ """
13
+ assert direction in ['x', 'h', 'horizontal',
14
+ 'y', 'v', 'vertical',
15
+ 'o', 'b', 'both']
16
+ boxes = np.asarray(boxes, np.float32)
17
+ ret_boxes = boxes.copy()
18
+
19
+ x_center = np.asarray(x_center, np.float32)
20
+ y_center = np.asarray(y_center, np.float32)
21
+ x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
22
+ y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
23
+
24
+ if direction in ['o', 'b', 'both', 'x', 'h', 'horizontal']:
25
+ ret_boxes[:, 0] = 2 * x_center - boxes[:, 2]
26
+ ret_boxes[:, 2] = 2 * x_center - boxes[:, 0]
27
+ if direction in ['o', 'b', 'both', 'y', 'v', 'vertical']:
28
+ ret_boxes[:, 1] = 2 * y_center - boxes[:, 3]
29
+ ret_boxes[:, 3] = 2 * y_center - boxes[:, 1]
30
+ return ret_boxes
31
+
32
+
33
+ def fliplr_boxes(boxes, x_center=0, y_center=0):
34
+ """
35
+ Args:
36
+ boxes: (N, 4+K)
37
+ x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
38
+ y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
39
+ """
40
+ boxes = np.asarray(boxes, np.float32)
41
+ ret_boxes = boxes.copy()
42
+
43
+ x_center = np.asarray(x_center, np.float32)
44
+ y_center = np.asarray(y_center, np.float32)
45
+ x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
46
+ y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
47
+
48
+ ret_boxes[:, 0] = 2 * x_center - boxes[:, 2]
49
+ ret_boxes[:, 2] = 2 * x_center - boxes[:, 0]
50
+ return ret_boxes
51
+
52
+
53
+ def flipud_boxes(boxes, x_center=0, y_center=0):
54
+ """
55
+ Args:
56
+ boxes: (N, 4+K)
57
+ x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
58
+ y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
59
+ """
60
+ boxes = np.asarray(boxes, np.float32)
61
+ ret_boxes = boxes.copy()
62
+
63
+ x_center = np.asarray(x_center, np.float32)
64
+ y_center = np.asarray(y_center, np.float32)
65
+ x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
66
+ y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
67
+
68
+ ret_boxes[:, 1] = 2 * y_center - boxes[:, 3]
69
+ ret_boxes[:, 3] = 2 * y_center - boxes[:, 1]
70
+ return ret_boxes
71
+
72
+
73
+ def transpose_boxes(boxes, x_center=0, y_center=0):
74
+ """
75
+ Args:
76
+ boxes: (N, 4+K)
77
+ x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
78
+ y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
79
+ """
80
+ boxes = np.asarray(boxes, np.float32)
81
+ ret_boxes = boxes.copy()
82
+
83
+ x_center = np.asarray(x_center, np.float32)
84
+ y_center = np.asarray(y_center, np.float32)
85
+ x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
86
+ y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
87
+
88
+ shift = x_center - y_center
89
+ ret_boxes[:, 0] = boxes[:, 1] + shift
90
+ ret_boxes[:, 1] = boxes[:, 0] - shift
91
+ ret_boxes[:, 2] = boxes[:, 3] + shift
92
+ ret_boxes[:, 3] = boxes[:, 2] - shift
93
+ return ret_boxes
94
+
95
+
96
+ def flip_boxes_in_image(boxes, image_width, image_height, direction='h'):
97
+ """
98
+ Args:
99
+ boxes: (N, 4+K)
100
+ image_width: int
101
+ image_width: int
102
+ direction: str
103
+
104
+ References:
105
+ `core.bbox.bbox_flip` in mmdetection
106
+ `datasets.pipelines.RandomFlip.bbox_flip` in mmdetection
107
+ """
108
+ x_center = (image_width - 1) * 0.5
109
+ y_center = (image_height - 1) * 0.5
110
+ ret_boxes = flip_boxes(boxes, x_center, y_center, direction=direction)
111
+ return ret_boxes
112
+
113
+
114
+ def rot90_boxes_in_image(boxes, image_width, image_height, n=1):
115
+ """Rotate boxes counter-clockwise by 90 degrees.
116
+
117
+ References:
118
+ np.rot90
119
+ cv2.rotate
120
+ tf.image.rot90
121
+ """
122
+ n = n % 4
123
+ if n == 0:
124
+ ret_boxes = boxes.copy()
125
+ elif n == 1:
126
+ ret_boxes = transpose_boxes(boxes)
127
+ ret_boxes = flip_boxes_in_image(ret_boxes, image_width, image_height, 'v')
128
+ elif n == 2:
129
+ ret_boxes = flip_boxes_in_image(boxes, image_width, image_height, 'o')
130
+ else:
131
+ ret_boxes = transpose_boxes(boxes)
132
+ ret_boxes = flip_boxes_in_image(ret_boxes, image_width, image_height, 'h');
133
+ return ret_boxes
134
+
135
+
khandy/boxes/boxes_transform_rotate.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .boxes_utils import assert_and_normalize_shape
3
+
4
+
5
+ def rotate_boxes(boxes, angle, x_center=0, y_center=0, scale=1,
6
+ degrees=True, return_rotated_boxes=False):
7
+ """
8
+ Args:
9
+ boxes: (N, 4+K)
10
+ angle: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
11
+ x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
12
+ y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
13
+ scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
14
+ scale factor in x and y dimension
15
+ degrees: bool
16
+ return_rotated_boxes: bool
17
+ """
18
+ boxes = np.asarray(boxes, np.float32)
19
+
20
+ angle = np.asarray(angle, np.float32)
21
+ x_center = np.asarray(x_center, np.float32)
22
+ y_center = np.asarray(y_center, np.float32)
23
+ scale = np.asarray(scale, np.float32)
24
+
25
+ angle = assert_and_normalize_shape(angle, boxes.shape[0])
26
+ x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
27
+ y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
28
+ scale = assert_and_normalize_shape(scale, boxes.shape[0])
29
+
30
+ if degrees:
31
+ angle = np.deg2rad(angle)
32
+ cos_val = scale * np.cos(angle)
33
+ sin_val = scale * np.sin(angle)
34
+ x_shift = x_center - x_center * cos_val + y_center * sin_val
35
+ y_shift = y_center - x_center * sin_val - y_center * cos_val
36
+
37
+ x_mins, y_mins = boxes[:,0], boxes[:,1]
38
+ x_maxs, y_maxs = boxes[:,2], boxes[:,3]
39
+ x00 = x_mins * cos_val - y_mins * sin_val + x_shift
40
+ x10 = x_maxs * cos_val - y_mins * sin_val + x_shift
41
+ x11 = x_maxs * cos_val - y_maxs * sin_val + x_shift
42
+ x01 = x_mins * cos_val - y_maxs * sin_val + x_shift
43
+
44
+ y00 = x_mins * sin_val + y_mins * cos_val + y_shift
45
+ y10 = x_maxs * sin_val + y_mins * cos_val + y_shift
46
+ y11 = x_maxs * sin_val + y_maxs * cos_val + y_shift
47
+ y01 = x_mins * sin_val + y_maxs * cos_val + y_shift
48
+
49
+ rotated_boxes = np.stack([x00, y00, x10, y10, x11, y11, x01, y01], axis=-1)
50
+ ret_x_mins = np.min(rotated_boxes[:,0::2], axis=1)
51
+ ret_y_mins = np.min(rotated_boxes[:,1::2], axis=1)
52
+ ret_x_maxs = np.max(rotated_boxes[:,0::2], axis=1)
53
+ ret_y_maxs = np.max(rotated_boxes[:,1::2], axis=1)
54
+
55
+ if boxes.ndim == 4:
56
+ ret_boxes = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
57
+ else:
58
+ ret_boxes = boxes.copy()
59
+ ret_boxes[:, :4] = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
60
+
61
+ if not return_rotated_boxes:
62
+ return ret_boxes
63
+ else:
64
+ return ret_boxes, rotated_boxes
65
+
66
+
67
+ def rotate_boxes_wrt_centers(boxes, angle, scale=1, degrees=True,
68
+ return_rotated_boxes=False):
69
+ """
70
+ Args:
71
+ boxes: (N, 4+K)
72
+ angle: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
73
+ scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
74
+ scale factor in x and y dimension
75
+ degrees: bool
76
+ return_rotated_boxes: bool
77
+ """
78
+ boxes = np.asarray(boxes, np.float32)
79
+
80
+ angle = np.asarray(angle, np.float32)
81
+ scale = np.asarray(scale, np.float32)
82
+ angle = assert_and_normalize_shape(angle, boxes.shape[0])
83
+ scale = assert_and_normalize_shape(scale, boxes.shape[0])
84
+
85
+ if degrees:
86
+ angle = np.deg2rad(angle)
87
+ cos_val = scale * np.cos(angle)
88
+ sin_val = scale * np.sin(angle)
89
+
90
+ x_centers = boxes[:, 2] + boxes[:, 0]
91
+ y_centers = boxes[:, 3] + boxes[:, 1]
92
+ x_centers *= 0.5
93
+ y_centers *= 0.5
94
+
95
+ half_widths = boxes[:, 2] - boxes[:, 0]
96
+ half_heights = boxes[:, 3] - boxes[:, 1]
97
+ half_widths *= 0.5
98
+ half_heights *= 0.5
99
+
100
+ half_widths_cos = half_widths * cos_val
101
+ half_widths_sin = half_widths * sin_val
102
+ half_heights_cos = half_heights * cos_val
103
+ half_heights_sin = half_heights * sin_val
104
+
105
+ x00 = -half_widths_cos + half_heights_sin
106
+ x10 = half_widths_cos + half_heights_sin
107
+ x11 = half_widths_cos - half_heights_sin
108
+ x01 = -half_widths_cos - half_heights_sin
109
+ x00 += x_centers
110
+ x10 += x_centers
111
+ x11 += x_centers
112
+ x01 += x_centers
113
+
114
+ y00 = -half_widths_sin - half_heights_cos
115
+ y10 = half_widths_sin - half_heights_cos
116
+ y11 = half_widths_sin + half_heights_cos
117
+ y01 = -half_widths_sin + half_heights_cos
118
+ y00 += y_centers
119
+ y10 += y_centers
120
+ y11 += y_centers
121
+ y01 += y_centers
122
+
123
+ rotated_boxes = np.stack([x00, y00, x10, y10, x11, y11, x01, y01], axis=-1)
124
+ ret_x_mins = np.min(rotated_boxes[:,0::2], axis=1)
125
+ ret_y_mins = np.min(rotated_boxes[:,1::2], axis=1)
126
+ ret_x_maxs = np.max(rotated_boxes[:,0::2], axis=1)
127
+ ret_y_maxs = np.max(rotated_boxes[:,1::2], axis=1)
128
+
129
+ if boxes.ndim == 4:
130
+ ret_boxes = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
131
+ else:
132
+ ret_boxes = boxes.copy()
133
+ ret_boxes[:, :4] = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
134
+
135
+ if not return_rotated_boxes:
136
+ return ret_boxes
137
+ else:
138
+ return ret_boxes, rotated_boxes
139
+
140
+
khandy/boxes/boxes_transform_scale.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .boxes_utils import assert_and_normalize_shape
3
+
4
+
5
+ def scale_boxes(boxes, x_scale=1, y_scale=1, x_center=0, y_center=0, copy=True):
6
+ """Scale boxes coordinates in x and y dimensions.
7
+
8
+ Args:
9
+ boxes: (N, 4+K)
10
+ x_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
11
+ scale factor in x dimension
12
+ y_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
13
+ scale factor in y dimension
14
+ x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
15
+ y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
16
+
17
+ References:
18
+ `core.box_list_ops.scale` in TensorFlow object detection API
19
+ `utils.box_list_ops.scale` in TensorFlow object detection API
20
+ `datasets.pipelines.Resize._resize_bboxes` in mmdetection
21
+ `core.anchor.guided_anchor_target.calc_region` in mmdetection where comments may be misleading!
22
+ `layers.mask_ops.scale_boxes` in detectron2
23
+ `mmcv.bbox_scaling`
24
+ """
25
+ boxes = np.array(boxes, dtype=np.float32, copy=copy)
26
+
27
+ x_scale = np.asarray(x_scale, np.float32)
28
+ y_scale = np.asarray(y_scale, np.float32)
29
+ x_scale = assert_and_normalize_shape(x_scale, boxes.shape[0])
30
+ y_scale = assert_and_normalize_shape(y_scale, boxes.shape[0])
31
+
32
+ x_center = np.asarray(x_center, np.float32)
33
+ y_center = np.asarray(y_center, np.float32)
34
+ x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
35
+ y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
36
+
37
+ x_shift = 1 - x_scale
38
+ y_shift = 1 - y_scale
39
+ x_shift *= x_center
40
+ y_shift *= y_center
41
+
42
+ boxes[:, 0] *= x_scale
43
+ boxes[:, 1] *= y_scale
44
+ boxes[:, 2] *= x_scale
45
+ boxes[:, 3] *= y_scale
46
+ boxes[:, 0] += x_shift
47
+ boxes[:, 1] += y_shift
48
+ boxes[:, 2] += x_shift
49
+ boxes[:, 3] += y_shift
50
+ return boxes
51
+
52
+
53
+ def scale_boxes_wrt_centers(boxes, x_scale=1, y_scale=1, copy=True):
54
+ """
55
+ Args:
56
+ boxes: (N, 4+K)
57
+ x_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
58
+ scale factor in x dimension
59
+ y_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
60
+ scale factor in y dimension
61
+
62
+ References:
63
+ `core.anchor.guided_anchor_target.calc_region` in mmdetection where comments may be misleading!
64
+ `layers.mask_ops.scale_boxes` in detectron2
65
+ `mmcv.bbox_scaling`
66
+ """
67
+ boxes = np.array(boxes, dtype=np.float32, copy=copy)
68
+
69
+ x_scale = np.asarray(x_scale, np.float32)
70
+ y_scale = np.asarray(y_scale, np.float32)
71
+ x_scale = assert_and_normalize_shape(x_scale, boxes.shape[0])
72
+ y_scale = assert_and_normalize_shape(y_scale, boxes.shape[0])
73
+
74
+ x_factor = (x_scale - 1) * 0.5
75
+ y_factor = (y_scale - 1) * 0.5
76
+ x_deltas = boxes[:, 2] - boxes[:, 0]
77
+ y_deltas = boxes[:, 3] - boxes[:, 1]
78
+ x_deltas *= x_factor
79
+ y_deltas *= y_factor
80
+
81
+ boxes[:, 0] -= x_deltas
82
+ boxes[:, 1] -= y_deltas
83
+ boxes[:, 2] += x_deltas
84
+ boxes[:, 3] += y_deltas
85
+ return boxes
86
+
khandy/boxes/boxes_transform_translate.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .boxes_utils import assert_and_normalize_shape
3
+
4
+
5
+ def translate_boxes(boxes, x_shift=0, y_shift=0, copy=True):
6
+ """translate boxes coordinates in x and y dimensions.
7
+
8
+ Args:
9
+ boxes: (N, 4+K)
10
+ x_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
11
+ shift in x dimension
12
+ y_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
13
+ shift in y dimension
14
+ copy: bool
15
+
16
+ References:
17
+ `datasets.pipelines.RandomCrop` in mmdetection
18
+ """
19
+ boxes = np.array(boxes, dtype=np.float32, copy=copy)
20
+
21
+ x_shift = np.asarray(x_shift, np.float32)
22
+ y_shift = np.asarray(y_shift, np.float32)
23
+
24
+ x_shift = assert_and_normalize_shape(x_shift, boxes.shape[0])
25
+ y_shift = assert_and_normalize_shape(y_shift, boxes.shape[0])
26
+
27
+ boxes[:, 0] += x_shift
28
+ boxes[:, 1] += y_shift
29
+ boxes[:, 2] += x_shift
30
+ boxes[:, 3] += y_shift
31
+ return boxes
32
+
33
+
34
+ def adjust_boxes(boxes, x_min_shift, y_min_shift, x_max_shift, y_max_shift, copy=True):
35
+ """
36
+ Args:
37
+ boxes: (N, 4+K)
38
+ x_min_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
39
+ shift (x_min, y_min) in x dimension
40
+ y_min_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
41
+ shift (x_min, y_min) in y dimension
42
+ x_max_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
43
+ shift (x_max, y_max) in x dimension
44
+ y_max_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
45
+ shift (x_max, y_max) in y dimension
46
+ copy: bool
47
+ """
48
+ boxes = np.array(boxes, dtype=np.float32, copy=copy)
49
+
50
+ x_min_shift = np.asarray(x_min_shift, np.float32)
51
+ y_min_shift = np.asarray(y_min_shift, np.float32)
52
+ x_max_shift = np.asarray(x_max_shift, np.float32)
53
+ y_max_shift = np.asarray(y_max_shift, np.float32)
54
+
55
+ x_min_shift = assert_and_normalize_shape(x_min_shift, boxes.shape[0])
56
+ y_min_shift = assert_and_normalize_shape(y_min_shift, boxes.shape[0])
57
+ x_max_shift = assert_and_normalize_shape(x_max_shift, boxes.shape[0])
58
+ y_max_shift = assert_and_normalize_shape(y_max_shift, boxes.shape[0])
59
+
60
+ boxes[:, 0] += x_min_shift
61
+ boxes[:, 1] += y_min_shift
62
+ boxes[:, 2] += x_max_shift
63
+ boxes[:, 3] += y_max_shift
64
+ return boxes
65
+
66
+
67
+ def inflate_or_deflate_boxes(boxes, width_delta=0, height_delta=0, copy=True):
68
+ """
69
+ Args:
70
+ boxes: (N, 4+K)
71
+ width_delta: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
72
+ height_delta: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
73
+ copy: bool
74
+ """
75
+ boxes = np.array(boxes, dtype=np.float32, copy=copy)
76
+
77
+ width_delta = np.asarray(width_delta, np.float32)
78
+ height_delta = np.asarray(height_delta, np.float32)
79
+
80
+ width_delta = assert_and_normalize_shape(width_delta, boxes.shape[0])
81
+ height_delta = assert_and_normalize_shape(height_delta, boxes.shape[0])
82
+
83
+ half_width_delta = width_delta * 0.5
84
+ half_height_delta = height_delta * 0.5
85
+ boxes[:, 0] -= half_width_delta
86
+ boxes[:, 1] -= half_height_delta
87
+ boxes[:, 2] += half_width_delta
88
+ boxes[:, 3] += half_height_delta
89
+ return boxes
90
+
91
+
92
+ def inflate_boxes_to_square(boxes, copy=True):
93
+ """Inflate boxes to square
94
+ Args:
95
+ boxes: (N, 4+K)
96
+ copy: bool
97
+ """
98
+ boxes = np.array(boxes, dtype=np.float32, copy=copy)
99
+
100
+ widths = boxes[:, 2] - boxes[:, 0]
101
+ heights = boxes[:, 3] - boxes[:, 1]
102
+ max_side_lengths = np.maximum(widths, heights)
103
+
104
+ width_deltas = np.subtract(max_side_lengths, widths, widths)
105
+ height_deltas = np.subtract(max_side_lengths, heights, heights)
106
+ width_deltas *= 0.5
107
+ height_deltas *= 0.5
108
+ boxes[:, 0] -= width_deltas
109
+ boxes[:, 1] -= height_deltas
110
+ boxes[:, 2] += width_deltas
111
+ boxes[:, 3] += height_deltas
112
+ return boxes
113
+
114
+
115
+ def deflate_boxes_to_square(boxes, copy=True):
116
+ """Deflate boxes to square
117
+ Args:
118
+ boxes: (N, 4+K)
119
+ copy: bool
120
+ """
121
+ boxes = np.array(boxes, dtype=np.float32, copy=copy)
122
+
123
+ widths = boxes[:, 2] - boxes[:, 0]
124
+ heights = boxes[:, 3] - boxes[:, 1]
125
+ min_side_lengths = np.minimum(widths, heights)
126
+
127
+ width_deltas = np.subtract(min_side_lengths, widths, widths)
128
+ height_deltas = np.subtract(min_side_lengths, heights, heights)
129
+ width_deltas *= 0.5
130
+ height_deltas *= 0.5
131
+ boxes[:, 0] -= width_deltas
132
+ boxes[:, 1] -= height_deltas
133
+ boxes[:, 2] += width_deltas
134
+ boxes[:, 3] += height_deltas
135
+ return boxes
136
+
khandy/boxes/boxes_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def assert_and_normalize_shape(x, length):
5
+ """
6
+ Args:
7
+ x: ndarray
8
+ length: int
9
+ """
10
+ if x.ndim == 0:
11
+ return x
12
+ elif x.ndim == 1:
13
+ if len(x) == 1:
14
+ return x
15
+ elif len(x) == length:
16
+ return x
17
+ else:
18
+ raise ValueError('Incompatible shape!')
19
+ elif x.ndim == 2:
20
+ if x.shape == (1, 1):
21
+ return np.squeeze(x, axis=-1)
22
+ elif x.shape == (length, 1):
23
+ return np.squeeze(x, axis=-1)
24
+ else:
25
+ raise ValueError('Incompatible shape!')
26
+ else:
27
+ raise ValueError('Incompatible ndim!')
28
+
khandy/dict_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import OrderedDict
3
+
4
+
5
+ def get_dict_first_item(dict_obj):
6
+ for key in dict_obj:
7
+ return key, dict_obj[key]
8
+
9
+
10
+ def sort_dict(dict_obj, key=None, reverse=False):
11
+ return OrderedDict(sorted(dict_obj.items(), key=key, reverse=reverse))
12
+
13
+
14
+ def create_multidict(key_list, value_list):
15
+ assert len(key_list) == len(value_list)
16
+ multidict_obj = {}
17
+ for key, value in zip(key_list, value_list):
18
+ multidict_obj.setdefault(key, []).append(value)
19
+ return multidict_obj
20
+
21
+
22
+ def convert_multidict_to_list(multidict_obj):
23
+ key_list, value_list = [], []
24
+ for key, value in multidict_obj.items():
25
+ key_list += [key] * len(value)
26
+ value_list += value
27
+ return key_list, value_list
28
+
29
+
30
+ def convert_multidict_to_records(multidict_obj, key_map=None, raise_if_key_error=True):
31
+ records = []
32
+ if key_map is None:
33
+ for key in multidict_obj:
34
+ for value in multidict_obj[key]:
35
+ records.append('{},{}'.format(value, key))
36
+ else:
37
+ for key in multidict_obj:
38
+ if raise_if_key_error:
39
+ mapped_key = key_map[key]
40
+ else:
41
+ mapped_key = key_map.get(key, key)
42
+ for value in multidict_obj[key]:
43
+ records.append('{},{}'.format(value, mapped_key))
44
+ return records
45
+
46
+
47
+ def sample_multidict(multidict_obj, num_keys, num_per_key=None):
48
+ num_keys = min(num_keys, len(multidict_obj))
49
+ sub_keys = random.sample(list(multidict_obj), num_keys)
50
+ if num_per_key is None:
51
+ sub_mdict = {key: multidict_obj[key] for key in sub_keys}
52
+ else:
53
+ sub_mdict = {}
54
+ for key in sub_keys:
55
+ num_examples_inner = min(num_per_key, len(multidict_obj[key]))
56
+ sub_mdict[key] = random.sample(multidict_obj[key], num_examples_inner)
57
+ return sub_mdict
58
+
59
+
60
+ def split_multidict_on_key(multidict_obj, split_ratio, use_shuffle=False):
61
+ """Split multidict_obj on its key.
62
+ """
63
+ assert isinstance(multidict_obj, dict)
64
+ assert isinstance(split_ratio, (list, tuple))
65
+
66
+ pdf = [k / float(sum(split_ratio)) for k in split_ratio]
67
+ cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
68
+ indices = [int(round(len(multidict_obj) * k)) for k in cdf]
69
+ dict_keys = list(multidict_obj)
70
+ if use_shuffle:
71
+ random.shuffle(dict_keys)
72
+
73
+ be_split_list = []
74
+ for i in range(len(split_ratio)):
75
+ part_keys = dict_keys[indices[i]: indices[i + 1]]
76
+ part_dict = dict([(key, multidict_obj[key]) for key in part_keys])
77
+ be_split_list.append(part_dict)
78
+ return be_split_list
79
+
80
+
81
+ def split_multidict_on_value(multidict_obj, split_ratio, use_shuffle=False):
82
+ """Split multidict_obj on its value.
83
+ """
84
+ assert isinstance(multidict_obj, dict)
85
+ assert isinstance(split_ratio, (list, tuple))
86
+
87
+ pdf = [k / float(sum(split_ratio)) for k in split_ratio]
88
+ cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
89
+ be_split_list = [dict() for k in range(len(split_ratio))]
90
+ for key, value in multidict_obj.items():
91
+ indices = [int(round(len(value) * k)) for k in cdf]
92
+ cloned = value[:]
93
+ if use_shuffle:
94
+ random.shuffle(cloned)
95
+ for i in range(len(split_ratio)):
96
+ be_split_list[i][key] = cloned[indices[i]: indices[i + 1]]
97
+ return be_split_list
98
+
99
+
100
+ def get_multidict_info(multidict_obj, with_print=False, desc=None):
101
+ num_list = [len(val) for val in multidict_obj.values()]
102
+ num_keys = len(num_list)
103
+ num_values = sum(num_list)
104
+ max_values_per_key = max(num_list)
105
+ min_values_per_key = min(num_list)
106
+ if num_keys == 0:
107
+ avg_values_per_key = 0
108
+ else:
109
+ avg_values_per_key = num_values / num_keys
110
+ info = {
111
+ 'num_keys': num_keys,
112
+ 'num_values': num_values,
113
+ 'max_values_per_key': max_values_per_key,
114
+ 'min_values_per_key': min_values_per_key,
115
+ 'avg_values_per_key': avg_values_per_key,
116
+ }
117
+ if with_print:
118
+ desc = desc or '<unknown>'
119
+ print('{} key number: {}'.format(desc, info['num_keys']))
120
+ print('{} value number: {}'.format(desc, info['num_values']))
121
+ print('{} max number per-key: {}'.format(desc, info['max_values_per_key']))
122
+ print('{} min number per-key: {}'.format(desc, info['min_values_per_key']))
123
+ print('{} avg number per-key: {:.2f}'.format(desc, info['avg_values_per_key']))
124
+ return info
125
+
126
+
127
+ def filter_multidict_by_number(multidict_obj, lower, upper=None):
128
+ if upper is None:
129
+ return {key: value for key, value in multidict_obj.items()
130
+ if lower <= len(value) }
131
+ else:
132
+ assert lower <= upper, 'lower must not be greater than upper'
133
+ return {key: value for key, value in multidict_obj.items()
134
+ if lower <= len(value) <= upper }
135
+
136
+
137
+ def sort_multidict_by_number(multidict_obj, num_keys_to_keep=None, reverse=True):
138
+ """
139
+ Args:
140
+ reverse: sort in ascending order when is True.
141
+ """
142
+ if num_keys_to_keep is None:
143
+ num_keys_to_keep = len(multidict_obj)
144
+ else:
145
+ num_keys_to_keep = min(num_keys_to_keep, len(multidict_obj))
146
+ sorted_items = sorted(multidict_obj.items(), key=lambda x: len(x[1]), reverse=reverse)
147
+ filtered_dict = OrderedDict()
148
+ for i in range(num_keys_to_keep):
149
+ filtered_dict[sorted_items[i][0]] = sorted_items[i][1]
150
+ return filtered_dict
151
+
152
+
153
+ def merge_multidict(*mdicts):
154
+ merged_multidict = {}
155
+ for item in mdicts:
156
+ for key, value in item.items():
157
+ merged_multidict.setdefault(key, []).extend(value)
158
+ return merged_multidict
159
+
160
+
161
+ def invert_multidict(multidict_obj):
162
+ inverted_dict = {}
163
+ for key, value in multidict_obj.items():
164
+ for item in value:
165
+ inverted_dict.setdefault(item, []).append(key)
166
+ return inverted_dict
167
+
168
+
khandy/draw_utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL
3
+ from PIL import Image
4
+ from PIL import ImageDraw
5
+ from PIL import ImageFont
6
+ from PIL import ImageColor
7
+
8
+
9
+ def _is_legal_color(color):
10
+ if color is None:
11
+ return True
12
+ if isinstance(color, str):
13
+ return True
14
+ return isinstance(color, (tuple, list)) and len(color) == 3
15
+
16
+
17
+ def _normalize_color(color, pil_mode, swap_rgb=False):
18
+ if color is None:
19
+ return color
20
+ if isinstance(color, str):
21
+ color = ImageColor.getrgb(color)
22
+ gray = color[0]
23
+ if swap_rgb:
24
+ color = (color[2], color[1], color[0])
25
+ if pil_mode == 'L':
26
+ color = gray
27
+ return color
28
+
29
+
30
+ def draw_text(image, text, position, color=(255,0,0), font=None, font_size=15):
31
+ """Draws text on given image.
32
+
33
+ Args:
34
+ image (ndarray).
35
+ text (str): text to be drawn.
36
+ position (Tuple[int, int]): position where to be drawn.
37
+ color (List[Union[str, Tuple[int, int, int]]]): text color.
38
+ font (str): A filename or file-like object containing a TrueType font. If the file is not found in this
39
+ filename, the loader may also search in other directories, such as the `fonts/` directory on Windows
40
+ or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
41
+ font_size (int): The requested font size in points.
42
+
43
+ References:
44
+ torchvision.utils.draw_bounding_boxes
45
+ """
46
+ if isinstance(image, np.ndarray):
47
+ # PIL.Image.fromarray fails with uint16 arrays
48
+ # https://github.com/python-pillow/Pillow/issues/1514
49
+ if (image.dtype == np.uint16) and (image.ndim != 2):
50
+ image = (image / 256).astype(np.uint8)
51
+ pil_image = Image.fromarray(image)
52
+ elif isinstance(image, PIL.Image.Image):
53
+ pil_image = image
54
+ else:
55
+ raise TypeError('Unsupported image type!')
56
+ assert pil_image.mode in ['L', 'RGB', 'RGBA']
57
+
58
+ assert _is_legal_color(color)
59
+ color = _normalize_color(color, pil_image.mode, isinstance(image, np.ndarray))
60
+
61
+ if font is None:
62
+ font_object = ImageFont.load_default()
63
+ else:
64
+ font_object = ImageFont.truetype(font, size=font_size)
65
+
66
+ draw = ImageDraw.Draw(pil_image)
67
+ draw.text((position[0], position[1]), text,
68
+ fill=color, font=font_object)
69
+
70
+ if isinstance(image, np.ndarray):
71
+ return np.asarray(pil_image)
72
+ return pil_image
73
+
74
+
75
+ def draw_bounding_boxes(image, boxes, labels=None, colors=None,
76
+ fill=False, width=1, font=None, font_size=15):
77
+ """Draws bounding boxes on given image.
78
+
79
+ Args:
80
+ image (ndarray).
81
+ boxes (ndarray): ndarray of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format.
82
+ labels (List[str]): List containing the labels of bounding boxes.
83
+ colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes or labels.
84
+ fill (bool): If `True` fills the bounding box with specified color.
85
+ width (int): Width of bounding box.
86
+ font (str): A filename or file-like object containing a TrueType font. If the file is not found in this
87
+ filename, the loader may also search in other directories, such as the `fonts/` directory on Windows
88
+ or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
89
+ font_size (int): The requested font size in points.
90
+
91
+ References:
92
+ torchvision.utils.draw_bounding_boxes
93
+ """
94
+ if isinstance(image, np.ndarray):
95
+ # PIL.Image.fromarray fails with uint16 arrays
96
+ # https://github.com/python-pillow/Pillow/issues/1514
97
+ if (image.dtype == np.uint16) and (image.ndim != 2):
98
+ image = (image / 256).astype(np.uint8)
99
+ pil_image = Image.fromarray(image)
100
+ elif isinstance(image, PIL.Image.Image):
101
+ pil_image = image
102
+ else:
103
+ raise TypeError('Unsupported image type!')
104
+ pil_image = pil_image.convert('RGB')
105
+
106
+ if font is None:
107
+ font_object = ImageFont.load_default()
108
+ else:
109
+ font_object = ImageFont.truetype(font, size=font_size)
110
+
111
+ if fill:
112
+ draw = ImageDraw.Draw(pil_image, "RGBA")
113
+ else:
114
+ draw = ImageDraw.Draw(pil_image)
115
+
116
+ for i, bbox in enumerate(boxes):
117
+ if colors is None:
118
+ color = None
119
+ else:
120
+ color = colors[i]
121
+
122
+ assert _is_legal_color(color)
123
+ color = _normalize_color(color, pil_image.mode, isinstance(image, np.ndarray))
124
+
125
+ if fill:
126
+ if color is None:
127
+ fill_color = (255, 255, 255, 100)
128
+ elif isinstance(color, str):
129
+ # This will automatically raise Error if rgb cannot be parsed.
130
+ fill_color = ImageColor.getrgb(color) + (100,)
131
+ elif isinstance(color, tuple):
132
+ fill_color = color + (100,)
133
+ # the first argument of ImageDraw.rectangle:
134
+ # in old version only supports [(x0, y0), (x1, y1)]
135
+ # in new version supports either [(x0, y0), (x1, y1)] or [x0, y0, x1, y1]
136
+ draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], width=width, outline=color, fill=fill_color)
137
+ else:
138
+ draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], width=width, outline=color)
139
+
140
+ if labels is not None:
141
+ margin = width + 1
142
+ draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=font_object)
143
+
144
+ if isinstance(image, np.ndarray):
145
+ return np.asarray(pil_image)
146
+ return pil_image
147
+
148
+
khandy/feature_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import khandy
4
+ import numpy as np
5
+
6
+
7
+ def convert_feature_dict_to_array(feature_dict):
8
+ one_feature = khandy.get_dict_first_item(feature_dict)[1]
9
+ num_features = sum([len(item) for item in feature_dict.values()])
10
+
11
+ key_list = []
12
+ start_index = 0
13
+ feature_array = np.empty((num_features, one_feature.shape[-1]), one_feature.dtype)
14
+ for key, value in feature_dict.items():
15
+ feature_array[start_index: start_index + len(value)]= value
16
+ key_list += [key] * len(value)
17
+ start_index += len(value)
18
+ return key_list, feature_array
19
+
20
+
21
+ def convert_feature_array_to_dict(key_list, feature_array):
22
+ assert len(key_list) == len(feature_array)
23
+ feature_dict = OrderedDict()
24
+ for key, feat in zip(key_list, feature_array):
25
+ feature_dict.setdefault(key, []).append(feat)
26
+ for label in feature_dict.keys():
27
+ feature_dict[label] = np.vstack(feature_dict[label])
28
+ return feature_dict
29
+
30
+
31
+ def pairwise_distances(x, y, squared=True):
32
+ """Compute pairwise (squared) Euclidean distances.
33
+
34
+ References:
35
+ [2016 CVPR] Deep Metric Learning via Lifted Structured Feature Embedding
36
+ `euclidean_distances` from sklearn
37
+ """
38
+ assert isinstance(x, np.ndarray) and x.ndim == 2
39
+ assert isinstance(y, np.ndarray) and y.ndim == 2
40
+ assert x.shape[1] == y.shape[1]
41
+
42
+ x_square = np.expand_dims(np.einsum('ij,ij->i', x, x), axis=1)
43
+ if x is y:
44
+ y_square = x_square.T
45
+ else:
46
+ y_square = np.expand_dims(np.einsum('ij,ij->i', y, y), axis=0)
47
+ distances = np.dot(x, y.T)
48
+ # use inplace operation to accelerate
49
+ distances *= -2
50
+ distances += x_square
51
+ distances += y_square
52
+ # result maybe less than 0 due to floating point rounding errors.
53
+ np.maximum(distances, 0, distances)
54
+ if x is y:
55
+ # Ensure that distances between vectors and themselves are set to 0.0.
56
+ # This may not be the case due to floating point rounding errors.
57
+ distances.flat[::distances.shape[0] + 1] = 0.0
58
+ if not squared:
59
+ np.sqrt(distances, distances)
60
+ return distances
61
+
62
+
khandy/file_io_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import numbers
4
+ import pickle
5
+ import warnings
6
+ from collections import OrderedDict
7
+
8
+
9
+ def load_list(filename, encoding='utf-8', start=0, stop=None):
10
+ assert isinstance(start, numbers.Integral) and start >= 0
11
+ assert (stop is None) or (isinstance(stop, numbers.Integral) and stop > start)
12
+
13
+ lines = []
14
+ with open(filename, 'r', encoding=encoding) as f:
15
+ for _ in range(start):
16
+ f.readline()
17
+ for k, line in enumerate(f):
18
+ if (stop is not None) and (k + start > stop):
19
+ break
20
+ lines.append(line.rstrip('\n'))
21
+ return lines
22
+
23
+
24
+ def save_list(filename, list_obj, encoding='utf-8', append_break=True):
25
+ with open(filename, 'w', encoding=encoding) as f:
26
+ if append_break:
27
+ for item in list_obj:
28
+ f.write(str(item) + '\n')
29
+ else:
30
+ for item in list_obj:
31
+ f.write(str(item))
32
+
33
+
34
+ def load_json(filename, encoding='utf-8'):
35
+ with open(filename, 'r', encoding=encoding) as f:
36
+ data = json.load(f, object_pairs_hook=OrderedDict)
37
+ return data
38
+
39
+
40
+ def save_json(filename, data, encoding='utf-8', indent=4, cls=None, sort_keys=False):
41
+ if not filename.endswith('.json'):
42
+ filename = filename + '.json'
43
+ with open(filename, 'w', encoding=encoding) as f:
44
+ json.dump(data, f, indent=indent, separators=(',',': '),
45
+ ensure_ascii=False, cls=cls, sort_keys=sort_keys)
46
+
47
+
48
+ def load_bytes(filename, use_base64: bool = False) -> bytes:
49
+ """Open the file in bytes mode, read it, and close the file.
50
+
51
+ References:
52
+ pathlib.Path.read_bytes
53
+ """
54
+ with open(filename, 'rb') as f:
55
+ data = f.read()
56
+ if use_base64:
57
+ data = base64.b64encode(data)
58
+ return data
59
+
60
+
61
+ def save_bytes(filename, data: bytes, use_base64: bool = False) -> int:
62
+ """Open the file in bytes mode, write to it, and close the file.
63
+
64
+ References:
65
+ pathlib.Path.write_bytes
66
+ """
67
+ if use_base64:
68
+ data = base64.b64decode(data)
69
+ with open(filename, 'wb') as f:
70
+ ret = f.write(data)
71
+ return ret
72
+
73
+
74
+ def load_as_base64(filename) -> bytes:
75
+ warnings.warn('khandy.load_as_base64 will be deprecated, use khandy.load_bytes instead!')
76
+ return load_bytes(filename, True)
77
+
78
+
79
+ def load_object(filename):
80
+ with open(filename, 'rb') as f:
81
+ return pickle.load(f)
82
+
83
+
84
+ def save_object(filename, obj):
85
+ with open(filename, 'wb') as f:
86
+ pickle.dump(obj, f)
87
+
khandy/fs_utils.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+ import warnings
5
+
6
+
7
+ def get_path_stem(path):
8
+ """
9
+ References:
10
+ `std::filesystem::path::stem` since C++17
11
+ """
12
+ return os.path.splitext(os.path.basename(path))[0]
13
+
14
+
15
+ def replace_path_stem(path, new_stem):
16
+ dirname, basename = os.path.split(path)
17
+ stem, extension = os.path.splitext(basename)
18
+ if isinstance(new_stem, str):
19
+ return os.path.join(dirname, new_stem + extension)
20
+ elif hasattr(new_stem, '__call__'):
21
+ return os.path.join(dirname, new_stem(stem) + extension)
22
+ else:
23
+ raise TypeError('Unsupported Type!')
24
+
25
+
26
+ def get_path_extension(path):
27
+ """
28
+ References:
29
+ `std::filesystem::path::extension` since C++17
30
+
31
+ Notes:
32
+ Not fully consistent with `std::filesystem::path::extension`
33
+ """
34
+ return os.path.splitext(os.path.basename(path))[1]
35
+
36
+
37
+ def replace_path_extension(path, new_extension=None):
38
+ """Replaces the extension with new_extension or removes it when the default value is used.
39
+ Firstly, if this path has an extension, it is removed. Then, a dot character is appended
40
+ to the pathname, if new_extension is not empty or does not begin with a dot character.
41
+
42
+ References:
43
+ `std::filesystem::path::replace_extension` since C++17
44
+ """
45
+ filename_wo_ext = os.path.splitext(path)[0]
46
+ if new_extension == '' or new_extension is None:
47
+ return filename_wo_ext
48
+ elif new_extension.startswith('.'):
49
+ return ''.join([filename_wo_ext, new_extension])
50
+ else:
51
+ return '.'.join([filename_wo_ext, new_extension])
52
+
53
+
54
+ def normalize_extension(extension):
55
+ if extension.startswith('.'):
56
+ new_extension = extension.lower()
57
+ else:
58
+ new_extension = '.' + extension.lower()
59
+ return new_extension
60
+
61
+
62
+ def is_path_in_extensions(path, extensions):
63
+ if isinstance(extensions, str):
64
+ extensions = [extensions]
65
+ extensions = [normalize_extension(item) for item in extensions]
66
+ extension = get_path_extension(path)
67
+ return extension.lower() in extensions
68
+
69
+
70
+ def normalize_path(path, norm_case=True):
71
+ """
72
+ References:
73
+ https://en.cppreference.com/w/cpp/filesystem/canonical
74
+ """
75
+ # On Unix and Windows, return the argument with an initial
76
+ # component of ~ or ~user replaced by that user's home directory.
77
+ path = os.path.expanduser(path)
78
+ # Return a normalized absolutized version of the pathname path.
79
+ # On most platforms, this is equivalent to calling the function
80
+ # normpath() as follows: normpath(join(os.getcwd(), path)).
81
+ path = os.path.abspath(path)
82
+ if norm_case:
83
+ # Normalize the case of a pathname. On Windows,
84
+ # convert all characters in the pathname to lowercase,
85
+ # and also convert forward slashes to backward slashes.
86
+ # On other operating systems, return the path unchanged.
87
+ path = os.path.normcase(path)
88
+ return path
89
+
90
+
91
+ def makedirs(name, mode=0o755):
92
+ """
93
+ References:
94
+ mmcv.mkdir_or_exist
95
+ """
96
+ warnings.warn('`makedirs` will be deprecated!')
97
+ if name == '':
98
+ return
99
+ name = os.path.expanduser(name)
100
+ os.makedirs(name, mode=mode, exist_ok=True)
101
+
102
+
103
+ def listdirs(paths, path_sep=None, full_path=True):
104
+ """Enhancement on `os.listdir`
105
+ """
106
+ warnings.warn('`listdirs` will be deprecated!')
107
+ assert isinstance(paths, (str, tuple, list))
108
+ if isinstance(paths, str):
109
+ path_sep = path_sep or os.path.pathsep
110
+ paths = paths.split(path_sep)
111
+
112
+ all_filenames = []
113
+ for path in paths:
114
+ path_ex = os.path.expanduser(path)
115
+ filenames = os.listdir(path_ex)
116
+ if full_path:
117
+ filenames = [os.path.join(path_ex, filename) for filename in filenames]
118
+ all_filenames.extend(filenames)
119
+ return all_filenames
120
+
121
+
122
+ def get_all_filenames(path, extensions=None, is_valid_file=None):
123
+ warnings.warn('`get_all_filenames` will be deprecated, use `list_files_in_dir` with `recursive=True` instead!')
124
+ if (extensions is not None) and (is_valid_file is not None):
125
+ raise ValueError("Both extensions and is_valid_file cannot "
126
+ "be not None at the same time")
127
+ if is_valid_file is None:
128
+ if extensions is not None:
129
+ def is_valid_file(filename):
130
+ return is_path_in_extensions(filename, extensions)
131
+ else:
132
+ def is_valid_file(filename):
133
+ return True
134
+
135
+ all_filenames = []
136
+ path_ex = os.path.expanduser(path)
137
+ for root, _, filenames in sorted(os.walk(path_ex, followlinks=True)):
138
+ for filename in sorted(filenames):
139
+ fullname = os.path.join(root, filename)
140
+ if is_valid_file(fullname):
141
+ all_filenames.append(fullname)
142
+ return all_filenames
143
+
144
+
145
+ def get_top_level_dirs(path, full_path=True):
146
+ warnings.warn('`get_top_level_dirs` will be deprecated, use `list_dirs_in_dir` instead!')
147
+ if path is None:
148
+ path = os.getcwd()
149
+ path_ex = os.path.expanduser(path)
150
+ filenames = os.listdir(path_ex)
151
+ if full_path:
152
+ return [os.path.join(path_ex, item) for item in filenames
153
+ if os.path.isdir(os.path.join(path_ex, item))]
154
+ else:
155
+ return [item for item in filenames
156
+ if os.path.isdir(os.path.join(path_ex, item))]
157
+
158
+
159
+ def get_top_level_files(path, full_path=True):
160
+ warnings.warn('`get_top_level_files` will be deprecated, use `list_files_in_dir` instead!')
161
+ if path is None:
162
+ path = os.getcwd()
163
+ path_ex = os.path.expanduser(path)
164
+ filenames = os.listdir(path_ex)
165
+ if full_path:
166
+ return [os.path.join(path_ex, item) for item in filenames
167
+ if os.path.isfile(os.path.join(path_ex, item))]
168
+ else:
169
+ return [item for item in filenames
170
+ if os.path.isfile(os.path.join(path_ex, item))]
171
+
172
+
173
+ def list_items_in_dir(path=None, recursive=False, full_path=True):
174
+ """List all entries in directory
175
+ """
176
+ if path is None:
177
+ path = os.getcwd()
178
+ path_ex = os.path.expanduser(path)
179
+
180
+ if not recursive:
181
+ names = os.listdir(path_ex)
182
+ if full_path:
183
+ return [os.path.join(path_ex, name) for name in sorted(names)]
184
+ else:
185
+ return sorted(names)
186
+ else:
187
+ all_names = []
188
+ for root, dirnames, filenames in sorted(os.walk(path_ex, followlinks=True)):
189
+ all_names += [os.path.join(root, name) for name in sorted(dirnames)]
190
+ all_names += [os.path.join(root, name) for name in sorted(filenames)]
191
+ return all_names
192
+
193
+
194
+ def list_dirs_in_dir(path=None, recursive=False, full_path=True):
195
+ """List all dirs in directory
196
+ """
197
+ if path is None:
198
+ path = os.getcwd()
199
+ path_ex = os.path.expanduser(path)
200
+
201
+ if not recursive:
202
+ names = os.listdir(path_ex)
203
+ if full_path:
204
+ return [os.path.join(path_ex, name) for name in sorted(names)
205
+ if os.path.isdir(os.path.join(path_ex, name))]
206
+ else:
207
+ return [name for name in sorted(names)
208
+ if os.path.isdir(os.path.join(path_ex, name))]
209
+ else:
210
+ all_names = []
211
+ for root, dirnames, _ in sorted(os.walk(path_ex, followlinks=True)):
212
+ all_names += [os.path.join(root, name) for name in sorted(dirnames)]
213
+ return all_names
214
+
215
+
216
+ def list_files_in_dir(path=None, recursive=False, full_path=True):
217
+ """List all files in directory
218
+ """
219
+ if path is None:
220
+ path = os.getcwd()
221
+ path_ex = os.path.expanduser(path)
222
+
223
+ if not recursive:
224
+ names = os.listdir(path_ex)
225
+ if full_path:
226
+ return [os.path.join(path_ex, name) for name in sorted(names)
227
+ if os.path.isfile(os.path.join(path_ex, name))]
228
+ else:
229
+ return [name for name in sorted(names)
230
+ if os.path.isfile(os.path.join(path_ex, name))]
231
+ else:
232
+ all_names = []
233
+ for root, _, filenames in sorted(os.walk(path_ex, followlinks=True)):
234
+ all_names += [os.path.join(root, name) for name in sorted(filenames)]
235
+ return all_names
236
+
237
+
238
+ def get_folder_size(dirname):
239
+ if not os.path.exists(dirname):
240
+ raise ValueError("Incorrect path: {}".format(dirname))
241
+ total_size = 0
242
+ for root, _, filenames in os.walk(dirname):
243
+ for name in filenames:
244
+ total_size += os.path.getsize(os.path.join(root, name))
245
+ return total_size
246
+
247
+
248
+ def escape_filename(filename, new_char='_'):
249
+ assert isinstance(new_char, str)
250
+ control_chars = ''.join((map(chr, range(0x00, 0x20))))
251
+ pattern = r'[\\/*?:"<>|{}]'.format(control_chars)
252
+ return re.sub(pattern, new_char, filename)
253
+
254
+
255
+ def replace_invalid_filename_char(filename, new_char='_'):
256
+ warnings.warn('`replace_invalid_filename_char` will be deprecated, use `escape_filename` instead!')
257
+ return escape_filename(filename, new_char)
258
+
259
+
260
+ def copy_file(src, dst_dir, action_if_exist='rename'):
261
+ """
262
+ Args:
263
+ src: source file path
264
+ dst_dir: dest dir
265
+ action_if_exist:
266
+ None: same as shutil.copy
267
+ ignore: when dest file exists, don't copy and return None
268
+ rename: when dest file exists, copy after rename
269
+
270
+ Returns:
271
+ dest filename
272
+ """
273
+ dst = os.path.join(dst_dir, os.path.basename(src))
274
+
275
+ if action_if_exist is None:
276
+ os.makedirs(dst_dir, exist_ok=True)
277
+ shutil.copy(src, dst)
278
+ elif action_if_exist.lower() == 'ignore':
279
+ if os.path.exists(dst):
280
+ warnings.warn(f'{dst} already exists, do not copy!')
281
+ return dst
282
+ os.makedirs(dst_dir, exist_ok=True)
283
+ shutil.copy(src, dst)
284
+ elif action_if_exist.lower() == 'rename':
285
+ suffix = 2
286
+ stem, extension = os.path.splitext(os.path.basename(src))
287
+ while os.path.exists(dst):
288
+ dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
289
+ suffix += 1
290
+ os.makedirs(dst_dir, exist_ok=True)
291
+ shutil.copy(src, dst)
292
+ else:
293
+ raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
294
+
295
+ return dst
296
+
297
+
298
+ def move_file(src, dst_dir, action_if_exist='rename'):
299
+ """
300
+ Args:
301
+ src: source file path
302
+ dst_dir: dest dir
303
+ action_if_exist:
304
+ None: same as shutil.move
305
+ ignore: when dest file exists, don't move and return None
306
+ rename: when dest file exists, move after rename
307
+
308
+ Returns:
309
+ dest filename
310
+ """
311
+ dst = os.path.join(dst_dir, os.path.basename(src))
312
+
313
+ if action_if_exist is None:
314
+ os.makedirs(dst_dir, exist_ok=True)
315
+ shutil.move(src, dst)
316
+ elif action_if_exist.lower() == 'ignore':
317
+ if os.path.exists(dst):
318
+ warnings.warn(f'{dst} already exists, do not move!')
319
+ return dst
320
+ os.makedirs(dst_dir, exist_ok=True)
321
+ shutil.move(src, dst)
322
+ elif action_if_exist.lower() == 'rename':
323
+ suffix = 2
324
+ stem, extension = os.path.splitext(os.path.basename(src))
325
+ while os.path.exists(dst):
326
+ dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
327
+ suffix += 1
328
+ os.makedirs(dst_dir, exist_ok=True)
329
+ shutil.move(src, dst)
330
+ else:
331
+ raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
332
+
333
+ return dst
334
+
335
+
336
+ def rename_file(src, dst, action_if_exist='rename'):
337
+ """
338
+ Args:
339
+ src: source file path
340
+ dst: dest file path
341
+ action_if_exist:
342
+ None: same as os.rename
343
+ ignore: when dest file exists, don't rename and return None
344
+ rename: when dest file exists, rename it
345
+
346
+ Returns:
347
+ dest filename
348
+ """
349
+ if dst == src:
350
+ return dst
351
+ dst_dir = os.path.dirname(os.path.abspath(dst))
352
+
353
+ if action_if_exist is None:
354
+ os.makedirs(dst_dir, exist_ok=True)
355
+ os.rename(src, dst)
356
+ elif action_if_exist.lower() == 'ignore':
357
+ if os.path.exists(dst):
358
+ warnings.warn(f'{dst} already exists, do not rename!')
359
+ return dst
360
+ os.makedirs(dst_dir, exist_ok=True)
361
+ os.rename(src, dst)
362
+ elif action_if_exist.lower() == 'rename':
363
+ suffix = 2
364
+ stem, extension = os.path.splitext(os.path.basename(dst))
365
+ while os.path.exists(dst):
366
+ dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
367
+ suffix += 1
368
+ os.makedirs(dst_dir, exist_ok=True)
369
+ os.rename(src, dst)
370
+ else:
371
+ raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
372
+
373
+ return dst
374
+
375
+
khandy/hash_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+
3
+
4
+ def calc_hash(content, hash_object=None):
5
+ hash_object = hash_object or hashlib.md5()
6
+ if isinstance(hash_object, str):
7
+ hash_object = hashlib.new(hash_object)
8
+ hash_object.update(content)
9
+ return hash_object.hexdigest()
10
+
11
+
12
+ def calc_file_hash(filename, hash_object=None, chunk_size=1024 * 1024):
13
+ hash_object = hash_object or hashlib.md5()
14
+ if isinstance(hash_object, str):
15
+ hash_object = hashlib.new(hash_object)
16
+
17
+ with open(filename, "rb") as f:
18
+ while True:
19
+ chunk = f.read(chunk_size)
20
+ if not chunk:
21
+ break
22
+ hash_object.update(chunk)
23
+ return hash_object.hexdigest()
24
+
25
+
khandy/image/align_and_crop.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def get_similarity_transform(src_pts, dst_pts):
6
+ """Get similarity transform matrix from src_pts to dst_pts
7
+
8
+ Args:
9
+ src_pts: Kx2 np.array
10
+ source points matrix, each row is a pair of coordinates (x, y)
11
+ dst_pts: Kx2 np.array
12
+ destination points matrix, each row is a pair of coordinates (x, y)
13
+
14
+ Returns:
15
+ xform_matrix: 3x3 np.array
16
+ transform matrix from src_pts to dst_pts
17
+ """
18
+ src_pts = np.asarray(src_pts)
19
+ dst_pts = np.asarray(dst_pts)
20
+ assert src_pts.shape == dst_pts.shape
21
+ assert (src_pts.ndim == 2) and (src_pts.shape[-1] == 2)
22
+
23
+ npts = src_pts.shape[0]
24
+ src_x = src_pts[:, 0].reshape((-1, 1))
25
+ src_y = src_pts[:, 1].reshape((-1, 1))
26
+ tmp1 = np.hstack((src_x, -src_y, np.ones((npts, 1)), np.zeros((npts, 1))))
27
+ tmp2 = np.hstack((src_y, src_x, np.zeros((npts, 1)), np.ones((npts, 1))))
28
+ A = np.vstack((tmp1, tmp2))
29
+
30
+ dst_x = dst_pts[:, 0].reshape((-1, 1))
31
+ dst_y = dst_pts[:, 1].reshape((-1, 1))
32
+ b = np.vstack((dst_x, dst_y))
33
+
34
+ x = np.linalg.lstsq(A, b, rcond=-1)[0]
35
+ x = np.squeeze(x)
36
+ sc, ss, tx, ty = x[0], x[1], x[2], x[3]
37
+ xform_matrix = np.array([
38
+ [sc, -ss, tx],
39
+ [ss, sc, ty],
40
+ [ 0, 0, 1]
41
+ ])
42
+ return xform_matrix
43
+
44
+
45
+ def align_and_crop(image, landmarks, std_landmarks, align_size,
46
+ border_value=0, return_transform_matrix=False):
47
+ landmarks = np.asarray(landmarks)
48
+ std_landmarks = np.asarray(std_landmarks)
49
+ xform_matrix = get_similarity_transform(landmarks, std_landmarks)
50
+
51
+ landmarks_ex = np.pad(landmarks, ((0,0),(0,1)), mode='constant', constant_values=1)
52
+ dst_landmarks = np.dot(landmarks_ex, xform_matrix[:2,:].T)
53
+ dst_image = cv2.warpAffine(image, xform_matrix[:2,:], dsize=align_size,
54
+ borderValue=border_value)
55
+ if return_transform_matrix:
56
+ return dst_image, dst_landmarks, xform_matrix
57
+ else:
58
+ return dst_image, dst_landmarks
59
+
60
+
khandy/image/crop_or_pad.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import warnings
3
+
4
+ import khandy
5
+ import numpy as np
6
+
7
+
8
+ def crop(image, x_min, y_min, x_max, y_max, border_value=0):
9
+ """Crop the given image at specified rectangular area.
10
+
11
+ See Also:
12
+ translate_image
13
+
14
+ References:
15
+ PIL.Image.crop
16
+ tf.image.resize_image_with_crop_or_pad
17
+ """
18
+ assert khandy.is_numpy_image(image)
19
+ assert isinstance(x_min, numbers.Integral) and isinstance(y_min, numbers.Integral)
20
+ assert isinstance(x_max, numbers.Integral) and isinstance(y_max, numbers.Integral)
21
+ assert (x_min <= x_max) and (y_min <= y_max)
22
+
23
+ src_height, src_width = image.shape[:2]
24
+ dst_height, dst_width = y_max - y_min + 1, x_max - x_min + 1
25
+ channels = 1 if image.ndim == 2 else image.shape[2]
26
+
27
+ if isinstance(border_value, (tuple, list)):
28
+ assert len(border_value) == channels, \
29
+ 'Expected the num of elements in tuple equals the channels ' \
30
+ 'of input image. Found {} vs {}'.format(
31
+ len(border_value), channels)
32
+ else:
33
+ border_value = (border_value,) * channels
34
+ dst_image = khandy.create_solid_color_image(
35
+ dst_width, dst_height, border_value, dtype=image.dtype)
36
+
37
+ src_x_begin = max(x_min, 0)
38
+ src_x_end = min(x_max + 1, src_width)
39
+ dst_x_begin = src_x_begin - x_min
40
+ dst_x_end = src_x_end - x_min
41
+
42
+ src_y_begin = max(y_min, 0)
43
+ src_y_end = min(y_max + 1, src_height)
44
+ dst_y_begin = src_y_begin - y_min
45
+ dst_y_end = src_y_end - y_min
46
+
47
+ if (src_x_begin >= src_x_end) or (src_y_begin >= src_y_end):
48
+ return dst_image
49
+ dst_image[dst_y_begin: dst_y_end, dst_x_begin: dst_x_end, ...] = \
50
+ image[src_y_begin: src_y_end, src_x_begin: src_x_end, ...]
51
+ return dst_image
52
+
53
+
54
+ def crop_or_pad(image, x_min, y_min, x_max, y_max, border_value=0):
55
+ warnings.warn('crop_or_pad will be deprecated, use crop instead!')
56
+ return crop(image, x_min, y_min, x_max, y_max, border_value)
57
+
58
+
59
+ def crop_coords(boxes, image_width, image_height):
60
+ """
61
+ References:
62
+ `mmcv.impad`
63
+ `pad` in https://github.com/kpzhang93/MTCNN_face_detection_alignment
64
+ `MtcnnDetector.pad` in https://github.com/AITTSMD/MTCNN-Tensorflow
65
+ """
66
+ x_mins = boxes[:, 0]
67
+ y_mins = boxes[:, 1]
68
+ x_maxs = boxes[:, 2]
69
+ y_maxs = boxes[:, 3]
70
+ dst_widths = x_maxs - x_mins + 1
71
+ dst_heights = y_maxs - y_mins + 1
72
+
73
+ src_x_begin = np.maximum(x_mins, 0)
74
+ src_x_end = np.minimum(x_maxs + 1, image_width)
75
+ dst_x_begin = src_x_begin - x_mins
76
+ dst_x_end = src_x_end - x_mins
77
+
78
+ src_y_begin = np.maximum(y_mins, 0)
79
+ src_y_end = np.minimum(y_maxs + 1, image_height)
80
+ dst_y_begin = src_y_begin - y_mins
81
+ dst_y_end = src_y_end - y_mins
82
+
83
+ coords = np.stack([dst_y_begin, dst_y_end, dst_x_begin, dst_x_end,
84
+ src_y_begin, src_y_end, src_x_begin, src_x_end,
85
+ dst_heights, dst_widths], axis=0)
86
+ return coords
87
+
88
+
89
+ def crop_or_pad_coords(boxes, image_width, image_height):
90
+ warnings.warn('crop_or_pad_coords will be deprecated, use crop_coords instead!')
91
+ return crop_coords(boxes, image_width, image_height)
92
+
93
+
94
+ def center_crop(image, dst_width, dst_height, strict=True):
95
+ """
96
+ strict:
97
+ when True, raise error if src size is less than dst size.
98
+ when False, remain unchanged if src size is less than dst size, otherwise center crop.
99
+ """
100
+ assert khandy.is_numpy_image(image)
101
+ assert isinstance(dst_width, numbers.Integral) and isinstance(dst_height, numbers.Integral)
102
+ src_height, src_width = image.shape[:2]
103
+ if strict:
104
+ assert (src_height >= dst_height) and (src_width >= dst_width)
105
+
106
+ crop_top = max((src_height - dst_height) // 2, 0)
107
+ crop_left = max((src_width - dst_width) // 2, 0)
108
+ cropped = image[crop_top: dst_height + crop_top,
109
+ crop_left: dst_width + crop_left, ...]
110
+ return cropped
111
+
112
+
113
+ def center_pad(image, dst_width, dst_height, strict=True):
114
+ """
115
+ strict:
116
+ when True, raise error if src size is greater than dst size.
117
+ when False, remain unchanged if src size is greater than dst size, otherwise center pad.
118
+ """
119
+ assert khandy.is_numpy_image(image)
120
+ assert isinstance(dst_width, numbers.Integral) and isinstance(dst_height, numbers.Integral)
121
+
122
+ src_height, src_width = image.shape[:2]
123
+ if strict:
124
+ assert (src_height <= dst_height) and (src_width <= dst_width)
125
+
126
+ padding_x = max(dst_width - src_width, 0)
127
+ padding_y = max(dst_height - src_height, 0)
128
+ padding_top = padding_y // 2
129
+ padding_left = padding_x // 2
130
+ if image.ndim == 2:
131
+ padding = ((padding_top, padding_y - padding_top),
132
+ (padding_left, padding_x - padding_left))
133
+ else:
134
+ padding = ((padding_top, padding_y - padding_top),
135
+ (padding_left, padding_x - padding_left), (0, 0))
136
+ return np.pad(image, padding, 'constant')
137
+
138
+
khandy/image/flip.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import khandy
2
+ import numpy as np
3
+
4
+
5
+ def flip_image(image, direction='h', copy=True):
6
+ """
7
+ References:
8
+ np.flipud, np.fliplr, np.flip
9
+ cv2.flip
10
+ tf.image.flip_up_down
11
+ tf.image.flip_left_right
12
+ """
13
+ assert khandy.is_numpy_image(image)
14
+ assert direction in ['x', 'h', 'horizontal',
15
+ 'y', 'v', 'vertical',
16
+ 'o', 'b', 'both']
17
+ if copy:
18
+ image = image.copy()
19
+ if direction in ['o', 'b', 'both', 'x', 'h', 'horizontal']:
20
+ image = np.fliplr(image)
21
+ if direction in ['o', 'b', 'both', 'y', 'v', 'vertical']:
22
+ image = np.flipud(image)
23
+ return image
24
+
25
+
26
+ def transpose_image(image, copy=True):
27
+ """Transpose image.
28
+
29
+ References:
30
+ np.transpose
31
+ cv2.transpose
32
+ tf.image.transpose
33
+ """
34
+ assert khandy.is_numpy_image(image)
35
+ if copy:
36
+ image = image.copy()
37
+ if image.ndim == 2:
38
+ transpose_axes = (1, 0)
39
+ else:
40
+ transpose_axes = (1, 0, 2)
41
+ image = np.transpose(image, transpose_axes)
42
+ return image
43
+
44
+
45
+ def rot90_image(image, n=1, copy=True):
46
+ """Rotate image counter-clockwise by 90 degrees.
47
+
48
+ References:
49
+ np.rot90
50
+ cv2.rotate
51
+ tf.image.rot90
52
+ """
53
+ assert khandy.is_numpy_image(image)
54
+ if copy:
55
+ image = image.copy()
56
+ if image.ndim == 2:
57
+ transpose_axes = (1, 0)
58
+ else:
59
+ transpose_axes = (1, 0, 2)
60
+
61
+ n = n % 4
62
+ if n == 0:
63
+ return image[:]
64
+ elif n == 1:
65
+ image = np.transpose(image, transpose_axes)
66
+ image = np.flipud(image)
67
+ elif n == 2:
68
+ image = np.fliplr(np.flipud(image))
69
+ else:
70
+ image = np.transpose(image, transpose_axes)
71
+ image = np.fliplr(image)
72
+ return image
khandy/image/image_hash.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import khandy
3
+ import numpy as np
4
+
5
+
6
+ def _convert_bool_matrix_to_int(bool_mat):
7
+ hash_val = int(0)
8
+ for item in bool_mat.flatten():
9
+ hash_val <<= 1
10
+ hash_val |= int(item)
11
+ return hash_val
12
+
13
+
14
+ def calc_image_ahash(image):
15
+ """Average Hashing
16
+
17
+ References:
18
+ http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
19
+ """
20
+ assert khandy.is_numpy_image(image)
21
+ if image.ndim == 3:
22
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
23
+ resized = cv2.resize(image, (8, 8))
24
+
25
+ mean_val = np.mean(resized)
26
+ hash_mat = resized >= mean_val
27
+ hash_val = _convert_bool_matrix_to_int(hash_mat)
28
+ return f'{hash_val:016x}'
29
+
30
+
31
+ def calc_image_dhash(image):
32
+ """Difference Hashing
33
+
34
+ References:
35
+ http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
36
+ """
37
+ assert khandy.is_numpy_image(image)
38
+ if image.ndim == 3:
39
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
40
+ resized = cv2.resize(image, (9, 8))
41
+
42
+ hash_mat = resized[:,:-1] >= resized[:,1:]
43
+ hash_val = _convert_bool_matrix_to_int(hash_mat)
44
+ return f'{hash_val:016x}'
45
+
46
+
47
+ def calc_image_phash(image):
48
+ """Perceptual Hashing
49
+
50
+ References:
51
+ http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
52
+ """
53
+ assert khandy.is_numpy_image(image)
54
+ if image.ndim == 3:
55
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
56
+ resized = cv2.resize(image, (32, 32))
57
+
58
+ dct_coeff = cv2.dct(resized.astype(np.float32))
59
+ reduced_dct_coeff = dct_coeff[:8, :8]
60
+
61
+ # # mean of coefficients excluding the DC term (0th term)
62
+ # mean_val = np.mean(reduced_dct_coeff.flatten()[1:])
63
+ # median of coefficients
64
+ median_val = np.median(reduced_dct_coeff)
65
+
66
+ hash_mat = reduced_dct_coeff >= median_val
67
+ hash_val = _convert_bool_matrix_to_int(hash_mat)
68
+ return f'{hash_val:016x}'
69
+
khandy/image/misc.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imghdr
3
+ import numbers
4
+ import warnings
5
+ from io import BytesIO
6
+
7
+ import cv2
8
+ import khandy
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+
13
+ def imread(file_or_buffer, flags=-1):
14
+ """Improvement on cv2.imread, make it support filename including chinese character.
15
+ """
16
+ try:
17
+ if isinstance(file_or_buffer, bytes):
18
+ return cv2.imdecode(np.frombuffer(file_or_buffer, dtype=np.uint8), flags)
19
+ else:
20
+ # support type: file or str or Path
21
+ return cv2.imdecode(np.fromfile(file_or_buffer, dtype=np.uint8), flags)
22
+ except Exception as e:
23
+ print(e)
24
+ return None
25
+
26
+
27
+ def imread_cv(file_or_buffer, flags=-1):
28
+ warnings.warn('khandy.imread_cv will be deprecated, use khandy.imread instead!')
29
+ return imread(file_or_buffer, flags)
30
+
31
+
32
+ def imwrite(filename, image, params=None):
33
+ """Improvement on cv2.imwrite, make it support filename including chinese character.
34
+ """
35
+ cv2.imencode(os.path.splitext(filename)[-1], image, params)[1].tofile(filename)
36
+
37
+
38
+ def imwrite_cv(filename, image, params=None):
39
+ warnings.warn('khandy.imwrite_cv will be deprecated, use khandy.imwrite instead!')
40
+ return imwrite(filename, image, params)
41
+
42
+
43
+ def imread_pil(file_or_buffer, to_mode=None):
44
+ """Improvement on Image.open to avoid ResourceWarning.
45
+ """
46
+ try:
47
+ if isinstance(file_or_buffer, bytes):
48
+ buffer = BytesIO()
49
+ buffer.write(file_or_buffer)
50
+ buffer.seek(0)
51
+ file_or_buffer = buffer
52
+
53
+ if hasattr(file_or_buffer, 'read'):
54
+ image = Image.open(file_or_buffer)
55
+ if to_mode is not None:
56
+ image = image.convert(to_mode)
57
+ else:
58
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
59
+ with open(file_or_buffer, 'rb') as f:
60
+ image = Image.open(f)
61
+ # If convert outside with statement, will raise "seek of closed file" as
62
+ # https://github.com/microsoft/Swin-Transformer/issues/66
63
+ if to_mode is not None:
64
+ image = image.convert(to_mode)
65
+ return image
66
+ except Exception as e:
67
+ print(e)
68
+ return None
69
+
70
+
71
+ def imwrite_bytes(filename, image_bytes: bytes, update_extension: bool = True):
72
+ """Write image bytes to file.
73
+
74
+ Args:
75
+ filename: str
76
+ filename which image_bytes is written into.
77
+ image_bytes: bytes
78
+ image content to be written.
79
+ update_extension: bool
80
+ whether update extension according to image_bytes or not.
81
+ the cost of update extension is smaller than update image format.
82
+ """
83
+ extension = imghdr.what('', image_bytes)
84
+ file_extension = khandy.get_path_extension(filename)
85
+ # imghdr.what fails to determine image format sometimes!
86
+ # so when its return value is None, never update extension.
87
+ if extension is None:
88
+ image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
89
+ image_bytes = cv2.imencode(file_extension, image)[1]
90
+ elif (extension.lower() != file_extension.lower()[1:]):
91
+ if update_extension:
92
+ filename = khandy.replace_path_extension(filename, extension)
93
+ else:
94
+ image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
95
+ image_bytes = cv2.imencode(file_extension, image)[1]
96
+
97
+ with open(filename, "wb") as f:
98
+ f.write(image_bytes)
99
+ return filename
100
+
101
+
102
+ def rescale_image(image: np.ndarray, rescale_factor='auto', dst_dtype=np.float32):
103
+ """Rescale image by rescale_factor.
104
+
105
+ Args:
106
+ img (ndarray): Image to be rescaled.
107
+ rescale_factor (str, int or float, *optional*, defaults to `'auto'`):
108
+ rescale the image by the specified scale factor. When is `'auto'`,
109
+ rescale the image to [0, 1).
110
+ dtype (np.dtype, *optional*, defaults to `np.float32`):
111
+ The dtype of the output image. Defaults to `np.float32`.
112
+
113
+ Returns:
114
+ ndarray: The rescaled image.
115
+ """
116
+ if rescale_factor == 'auto':
117
+ if np.issubdtype(image.dtype, np.unsignedinteger):
118
+ rescale_factor = 1. / np.iinfo(image.dtype).max
119
+ else:
120
+ raise TypeError(f'Only support uint dtype ndarray when `rescale_factor` is `auto`, got {image.dtype}')
121
+ elif issubclass(rescale_factor, (int, float)):
122
+ pass
123
+ else:
124
+ raise TypeError('rescale_factor must be "auto", int or float')
125
+ image = image.astype(dst_dtype, copy=True)
126
+ image *= rescale_factor
127
+ image = image.astype(dst_dtype)
128
+ return image
129
+
130
+
131
+ def normalize_image_value(image: np.ndarray, mean, std, rescale_factor=None):
132
+ """Normalize an image with mean and std, rescale optionally.
133
+
134
+ Args:
135
+ image (ndarray): Image to be normalized.
136
+ mean (int, float, Sequence[int], Sequence[float], ndarray): The mean to be used for normalize.
137
+ std (int, float, Sequence[int], Sequence[float], ndarray): The std to be used for normalize.
138
+ rescale_factor (None, 'auto', int or float, *optional*, defaults to `None`):
139
+ rescale the image by the specified scale factor. When is `'auto'`,
140
+ rescale the image to [0, 1); When is `None`, do not rescale.
141
+
142
+ Returns:
143
+ ndarray: The normalized image which dtype is np.float32.
144
+ """
145
+ dst_dtype = np.float32
146
+ mean = np.array(mean, dtype=dst_dtype).flatten()
147
+ std = np.array(std, dtype=dst_dtype).flatten()
148
+ if rescale_factor == 'auto':
149
+ if np.issubdtype(image.dtype, np.unsignedinteger):
150
+ mean *= np.iinfo(image.dtype).max
151
+ std *= np.iinfo(image.dtype).max
152
+ else:
153
+ raise TypeError(f'Only support uint dtype ndarray when `rescale_factor` is `auto`, got {image.dtype}')
154
+ elif isinstance(rescale_factor, (int, float)):
155
+ mean *= rescale_factor
156
+ std *= rescale_factor
157
+ image = image.astype(dst_dtype, copy=True)
158
+ image -= mean
159
+ image /= std
160
+ return image
161
+
162
+
163
+ def normalize_image_dtype(image, keep_num_channels=False):
164
+ """Normalize image dtype to uint8 (usually for visualization).
165
+
166
+ Args:
167
+ image : ndarray
168
+ Input image.
169
+ keep_num_channels : bool, optional
170
+ If this is set to True, the result is an array which has
171
+ the same shape as input image, otherwise the result is
172
+ an array whose channels number is 3.
173
+
174
+ Returns:
175
+ out: ndarray
176
+ Image whose dtype is np.uint8.
177
+ """
178
+ assert (image.ndim == 3 and image.shape[-1] in [1, 3]) or (image.ndim == 2)
179
+
180
+ image = image.astype(np.float32)
181
+ image = khandy.minmax_normalize(image, axis=None, copy=False)
182
+ image = np.array(image * 255, dtype=np.uint8)
183
+
184
+ if not keep_num_channels:
185
+ if image.ndim == 2:
186
+ image = np.expand_dims(image, -1)
187
+ if image.shape[-1] == 1:
188
+ image = np.tile(image, (1,1,3))
189
+ return image
190
+
191
+
192
+ def normalize_image_channel(image, swap_rb=False):
193
+ """Normalize image channel number and order to RGB or BGR.
194
+
195
+ Args:
196
+ image : ndarray
197
+ Input image.
198
+ swap_rb : bool, optional
199
+ whether swap red and blue channel or not
200
+
201
+ Returns:
202
+ out: ndarray
203
+ Image whose shape is (..., 3).
204
+ """
205
+ if image.ndim == 2:
206
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
207
+ elif image.ndim == 3:
208
+ num_channels = image.shape[-1]
209
+ if num_channels == 1:
210
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
211
+ elif num_channels == 3:
212
+ if swap_rb:
213
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
214
+ elif num_channels == 4:
215
+ if swap_rb:
216
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
217
+ else:
218
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
219
+ else:
220
+ raise ValueError(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')
221
+ else:
222
+ raise ValueError(f'Unsupported image ndarray ndim, only support 2 and 3, got {image.ndim}!')
223
+ return image
224
+
225
+
226
+ def normalize_image_shape(image, swap_rb=False):
227
+ warnings.warn('khandy.normalize_image_shape will be deprecated, use khandy.normalize_image_channel instead!')
228
+ return normalize_image_channel(image, swap_rb)
229
+
230
+
231
+ def stack_image_list(image_list, dtype=np.float32):
232
+ """Join a sequence of image along a new axis before first axis.
233
+
234
+ References:
235
+ `im_list_to_blob` in `py-faster-rcnn-master/lib/utils/blob.py`
236
+ """
237
+ assert isinstance(image_list, (tuple, list))
238
+
239
+ max_dimension = np.array([image.ndim for image in image_list]).max()
240
+ assert max_dimension in [2, 3]
241
+ max_shape = np.array([image.shape[:2] for image in image_list]).max(axis=0)
242
+
243
+ num_channels = []
244
+ for image in image_list:
245
+ if image.ndim == 2:
246
+ num_channels.append(1)
247
+ else:
248
+ num_channels.append(image.shape[-1])
249
+ assert len(set(num_channels) - set([1])) in [0, 1]
250
+ max_num_channels = np.max(num_channels)
251
+
252
+ blob = np.empty((len(image_list), max_shape[0], max_shape[1], max_num_channels), dtype=dtype)
253
+ for k, image in enumerate(image_list):
254
+ blob[k, :image.shape[0], :image.shape[1], :] = np.atleast_3d(image).astype(dtype, copy=False)
255
+ if max_dimension == 2:
256
+ blob = np.squeeze(blob, axis=-1)
257
+ return blob
258
+
259
+
260
+ def is_numpy_image(image):
261
+ return isinstance(image, np.ndarray) and image.ndim in {2, 3}
262
+
263
+
264
+ def is_gray_image(image, tol=3):
265
+ assert is_numpy_image(image)
266
+ if image.ndim == 2:
267
+ return True
268
+ elif image.ndim == 3:
269
+ num_channels = image.shape[-1]
270
+ if num_channels == 1:
271
+ return True
272
+ elif num_channels == 3:
273
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
274
+ gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
275
+ mae = np.mean(cv2.absdiff(image, gray3))
276
+ return mae <= tol
277
+ elif num_channels == 4:
278
+ rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
279
+ gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
280
+ gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
281
+ mae = np.mean(cv2.absdiff(rgb, gray3))
282
+ return mae <= tol
283
+ else:
284
+ return False
285
+ else:
286
+ return False
287
+
288
+
289
+ def is_solid_color_image(image, tol=4):
290
+ assert is_numpy_image(image)
291
+ mean = np.array(cv2.mean(image)[:-1], dtype=np.float32)
292
+
293
+ if image.ndim == 2:
294
+ mae = np.mean(np.abs(image - mean[0]))
295
+ return mae <= tol
296
+ elif image.ndim == 3:
297
+ num_channels = image.shape[-1]
298
+ if num_channels == 1:
299
+ mae = np.mean(np.abs(image - mean[0]))
300
+ return mae <= tol
301
+ elif num_channels == 3:
302
+ mae = np.mean(np.abs(image - mean))
303
+ return mae <= tol
304
+ elif num_channels == 4:
305
+ mae = np.mean(np.abs(image[:,:,:-1] - mean))
306
+ return mae <= tol
307
+ else:
308
+ return False
309
+ else:
310
+ return False
311
+
312
+
313
+ def create_solid_color_image(image_width, image_height, color, dtype=None):
314
+ if isinstance(color, numbers.Real):
315
+ image = np.full((image_height, image_width), color, dtype=dtype)
316
+ elif isinstance(color, (tuple, list)):
317
+ if len(color) == 1:
318
+ image = np.full((image_height, image_width), color[0], dtype=dtype)
319
+ elif len(color) in (3, 4):
320
+ image = np.full((1, 1, len(color)), color, dtype=dtype)
321
+ image = cv2.copyMakeBorder(image, 0, image_height-1, 0, image_width-1,
322
+ cv2.BORDER_CONSTANT, value=color)
323
+ else:
324
+ color = np.asarray(color, dtype=dtype)
325
+ image = np.empty((image_height, image_width, len(color)), dtype=dtype)
326
+ image[:] = color
327
+ else:
328
+ raise TypeError(f'Invalid type {type(color)} for `color`.')
329
+ return image
khandy/image/resize.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import cv2
4
+ import khandy
5
+ import numpy as np
6
+
7
+
8
+ interp_codes = {
9
+ 'nearest': cv2.INTER_NEAREST,
10
+ 'bilinear': cv2.INTER_LINEAR,
11
+ 'bicubic': cv2.INTER_CUBIC,
12
+ 'area': cv2.INTER_AREA,
13
+ 'lanczos': cv2.INTER_LANCZOS4
14
+ }
15
+
16
+
17
+ def scale_image(image, x_scale, y_scale, interpolation='bilinear'):
18
+ """Scale image.
19
+
20
+ Reference:
21
+ mmcv.imrescale
22
+ """
23
+ assert khandy.is_numpy_image(image)
24
+ src_height, src_width = image.shape[:2]
25
+ dst_width = int(round(x_scale * src_width))
26
+ dst_height = int(round(y_scale * src_height))
27
+
28
+ resized_image = cv2.resize(image, (dst_width, dst_height),
29
+ interpolation=interp_codes[interpolation])
30
+ return resized_image
31
+
32
+
33
+ def resize_image(image, dst_width, dst_height, return_scale=False, interpolation='bilinear'):
34
+ """Resize image to a given size.
35
+
36
+ Args:
37
+ image (ndarray): The input image.
38
+ dst_width (int): Target width.
39
+ dst_height (int): Target height.
40
+ return_scale (bool): Whether to return `x_scale` and `y_scale`.
41
+ interpolation (str): Interpolation method, accepted values are
42
+ "nearest", "bilinear", "bicubic", "area", "lanczos".
43
+
44
+ Returns:
45
+ tuple or ndarray: (`resized_image`, `x_scale`, `y_scale`) or `resized_image`.
46
+
47
+ Reference:
48
+ mmcv.imresize
49
+ """
50
+ assert khandy.is_numpy_image(image)
51
+ resized_image = cv2.resize(image, (dst_width, dst_height),
52
+ interpolation=interp_codes[interpolation])
53
+ if not return_scale:
54
+ return resized_image
55
+ else:
56
+ src_height, src_width = image.shape[:2]
57
+ x_scale = dst_width / src_width
58
+ y_scale = dst_height / src_height
59
+ return resized_image, x_scale, y_scale
60
+
61
+
62
+ def resize_image_short(image, dst_size, return_scale=False, interpolation='bilinear'):
63
+ """Resize an image so that the length of shorter side is dst_size while
64
+ preserving the original aspect ratio.
65
+
66
+ References:
67
+ `resize_min` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
68
+ """
69
+ assert khandy.is_numpy_image(image)
70
+ src_height, src_width = image.shape[:2]
71
+ scale = max(dst_size / src_width, dst_size / src_height)
72
+ dst_width = int(round(scale * src_width))
73
+ dst_height = int(round(scale * src_height))
74
+
75
+ resized_image = cv2.resize(image, (dst_width, dst_height),
76
+ interpolation=interp_codes[interpolation])
77
+ if not return_scale:
78
+ return resized_image
79
+ else:
80
+ return resized_image, scale
81
+
82
+
83
+ def resize_image_long(image, dst_size, return_scale=False, interpolation='bilinear'):
84
+ """Resize an image so that the length of longer side is dst_size while
85
+ preserving the original aspect ratio.
86
+
87
+ References:
88
+ `resize_max` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
89
+ """
90
+ assert khandy.is_numpy_image(image)
91
+ src_height, src_width = image.shape[:2]
92
+ scale = min(dst_size / src_width, dst_size / src_height)
93
+ dst_width = int(round(scale * src_width))
94
+ dst_height = int(round(scale * src_height))
95
+
96
+ resized_image = cv2.resize(image, (dst_width, dst_height),
97
+ interpolation=interp_codes[interpolation])
98
+ if not return_scale:
99
+ return resized_image
100
+ else:
101
+ return resized_image, scale
102
+
103
+
104
+ def resize_image_to_range(image, min_length, max_length, return_scale=False, interpolation='bilinear'):
105
+ """Resizes an image so its dimensions are within the provided value.
106
+
107
+ Rescale the shortest side of the image up to `min_length` pixels
108
+ while keeping the largest side below `max_length` pixels without
109
+ changing the aspect ratio. Often used in object detection (e.g. RCNN and SSH.)
110
+
111
+ The output size can be described by two cases:
112
+ 1. If the image can be rescaled so its shortest side is equal to the
113
+ `min_length` without the other side exceeding `max_length`, then do so.
114
+ 2. Otherwise, resize so the longest side is equal to `max_length`.
115
+
116
+ Returns:
117
+ resized_image: resized image so that
118
+ min(dst_height, dst_width) == min_length or
119
+ max(dst_height, dst_width) == max_length.
120
+
121
+ References:
122
+ `resize_to_range` in `models-master/research/object_detection/core/preprocessor.py`
123
+ `prep_im_for_blob` in `py-faster-rcnn-master/lib/utils/blob.py`
124
+ mmcv.imrescale
125
+ """
126
+ assert khandy.is_numpy_image(image)
127
+ assert min_length < max_length
128
+ src_height, src_width = image.shape[:2]
129
+
130
+ min_side_length = min(src_width, src_height)
131
+ max_side_length = max(src_width, src_height)
132
+ scale = min_length / min_side_length
133
+ if round(scale * max_side_length) > max_length:
134
+ scale = max_length / max_side_length
135
+ dst_width = int(round(scale * src_width))
136
+ dst_height = int(round(scale * src_height))
137
+
138
+ resized_image = cv2.resize(image, (dst_width, dst_height),
139
+ interpolation=interp_codes[interpolation])
140
+ if not return_scale:
141
+ return resized_image
142
+ else:
143
+ return resized_image, scale
144
+
145
+
146
+ def letterbox_image(image, dst_width, dst_height, border_value=0,
147
+ return_scale=False, interpolation='bilinear'):
148
+ """Resize an image preserving the original aspect ratio using padding.
149
+
150
+ References:
151
+ `letterbox_image` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
152
+ """
153
+ assert khandy.is_numpy_image(image)
154
+ src_height, src_width = image.shape[:2]
155
+ scale = min(dst_width / src_width, dst_height / src_height)
156
+ resize_w = int(round(scale * src_width))
157
+ resize_h = int(round(scale * src_height))
158
+
159
+ resized_image = cv2.resize(image, (resize_w, resize_h),
160
+ interpolation=interp_codes[interpolation])
161
+ pad_top = (dst_height - resize_h) // 2
162
+ pad_bottom = (dst_height - resize_h) - pad_top
163
+ pad_left = (dst_width - resize_w) // 2
164
+ pad_right = (dst_width - resize_w) - pad_left
165
+ padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
166
+ cv2.BORDER_CONSTANT, value=border_value)
167
+ if not return_scale:
168
+ return padded_image
169
+ else:
170
+ return padded_image, scale, pad_left, pad_top
171
+
172
+
173
+ def letterbox_resize_image(image, dst_width, dst_height, border_value=0,
174
+ return_scale=False, interpolation='bilinear'):
175
+ warnings.warn('letterbox_resize_image will be deprecated, use letterbox_image instead!')
176
+ return letterbox_image(image, dst_width, dst_height, border_value,
177
+ return_scale, interpolation)
khandy/image/rotate.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import khandy
3
+ import numpy as np
4
+
5
+
6
+ def get_2d_rotation_matrix(angle, cx=0, cy=0, scale=1,
7
+ degrees=True, dtype=np.float32):
8
+ """
9
+ References:
10
+ `cv2.getRotationMatrix2D` in OpenCV
11
+ """
12
+ if degrees:
13
+ angle = np.deg2rad(angle)
14
+ c = scale * np.cos(angle)
15
+ s = scale * np.sin(angle)
16
+
17
+ tx = cx - cx * c + cy * s
18
+ ty = cy - cx * s - cy * c
19
+ return np.array([[ c, -s, tx],
20
+ [ s, c, ty],
21
+ [ 0, 0, 1]], dtype=dtype)
22
+
23
+
24
+ def rotate_image(image, angle, scale=1.0, center=None,
25
+ degrees=True, border_value=0, auto_bound=False):
26
+ """Rotate an image.
27
+
28
+ Args:
29
+ image : ndarray
30
+ Image to be rotated.
31
+ angle : float
32
+ Rotation angle in degrees, positive values mean clockwise rotation.
33
+ center : tuple
34
+ Center of the rotation in the source image, by default
35
+ it is the center of the image.
36
+ scale : float
37
+ Isotropic scale factor.
38
+ degrees : bool
39
+ border_value : int
40
+ Border value.
41
+ auto_bound : bool
42
+ Whether to adjust the image size to cover the whole rotated image.
43
+
44
+ Returns:
45
+ ndarray: The rotated image.
46
+
47
+ References:
48
+ mmcv.imrotate
49
+ """
50
+ assert khandy.is_numpy_image(image)
51
+ image_height, image_width = image.shape[:2]
52
+ if auto_bound:
53
+ center = None
54
+ if center is None:
55
+ center = ((image_width - 1) * 0.5, (image_height - 1) * 0.5)
56
+ assert isinstance(center, tuple)
57
+
58
+ rotation_matrix = get_2d_rotation_matrix(angle, center[0], center[1], scale, degrees)
59
+ if auto_bound:
60
+ scale_cos = np.abs(rotation_matrix[0, 0])
61
+ scale_sin = np.abs(rotation_matrix[0, 1])
62
+ new_width = image_width * scale_cos + image_height * scale_sin
63
+ new_height = image_width * scale_sin + image_height * scale_cos
64
+
65
+ rotation_matrix[0, 2] += (new_width - image_width) * 0.5
66
+ rotation_matrix[1, 2] += (new_height - image_height) * 0.5
67
+
68
+ image_width = int(np.round(new_width))
69
+ image_height = int(np.round(new_height))
70
+ rotated = cv2.warpAffine(image, rotation_matrix[:2,:], (image_width, image_height),
71
+ borderValue=border_value)
72
+ return rotated
khandy/image/translate.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+
3
+ import khandy
4
+
5
+
6
+ def translate_image(image, x_shift, y_shift, border_value=0):
7
+ """Translate an image.
8
+
9
+ Args:
10
+ image (ndarray): Image to be translated with format (h, w) or (h, w, c).
11
+ x_shift (int): The offset used for translate in horizontal
12
+ direction. right is the positive direction.
13
+ y_shift (int): The offset used for translate in vertical
14
+ direction. down is the positive direction.
15
+ border_value (int | tuple[int]): Value used in case of a
16
+ constant border.
17
+
18
+ Returns:
19
+ ndarray: The translated image.
20
+
21
+ See Also:
22
+ crop_or_pad
23
+ """
24
+ assert khandy.is_numpy_image(image)
25
+ assert isinstance(x_shift, numbers.Integral)
26
+ assert isinstance(y_shift, numbers.Integral)
27
+ image_height, image_width = image.shape[:2]
28
+ channels = 1 if image.ndim == 2 else image.shape[2]
29
+
30
+ if isinstance(border_value, (tuple, list)):
31
+ assert len(border_value) == channels, \
32
+ 'Expected the num of elements in tuple equals the channels ' \
33
+ 'of input image. Found {} vs {}'.format(
34
+ len(border_value), channels)
35
+ else:
36
+ border_value = (border_value,) * channels
37
+ dst_image = khandy.create_solid_color_image(
38
+ image_height, image_width, border_value, dtype=image.dtype)
39
+
40
+ if (abs(x_shift) >= image_width) or (abs(y_shift) >= image_height):
41
+ return dst_image
42
+
43
+ src_x_begin = max(-x_shift, 0)
44
+ src_x_end = min(image_width - x_shift, image_width)
45
+ dst_x_begin = max(x_shift, 0)
46
+ dst_x_end = min(image_width + x_shift, image_width)
47
+
48
+ src_y_begin = max(-y_shift, 0)
49
+ src_y_end = min(image_height - y_shift, image_height)
50
+ dst_y_begin = max(y_shift, 0)
51
+ dst_y_end = min(image_height + y_shift, image_height)
52
+
53
+ dst_image[dst_y_begin:dst_y_end, dst_x_begin:dst_x_end] = \
54
+ image[src_y_begin:src_y_end, src_x_begin:src_x_end]
55
+ return dst_image
56
+
57
+
khandy/label/detect.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import json
4
+ import dataclasses
5
+ from dataclasses import dataclass, field
6
+ from collections import OrderedDict
7
+ from typing import Optional, List
8
+ import xml.etree.ElementTree as ET
9
+
10
+ import khandy
11
+ import lxml
12
+ import lxml.builder
13
+ import numpy as np
14
+
15
+
16
+ __all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
17
+ 'save_detect', 'convert_detect', 'replace_detect_label',
18
+ 'load_coco_class_names']
19
+
20
+
21
+ @dataclass
22
+ class DetectIrObject:
23
+ """Intermediate Representation Format of Object
24
+ """
25
+ label: str
26
+ x_min: float
27
+ y_min: float
28
+ x_max: float
29
+ y_max: float
30
+
31
+
32
+ @dataclass
33
+ class DetectIrRecord:
34
+ """Intermediate Representation Format of Record
35
+ """
36
+ filename: str
37
+ width: int
38
+ height: int
39
+ objects: List[DetectIrObject] = field(default_factory=list)
40
+
41
+
42
+ @dataclass
43
+ class PascalVocSource:
44
+ database: str = ''
45
+ annotation: str = ''
46
+ image: str = ''
47
+
48
+
49
+ @dataclass
50
+ class PascalVocSize:
51
+ height: int
52
+ width: int
53
+ depth: int
54
+
55
+
56
+ @dataclass
57
+ class PascalVocBndbox:
58
+ xmin: float
59
+ ymin: float
60
+ xmax: float
61
+ ymax: float
62
+
63
+
64
+ @dataclass
65
+ class PascalVocObject:
66
+ name: str
67
+ pose: str = 'Unspecified'
68
+ truncated: int = 0
69
+ difficult: int = 0
70
+ bndbox: Optional[PascalVocBndbox] = None
71
+
72
+
73
+ @dataclass
74
+ class PascalVocRecord:
75
+ folder: str = ''
76
+ filename: str = ''
77
+ path: str = ''
78
+ source: PascalVocSource = PascalVocSource()
79
+ size: Optional[PascalVocSize] = None
80
+ segmented: int = 0
81
+ objects: List[PascalVocObject] = field(default_factory=list)
82
+
83
+
84
+ class PascalVocHandler:
85
+ @staticmethod
86
+ def load(filename, **kwargs) -> PascalVocRecord:
87
+ pascal_voc_record = PascalVocRecord()
88
+
89
+ xml_tree = ET.parse(filename)
90
+ pascal_voc_record.folder = xml_tree.find('folder').text
91
+ pascal_voc_record.filename = xml_tree.find('filename').text
92
+ pascal_voc_record.path = xml_tree.find('path').text
93
+ pascal_voc_record.segmented = xml_tree.find('segmented').text
94
+
95
+ source_tag = xml_tree.find('source')
96
+ pascal_voc_record.source = PascalVocSource(
97
+ database=source_tag.find('database').text,
98
+ # annotation=source_tag.find('annotation').text,
99
+ # image=source_tag.find('image').text
100
+ )
101
+
102
+ size_tag = xml_tree.find('size')
103
+ pascal_voc_record.size = PascalVocSize(
104
+ width=int(size_tag.find('width').text),
105
+ height=int(size_tag.find('height').text),
106
+ depth=int(size_tag.find('depth').text)
107
+ )
108
+
109
+ object_tags = xml_tree.findall('object')
110
+ for index, object_tag in enumerate(object_tags):
111
+ bndbox_tag = object_tag.find('bndbox')
112
+ bndbox = PascalVocBndbox(
113
+ xmin=float(bndbox_tag.find('xmin').text) - 1,
114
+ ymin=float(bndbox_tag.find('ymin').text) - 1,
115
+ xmax=float(bndbox_tag.find('xmax').text) - 1,
116
+ ymax=float(bndbox_tag.find('ymax').text) - 1
117
+ )
118
+ pascal_voc_object = PascalVocObject(
119
+ name=object_tag.find('name').text,
120
+ pose=object_tag.find('pose').text,
121
+ truncated=object_tag.find('truncated').text,
122
+ difficult=object_tag.find('difficult').text,
123
+ bndbox=bndbox
124
+ )
125
+ pascal_voc_record.objects.append(pascal_voc_object)
126
+ return pascal_voc_record
127
+
128
+ @staticmethod
129
+ def save(filename, pascal_voc_record: PascalVocRecord):
130
+ maker = lxml.builder.ElementMaker()
131
+ xml = maker.annotation(
132
+ maker.folder(pascal_voc_record.folder),
133
+ maker.filename(pascal_voc_record.filename),
134
+ maker.path(pascal_voc_record.path),
135
+ maker.source(
136
+ maker.database(pascal_voc_record.source.database),
137
+ ),
138
+ maker.size(
139
+ maker.width(str(pascal_voc_record.size.width)),
140
+ maker.height(str(pascal_voc_record.size.height)),
141
+ maker.depth(str(pascal_voc_record.size.depth)),
142
+ ),
143
+ maker.segmented(str(pascal_voc_record.segmented)),
144
+ )
145
+
146
+ for pascal_voc_object in pascal_voc_record.objects:
147
+ object_tag = maker.object(
148
+ maker.name(pascal_voc_object.name),
149
+ maker.pose(pascal_voc_object.pose),
150
+ maker.truncated(str(pascal_voc_object.truncated)),
151
+ maker.difficult(str(pascal_voc_object.difficult)),
152
+ maker.bndbox(
153
+ maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
154
+ maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
155
+ maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
156
+ maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
157
+ ),
158
+ )
159
+ xml.append(object_tag)
160
+
161
+ if not filename.endswith('.xml'):
162
+ filename = filename + '.xml'
163
+ with open(filename, 'wb') as f:
164
+ f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
165
+
166
+ @staticmethod
167
+ def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
168
+ ir_record = DetectIrRecord(
169
+ filename=pascal_voc_record.filename,
170
+ width=pascal_voc_record.size.width,
171
+ height=pascal_voc_record.size.height
172
+ )
173
+ for pascal_voc_object in pascal_voc_record.objects:
174
+ ir_object = DetectIrObject(
175
+ label=pascal_voc_object.name,
176
+ x_min=pascal_voc_object.bndbox.xmin,
177
+ y_min=pascal_voc_object.bndbox.ymin,
178
+ x_max=pascal_voc_object.bndbox.xmax,
179
+ y_max=pascal_voc_object.bndbox.ymax
180
+ )
181
+ ir_record.objects.append(ir_object)
182
+ return ir_record
183
+
184
+ @staticmethod
185
+ def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
186
+ pascal_voc_record = PascalVocRecord(
187
+ filename=ir_record.filename,
188
+ size=PascalVocSize(
189
+ width=ir_record.width,
190
+ height=ir_record.height,
191
+ depth=3
192
+ )
193
+ )
194
+ for ir_object in ir_record.objects:
195
+ pascal_voc_object = PascalVocObject(
196
+ name=ir_object.label,
197
+ bndbox=PascalVocBndbox(
198
+ xmin=ir_object.x_min,
199
+ ymin=ir_object.y_min,
200
+ xmax=ir_object.x_max,
201
+ ymax=ir_object.y_max,
202
+ )
203
+ )
204
+ pascal_voc_record.objects.append(pascal_voc_object)
205
+ return pascal_voc_record
206
+
207
+
208
+ class _NumpyEncoder(json.JSONEncoder):
209
+ """ Special json encoder for numpy types """
210
+ def default(self, obj):
211
+ if isinstance(obj, (np.bool_,)):
212
+ return bool(obj)
213
+ elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
214
+ np.int16, np.int32, np.int64, np.uint8,
215
+ np.uint16, np.uint32, np.uint64)):
216
+ return int(obj)
217
+ elif isinstance(obj, (np.float_, np.float16, np.float32,
218
+ np.float64)):
219
+ return float(obj)
220
+ elif isinstance(obj, (np.ndarray,)):
221
+ return obj.tolist()
222
+ return json.JSONEncoder.default(self, obj)
223
+
224
+
225
+ @dataclass
226
+ class LabelmeShape:
227
+ label: str
228
+ points: np.ndarray
229
+ shape_type: str
230
+ flags: dict = field(default_factory=dict)
231
+ group_id: Optional[int] = None
232
+
233
+ def __post_init__(self):
234
+ self.points = np.asarray(self.points)
235
+
236
+
237
+ @dataclass
238
+ class LabelmeRecord:
239
+ version: str = '4.5.6'
240
+ flags: dict = field(default_factory=dict)
241
+ shapes: List[LabelmeShape] = field(default_factory=list)
242
+ imagePath: Optional[str] = None
243
+ imageData: Optional[str] = None
244
+ imageHeight: Optional[int] = None
245
+ imageWidth: Optional[int] = None
246
+
247
+ def __post_init__(self):
248
+ for k, shape in enumerate(self.shapes):
249
+ self.shapes[k] = LabelmeShape(**shape)
250
+
251
+
252
+ class LabelmeHandler:
253
+ @staticmethod
254
+ def load(filename, **kwargs) -> LabelmeRecord:
255
+ json_content = khandy.load_json(filename)
256
+ return LabelmeRecord(**json_content)
257
+
258
+ @staticmethod
259
+ def save(filename, labelme_record: LabelmeRecord):
260
+ json_content = dataclasses.asdict(labelme_record)
261
+ khandy.save_json(filename, json_content, cls=_NumpyEncoder)
262
+
263
+ @staticmethod
264
+ def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
265
+ ir_record = DetectIrRecord(
266
+ filename=labelme_record.imagePath,
267
+ width=labelme_record.imageWidth,
268
+ height=labelme_record.imageHeight
269
+ )
270
+ for labelme_shape in labelme_record.shapes:
271
+ if labelme_shape.shape_type != 'rectangle':
272
+ continue
273
+ ir_object = DetectIrObject(
274
+ label=labelme_shape.label,
275
+ x_min=labelme_shape.points[0][0],
276
+ y_min=labelme_shape.points[0][1],
277
+ x_max=labelme_shape.points[1][0],
278
+ y_max=labelme_shape.points[1][1],
279
+ )
280
+ ir_record.objects.append(ir_object)
281
+ return ir_record
282
+
283
+ @staticmethod
284
+ def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
285
+ labelme_record = LabelmeRecord(
286
+ imagePath=ir_record.filename,
287
+ imageWidth=ir_record.width,
288
+ imageHeight=ir_record.height
289
+ )
290
+ for ir_object in ir_record.objects:
291
+ labelme_shape = LabelmeShape(
292
+ label=ir_object.label,
293
+ shape_type='rectangle',
294
+ points=[[ir_object.x_min, ir_object.y_min],
295
+ [ir_object.x_max, ir_object.y_max]]
296
+ )
297
+ labelme_record.shapes.append(labelme_shape)
298
+ return labelme_record
299
+
300
+
301
+ @dataclass
302
+ class YoloObject:
303
+ label: str
304
+ x_center: float
305
+ y_center: float
306
+ width: float
307
+ height: float
308
+
309
+
310
+ @dataclass
311
+ class YoloRecord:
312
+ filename: Optional[str] = None
313
+ width: Optional[int] = None
314
+ height: Optional[int] = None
315
+ objects: List[YoloObject] = field(default_factory=list)
316
+
317
+
318
+ class YoloHandler:
319
+ @staticmethod
320
+ def load(filename, **kwargs) -> YoloRecord:
321
+ assert 'image_filename' in kwargs
322
+ assert 'width' in kwargs and 'height' in kwargs
323
+
324
+ records = khandy.load_list(filename)
325
+ yolo_record = YoloRecord(
326
+ filename=kwargs.get('image_filename'),
327
+ width=kwargs.get('width'),
328
+ height=kwargs.get('height'))
329
+ for record in records:
330
+ record_parts = record.split()
331
+ yolo_record.objects.append(YoloObject(
332
+ label=record_parts[0],
333
+ x_center=float(record_parts[1]),
334
+ y_center=float(record_parts[2]),
335
+ width=float(record_parts[3]),
336
+ height=float(record_parts[4]),
337
+ ))
338
+ return yolo_record
339
+
340
+ @staticmethod
341
+ def save(filename, yolo_record: YoloRecord):
342
+ records = []
343
+ for object in yolo_record.objects:
344
+ records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
345
+ if not filename.endswith('.txt'):
346
+ filename = filename + '.txt'
347
+ khandy.save_list(filename, records)
348
+
349
+ @staticmethod
350
+ def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
351
+ ir_record = DetectIrRecord(
352
+ filename=yolo_record.filename,
353
+ width=yolo_record.width,
354
+ height=yolo_record.height
355
+ )
356
+ for yolo_object in yolo_record.objects:
357
+ x_min = (yolo_object.x_center - 0.5 * yolo_object.width) * yolo_record.width
358
+ y_min = (yolo_object.y_center - 0.5 * yolo_object.height) * yolo_record.height
359
+ x_max = (yolo_object.x_center + 0.5 * yolo_object.width) * yolo_record.width
360
+ y_max = (yolo_object.y_center + 0.5 * yolo_object.height) * yolo_record.height
361
+ ir_object = DetectIrObject(
362
+ label=yolo_object.label,
363
+ x_min=x_min,
364
+ y_min=y_min,
365
+ x_max=x_max,
366
+ y_max=y_max
367
+ )
368
+ ir_record.objects.append(ir_object)
369
+ return ir_record
370
+
371
+ @staticmethod
372
+ def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
373
+ yolo_record = YoloRecord(
374
+ filename=ir_record.filename,
375
+ width=ir_record.width,
376
+ height=ir_record.height
377
+ )
378
+ for ir_object in ir_record.objects:
379
+ x_center = (ir_object.x_max + ir_object.x_min) / (2 * ir_record.width)
380
+ y_center = (ir_object.y_max + ir_object.y_min) / (2 * ir_record.height)
381
+ width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
382
+ height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
383
+ yolo_object = YoloObject(
384
+ label=ir_object.label,
385
+ x_center=x_center,
386
+ y_center=y_center,
387
+ width=width,
388
+ height=height,
389
+ )
390
+ yolo_record.objects.append(yolo_object)
391
+ return yolo_record
392
+
393
+
394
+ @dataclass
395
+ class CocoObject:
396
+ label: str
397
+ x_min: float
398
+ y_min: float
399
+ width: float
400
+ height: float
401
+
402
+
403
+ @dataclass
404
+ class CocoRecord:
405
+ filename: str
406
+ width: int
407
+ height: int
408
+ objects: List[CocoObject] = field(default_factory=list)
409
+
410
+
411
+ class CocoHandler:
412
+ @staticmethod
413
+ def load(filename, **kwargs) -> List[CocoRecord]:
414
+ json_data = khandy.load_json(filename)
415
+
416
+ images = json_data['images']
417
+ annotations = json_data['annotations']
418
+ categories = json_data['categories']
419
+
420
+ label_map = {}
421
+ for cat_item in categories:
422
+ label_map[cat_item['id']] = cat_item['name']
423
+
424
+ coco_records = OrderedDict()
425
+ for image_item in images:
426
+ coco_records[image_item['id']] = CocoRecord(
427
+ filename=image_item['file_name'],
428
+ width=image_item['width'],
429
+ height=image_item['height'],
430
+ objects=[])
431
+
432
+ for annotation_item in annotations:
433
+ coco_object = CocoObject(
434
+ label=label_map[annotation_item['category_id']],
435
+ x_min=annotation_item['bbox'][0],
436
+ y_min=annotation_item['bbox'][1],
437
+ width=annotation_item['bbox'][2],
438
+ height=annotation_item['bbox'][3])
439
+ coco_records[annotation_item['image_id']].objects.append(coco_object)
440
+ return list(coco_records.values())
441
+
442
+ @staticmethod
443
+ def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
444
+ ir_record = DetectIrRecord(
445
+ filename=coco_record.filename,
446
+ width=coco_record.width,
447
+ height=coco_record.height,
448
+ )
449
+ for coco_object in coco_record.objects:
450
+ ir_object = DetectIrObject(
451
+ label=coco_object.label,
452
+ x_min=coco_object.x_min,
453
+ y_min=coco_object.y_min,
454
+ x_max=coco_object.x_min + coco_object.width,
455
+ y_max=coco_object.y_min + coco_object.height
456
+ )
457
+ ir_record.objects.append(ir_object)
458
+ return ir_record
459
+
460
+ @staticmethod
461
+ def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
462
+ coco_record = CocoRecord(
463
+ filename=ir_record.filename,
464
+ width=ir_record.width,
465
+ height=ir_record.height
466
+ )
467
+ for ir_object in ir_record.objects:
468
+ coco_object = CocoObject(
469
+ label=ir_object.label,
470
+ x_min=ir_object.x_min,
471
+ y_min=ir_object.y_min,
472
+ width=ir_object.x_max - ir_object.x_min,
473
+ height=ir_object.y_max - ir_object.y_min
474
+ )
475
+ coco_record.objects.append(coco_object)
476
+ return coco_record
477
+
478
+
479
+ def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
480
+ if fmt == 'labelme':
481
+ labelme_record = LabelmeHandler.load(filename, **kwargs)
482
+ ir_record = LabelmeHandler.to_ir(labelme_record)
483
+ elif fmt == 'yolo':
484
+ yolo_record = YoloHandler.load(filename, **kwargs)
485
+ ir_record = YoloHandler.to_ir(yolo_record)
486
+ elif fmt in ('voc', 'pascal', 'pascal_voc'):
487
+ pascal_voc_record = PascalVocHandler.load(filename, **kwargs)
488
+ ir_record = PascalVocHandler.to_ir(pascal_voc_record)
489
+ elif fmt == 'coco':
490
+ coco_records = CocoHandler.load(filename, **kwargs)
491
+ ir_record = [CocoHandler.to_ir(coco_record) for coco_record in coco_records]
492
+ else:
493
+ raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
494
+ return ir_record
495
+
496
+
497
+ def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
498
+ os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
499
+ if out_fmt == 'labelme':
500
+ labelme_record = LabelmeHandler.from_ir(ir_record)
501
+ LabelmeHandler.save(filename, labelme_record)
502
+ elif out_fmt == 'yolo':
503
+ yolo_record = YoloHandler.from_ir(ir_record)
504
+ YoloHandler.save(filename, yolo_record)
505
+ elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
506
+ pascal_voc_record = PascalVocHandler.from_ir(ir_record)
507
+ PascalVocHandler.save(filename, pascal_voc_record)
508
+ elif out_fmt == 'coco':
509
+ raise ValueError("Unsupported for `coco` now!")
510
+ else:
511
+ raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}")
512
+
513
+
514
+ def _get_format(record):
515
+ if isinstance(record, LabelmeRecord):
516
+ return ('labelme',)
517
+ elif isinstance(record, YoloRecord):
518
+ return ('yolo',)
519
+ elif isinstance(record, PascalVocRecord):
520
+ return ('voc', 'pascal', 'pascal_voc')
521
+ elif isinstance(record, CocoRecord):
522
+ return ('coco',)
523
+ elif isinstance(record, DetectIrRecord):
524
+ return ('ir', 'detect_ir')
525
+ else:
526
+ return ()
527
+
528
+
529
+ def convert_detect(record, out_fmt):
530
+ allowed_fmts = ('labelme', 'yolo', 'voc', 'coco', 'pascal', 'pascal_voc', 'ir', 'detect_ir')
531
+ if out_fmt not in allowed_fmts:
532
+ raise ValueError("Unsupported label format conversions for given out_fmt")
533
+ if out_fmt in _get_format(record):
534
+ return record
535
+
536
+ if isinstance(record, LabelmeRecord):
537
+ ir_record = LabelmeHandler.to_ir(record)
538
+ elif isinstance(record, YoloRecord):
539
+ ir_record = YoloHandler.to_ir(record)
540
+ elif isinstance(record, PascalVocRecord):
541
+ ir_record = PascalVocHandler.to_ir(record)
542
+ elif isinstance(record, CocoRecord):
543
+ ir_record = CocoHandler.to_ir(record)
544
+ elif isinstance(record, DetectIrRecord):
545
+ ir_record = record
546
+ else:
547
+ raise TypeError('Unsupported type for record')
548
+
549
+ if out_fmt in ('ir', 'detect_ir'):
550
+ dst_record = ir_record
551
+ elif out_fmt == 'labelme':
552
+ dst_record = LabelmeHandler.from_ir(ir_record)
553
+ elif out_fmt == 'yolo':
554
+ dst_record = YoloHandler.from_ir(ir_record)
555
+ elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
556
+ dst_record = PascalVocHandler.from_ir(ir_record)
557
+ elif out_fmt == 'coco':
558
+ dst_record = CocoHandler.from_ir(ir_record)
559
+ return dst_record
560
+
561
+
562
+ def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
563
+ dst_record = copy.deepcopy(record)
564
+ dst_objects = []
565
+ for ir_object in dst_record.objects:
566
+ if not ignore:
567
+ if ir_object.label in label_map:
568
+ ir_object.label = label_map[ir_object.label]
569
+ dst_objects.append(ir_object)
570
+ else:
571
+ if ir_object.label in label_map:
572
+ ir_object.label = label_map[ir_object.label]
573
+ dst_objects.append(ir_object)
574
+ dst_record.objects = dst_objects
575
+ return dst_record
576
+
577
+
578
+ def load_coco_class_names(filename):
579
+ json_data = khandy.load_json(filename)
580
+ categories = json_data['categories']
581
+ return [cat_item['name'] for cat_item in categories]
582
+
khandy/list_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import itertools
3
+
4
+
5
+ def to_list(obj):
6
+ if obj is None:
7
+ return None
8
+ elif hasattr(obj, '__iter__') and not isinstance(obj, str):
9
+ try:
10
+ return list(obj)
11
+ except:
12
+ return [obj]
13
+ else:
14
+ return [obj]
15
+
16
+
17
+ def convert_lists_to_record(*list_objs, delimiter=None):
18
+ assert len(list_objs) >= 1, 'list_objs length must >= 1.'
19
+ delimiter = delimiter or ','
20
+
21
+ assert isinstance(list_objs[0], (tuple, list))
22
+ number = len(list_objs[0])
23
+ for item in list_objs[1:]:
24
+ assert isinstance(item, (tuple, list))
25
+ assert len(item) == number, '{} != {}'.format(len(item), number)
26
+
27
+ records = []
28
+ record_list = zip(*list_objs)
29
+ for record in record_list:
30
+ record_str = [str(item) for item in record]
31
+ records.append(delimiter.join(record_str))
32
+ return records
33
+
34
+
35
+ def shuffle_table(*table):
36
+ """
37
+ Notes:
38
+ table can be seen as list of list which have equal items.
39
+ """
40
+ shuffled_list = list(zip(*table))
41
+ random.shuffle(shuffled_list)
42
+ tuple_list = zip(*shuffled_list)
43
+ return [list(item) for item in tuple_list]
44
+
45
+
46
+ def transpose_table(table):
47
+ """
48
+ Notes:
49
+ table can be seen as list of list which have equal items.
50
+ """
51
+ m, n = len(table), len(table[0])
52
+ return [[table[i][j] for i in range(m)] for j in range(n)]
53
+
54
+
55
+ def concat_list(in_list):
56
+ """Concatenate a list of list into a single list.
57
+
58
+ Args:
59
+ in_list (list): The list of list to be merged.
60
+
61
+ Returns:
62
+ list: The concatenated flat list.
63
+
64
+ References:
65
+ mmcv.concat_list
66
+ """
67
+ return list(itertools.chain(*in_list))
68
+
khandy/misc.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import socket
3
+ import logging
4
+ import argparse
5
+ import warnings
6
+ from enum import Enum
7
+
8
+ import requests
9
+
10
+
11
+ def all_of(iterable, pred):
12
+ """Returns whether all elements in the iterable satisfy the predicate.
13
+
14
+ Args:
15
+ iterable (Iterable): An iterable to check.
16
+ pred (callable): A predicate to apply to each element.
17
+
18
+ Returns:
19
+ bool: True if all elements satisfy the predicate, False otherwise.
20
+
21
+ References:
22
+ https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
23
+ """
24
+ return all(pred(element) for element in iterable)
25
+
26
+
27
+ def any_of(iterable, pred):
28
+ """Returns whether any element in the iterable satisfies the predicate.
29
+
30
+ Args:
31
+ iterable (Iterable): An iterable to check.
32
+ pred (callable): A predicate to apply to each element.
33
+
34
+ Returns:
35
+ bool: True if any element satisfies the predicate, False otherwise.
36
+
37
+ References:
38
+ https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
39
+ """
40
+ return any(pred(element) for element in iterable)
41
+
42
+
43
+ def none_of(iterable, pred):
44
+ """Returns whether no elements in the iterable satisfy the predicate.
45
+
46
+ Args:
47
+ iterable (Iterable): An iterable to check.
48
+ pred (callable): A predicate to apply to each element.
49
+
50
+ Returns:
51
+ bool: True if no elements satisfy the predicate, False otherwise.
52
+
53
+ References:
54
+ https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
55
+ """
56
+ return not any(pred(element) for element in iterable)
57
+
58
+
59
+ def print_with_no(obj):
60
+ if hasattr(obj, '__len__'):
61
+ for k, item in enumerate(obj):
62
+ print('[{}/{}] {}'.format(k+1, len(obj), item))
63
+ elif hasattr(obj, '__iter__'):
64
+ for k, item in enumerate(obj):
65
+ print('[{}] {}'.format(k+1, item))
66
+ else:
67
+ print('[1] {}'.format(obj))
68
+
69
+
70
+ def get_file_line_count(filename, encoding='utf-8'):
71
+ line_count = 0
72
+ buffer_size = 1024 * 1024 * 8
73
+ with open(filename, 'r', encoding=encoding) as f:
74
+ while True:
75
+ data = f.read(buffer_size)
76
+ if not data:
77
+ break
78
+ line_count += data.count('\n')
79
+ return line_count
80
+
81
+
82
+ def get_host_ip():
83
+ try:
84
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
85
+ s.connect(('8.8.8.8', 80))
86
+ ip = s.getsockname()[0]
87
+ finally:
88
+ s.close()
89
+ return ip
90
+
91
+
92
+ def set_logger(filename, level=logging.INFO, logger_name=None, formatter=None, with_print=True):
93
+ logger = logging.getLogger(logger_name)
94
+ logger.setLevel(level)
95
+
96
+ if formatter is None:
97
+ formatter = logging.Formatter('%(message)s')
98
+
99
+ # Never mutate (insert/remove elements) the list you're currently iterating on.
100
+ # If you need, make a copy.
101
+ for handler in logger.handlers[:]:
102
+ if isinstance(handler, logging.FileHandler):
103
+ logger.removeHandler(handler)
104
+ # FileHandler is subclass of StreamHandler, so isinstance(handler,
105
+ # logging.StreamHandler) is True even if handler is FileHandler.
106
+ # if (type(handler) == logging.StreamHandler) and (handler.stream == sys.stderr):
107
+ elif type(handler) == logging.StreamHandler:
108
+ logger.removeHandler(handler)
109
+
110
+ file_handler = logging.FileHandler(filename, encoding='utf-8')
111
+ file_handler.setFormatter(formatter)
112
+ logger.addHandler(file_handler)
113
+ if with_print:
114
+ console_handler = logging.StreamHandler()
115
+ console_handler.setFormatter(formatter)
116
+ logger.addHandler(console_handler)
117
+ return logger
118
+
119
+
120
+ def print_arguments(args):
121
+ assert isinstance(args, argparse.Namespace)
122
+ arg_list = sorted(vars(args).items())
123
+ for key, value in arg_list:
124
+ print('{}: {}'.format(key, value))
125
+
126
+
127
+ def save_arguments(filename, args, sort=True):
128
+ assert isinstance(args, argparse.Namespace)
129
+ args = vars(args)
130
+ with open(filename, 'w') as f:
131
+ json.dump(args, f, indent=4, sort_keys=sort)
132
+
133
+
134
+ class DownloadStatusCode(Enum):
135
+ FILE_SIZE_TOO_LARGE = (-100, 'the size of file from url is too large')
136
+ FILE_SIZE_TOO_SMALL = (-101, 'the size of file from url is too small')
137
+ FILE_SIZE_IS_ZERO = (-102, 'the size of file from url is zero')
138
+ URL_IS_NOT_IMAGE = (-103, 'URL is not an image')
139
+
140
+ @property
141
+ def code(self):
142
+ return self.value[0]
143
+
144
+ @property
145
+ def message(self):
146
+ return self.value[1]
147
+
148
+
149
+ class DownloadError(Exception):
150
+ def __init__(self, status_code: DownloadStatusCode, extra_str: str=None):
151
+ self.name = status_code.name
152
+ self.code = status_code.code
153
+ if extra_str is None:
154
+ self.message = status_code.message
155
+ else:
156
+ self.message = f'{status_code.message}: {extra_str}'
157
+ Exception.__init__(self)
158
+
159
+ def __repr__(self):
160
+ return f'[{self.__class__.__name__} {self.code}] {self.message}'
161
+
162
+ __str__ = __repr__
163
+
164
+
165
+ def download_image(image_url, min_filesize=0, max_filesize=100*1024*1024,
166
+ params=None, **kwargs) -> bytes:
167
+ """
168
+ References:
169
+ https://httpwg.org/specs/rfc9110.html#field.content-length
170
+ https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow
171
+ """
172
+ stream = kwargs.pop('stream', True)
173
+
174
+ with requests.get(image_url, stream=stream, params=params, **kwargs) as response:
175
+ response.raise_for_status()
176
+
177
+ content_type = response.headers.get('content-type')
178
+ if content_type is None:
179
+ warnings.warn('No Content-Type!')
180
+ else:
181
+ if not content_type.startswith(('image/', 'application/octet-stream')):
182
+ raise DownloadError(DownloadStatusCode.URL_IS_NOT_IMAGE)
183
+
184
+ # when Transfer-Encoding == chunked, Content-Length does not exist.
185
+ content_length = response.headers.get('content-length')
186
+ if content_length is None:
187
+ warnings.warn('No Content-Length!')
188
+ else:
189
+ content_length = int(content_length)
190
+ if content_length > max_filesize:
191
+ raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
192
+ if content_length < min_filesize:
193
+ raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
194
+
195
+ filesize = 0
196
+ chunks = []
197
+ for chunk in response.iter_content(chunk_size=10*1024):
198
+ chunks.append(chunk)
199
+ filesize += len(chunk)
200
+ if filesize > max_filesize:
201
+ raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
202
+ if filesize < min_filesize:
203
+ raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
204
+ image_bytes = b''.join(chunks)
205
+
206
+ return image_bytes
207
+
208
+
209
+ def download_file(url, min_filesize=0, max_filesize=100*1024*1024,
210
+ params=None, **kwargs) -> bytes:
211
+ """
212
+ References:
213
+ https://httpwg.org/specs/rfc9110.html#field.content-length
214
+ https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow
215
+ """
216
+ stream = kwargs.pop('stream', True)
217
+
218
+ with requests.get(url, stream=stream, params=params, **kwargs) as response:
219
+ response.raise_for_status()
220
+
221
+ # when Transfer-Encoding == chunked, Content-Length does not exist.
222
+ content_length = response.headers.get('content-length')
223
+ if content_length is None:
224
+ warnings.warn('No Content-Length!')
225
+ else:
226
+ content_length = int(content_length)
227
+ if content_length > max_filesize:
228
+ raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
229
+ if content_length < min_filesize:
230
+ raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
231
+
232
+ filesize = 0
233
+ chunks = []
234
+ for chunk in response.iter_content(chunk_size=10*1024):
235
+ chunks.append(chunk)
236
+ filesize += len(chunk)
237
+ if filesize > max_filesize:
238
+ raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
239
+ if filesize < min_filesize:
240
+ raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
241
+ file_bytes = b''.join(chunks)
242
+
243
+ return file_bytes
244
+
245
+
khandy/numpy_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def sigmoid(x):
5
+ return 1. / (1 + np.exp(-x))
6
+
7
+
8
+ def softmax(x, axis=-1, copy=True):
9
+ """
10
+ Args:
11
+ copy: Copy x or not.
12
+
13
+ Referneces:
14
+ `from sklearn.utils.extmath import softmax`
15
+ """
16
+ if copy:
17
+ x = np.copy(x)
18
+ max_val = np.max(x, axis=axis, keepdims=True)
19
+ x -= max_val
20
+ np.exp(x, x)
21
+ sum_exp = np.sum(x, axis=axis, keepdims=True)
22
+ x /= sum_exp
23
+ return x
24
+
25
+
26
+ def log_sum_exp(x, axis=-1, keepdims=False):
27
+ """
28
+ References:
29
+ numpy.logaddexp
30
+ numpy.logaddexp2
31
+ scipy.misc.logsumexp
32
+ """
33
+ max_val = np.max(x, axis=axis, keepdims=True)
34
+ x -= max_val
35
+ np.exp(x, x)
36
+ sum_exp = np.sum(x, axis=axis, keepdims=keepdims)
37
+ lse = np.log(sum_exp, sum_exp)
38
+ if not keepdims:
39
+ max_val = np.squeeze(max_val, axis=axis)
40
+ return max_val + lse
41
+
42
+
43
+ def l2_normalize(x, axis=None, epsilon=1e-12, copy=True):
44
+ """L2 normalize an array along an axis.
45
+
46
+ Args:
47
+ x : array_like of floats
48
+ Input data.
49
+ axis : None or int or tuple of ints, optional
50
+ Axis or axes along which to operate.
51
+ epsilon: float, optional
52
+ A small value such as to avoid division by zero.
53
+ copy : bool, optional
54
+ Copy x or not.
55
+ """
56
+ if copy:
57
+ x = np.copy(x)
58
+ x /= np.maximum(np.linalg.norm(x, axis=axis, keepdims=True), epsilon)
59
+ return x
60
+
61
+
62
+ def minmax_normalize(x, axis=None, epsilon=1e-12, copy=True):
63
+ """minmax normalize an array along a given axis.
64
+
65
+ Args:
66
+ x : array_like of floats
67
+ Input data.
68
+ axis : None or int or tuple of ints, optional
69
+ Axis or axes along which to operate.
70
+ epsilon: float, optional
71
+ A small value such as to avoid division by zero.
72
+ copy : bool, optional
73
+ Copy x or not.
74
+ """
75
+ if copy:
76
+ x = np.copy(x)
77
+
78
+ minval = np.min(x, axis=axis, keepdims=True)
79
+ maxval = np.max(x, axis=axis, keepdims=True)
80
+ maxval -= minval
81
+ maxval = np.maximum(maxval, epsilon)
82
+
83
+ x -= minval
84
+ x /= maxval
85
+ return x
86
+
87
+
88
+ def zscore_normalize(x, mean=None, std=None, axis=None, epsilon=1e-12, copy=True):
89
+ """z-score normalize an array along a given axis.
90
+
91
+ Args:
92
+ x : array_like of floats
93
+ Input data.
94
+ mean: array_like of floats, optional
95
+ mean for z-score
96
+ std: array_like of floats, optional
97
+ std for z-score
98
+ axis : None or int or tuple of ints, optional
99
+ Axis or axes along which to operate.
100
+ epsilon: float, optional
101
+ A small value such as to avoid division by zero.
102
+ copy : bool, optional
103
+ Copy x or not.
104
+ """
105
+ if copy:
106
+ x = np.copy(x)
107
+ if mean is None:
108
+ mean = np.mean(x, axis=axis, keepdims=True)
109
+ if std is None:
110
+ std = np.std(x, axis=axis, keepdims=True)
111
+ mean = np.asarray(mean, dtype=x.dtype)
112
+ std = np.asarray(std, dtype=x.dtype)
113
+ std = np.maximum(std, epsilon)
114
+
115
+ x -= mean
116
+ x /= std
117
+ return x
118
+
119
+
120
+ def get_order_of_magnitude(number):
121
+ number = np.where(number == 0, 1, number)
122
+ oom = np.floor(np.log10(np.abs(number)))
123
+ return oom.astype(np.int32)
124
+
125
+
126
+ def top_k(x, k, axis=-1, largest=True, sorted=True):
127
+ """Finds values and indices of the k largest/smallest
128
+ elements along a given axis.
129
+
130
+ Args:
131
+ x: numpy ndarray
132
+ 1-D or higher with given axis at least k.
133
+ k: int
134
+ Number of top elements to look for along the given axis.
135
+ axis: int
136
+ The axis to sort along.
137
+ largest: bool
138
+ Controls whether to return largest or smallest elements
139
+ sorted: bool
140
+ If true the resulting k elements will be sorted by the values.
141
+
142
+ Returns:
143
+ topk_values:
144
+ The k largest/smallest elements along the given axis.
145
+ topk_indices:
146
+ The indices of the k largest/smallest elements along the given axis.
147
+ """
148
+ if axis is None:
149
+ axis_size = x.size
150
+ else:
151
+ axis_size = x.shape[axis]
152
+ assert 1 <= k <= axis_size
153
+
154
+ x = np.asanyarray(x)
155
+ if largest:
156
+ index_array = np.argpartition(x, axis_size-k, axis=axis)
157
+ topk_indices = np.take(index_array, -np.arange(k)-1, axis=axis)
158
+ else:
159
+ index_array = np.argpartition(x, k-1, axis=axis)
160
+ topk_indices = np.take(index_array, np.arange(k), axis=axis)
161
+ topk_values = np.take_along_axis(x, topk_indices, axis=axis)
162
+ if sorted:
163
+ sorted_indices_in_topk = np.argsort(topk_values, axis=axis)
164
+ if largest:
165
+ sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis)
166
+ sorted_topk_values = np.take_along_axis(
167
+ topk_values, sorted_indices_in_topk, axis=axis)
168
+ sorted_topk_indices = np.take_along_axis(
169
+ topk_indices, sorted_indices_in_topk, axis=axis)
170
+ return sorted_topk_values, sorted_topk_indices
171
+ return topk_values, topk_indices
172
+
173
+
khandy/points/pts_letterbox.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ['letterbox_2d_points', 'unletterbox_2d_points']
2
+
3
+
4
+ def letterbox_2d_points(points, scale=1.0, pad_left=0, pad_top=0, copy=True):
5
+ if copy:
6
+ points = points.copy()
7
+ points[..., 0::2] = points[..., 0::2] * scale + pad_left
8
+ points[..., 1::2] = points[..., 1::2] * scale + pad_top
9
+ return points
10
+
11
+
12
+ def unletterbox_2d_points(points, scale=1.0, pad_left=0, pad_top=0, copy=True):
13
+ if copy:
14
+ points = points.copy()
15
+
16
+ points[..., 0::2] = (points[..., 0::2] - pad_left) / scale
17
+ points[..., 1::2] = (points[..., 1::2] - pad_top) / scale
18
+ return points
19
+
khandy/points/pts_transform_scale.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ __all__ = ['scale_2d_points']
4
+
5
+
6
+ def scale_2d_points(points, x_scale=1, y_scale=1, x_center=0, y_center=0, copy=True):
7
+ """Scale 2d points.
8
+
9
+ Args:
10
+ points: (..., 2N)
11
+ x_scale: scale factor in x dimension
12
+ y_scale: scale factor in y dimension
13
+ x_center: scale center in x dimension
14
+ y_center: scale center in y dimension
15
+ """
16
+ points = np.array(points, dtype=np.float32, copy=copy)
17
+ x_scale = np.asarray(x_scale, np.float32)
18
+ y_scale = np.asarray(y_scale, np.float32)
19
+ x_center = np.asarray(x_center, np.float32)
20
+ y_center = np.asarray(y_center, np.float32)
21
+
22
+ x_shift = 1 - x_scale
23
+ y_shift = 1 - y_scale
24
+ x_shift *= x_center
25
+ y_shift *= y_center
26
+
27
+ points[..., 0::2] *= x_scale
28
+ points[..., 1::2] *= y_scale
29
+ points[..., 0::2] += x_shift
30
+ points[..., 1::2] += y_shift
31
+ return points
32
+
33
+
khandy/split_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from collections import Sequence
3
+
4
+ import numpy as np
5
+
6
+
7
+ def split_by_num(x, num_splits, strict=True):
8
+ """
9
+ Args:
10
+ num_splits: an integer indicating the number of splits
11
+
12
+ References:
13
+ numpy.split and numpy.array_split
14
+ """
15
+ # NB: np.ndarray is not Sequence
16
+ assert isinstance(x, (Sequence, np.ndarray))
17
+ assert isinstance(num_splits, numbers.Integral)
18
+
19
+ if strict:
20
+ assert len(x) % num_splits == 0
21
+ split_size = (len(x) + num_splits - 1) // num_splits
22
+ out_list = []
23
+ for i in range(0, len(x), split_size):
24
+ out_list.append(x[i: i + split_size])
25
+ return out_list
26
+
27
+
28
+ def split_by_size(x, sizes):
29
+ """
30
+ References:
31
+ tf.split
32
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/misc.py
33
+ """
34
+ # NB: np.ndarray is not Sequence
35
+ assert isinstance(x, (Sequence, np.ndarray))
36
+ assert isinstance(sizes, (list, tuple))
37
+
38
+ assert sum(sizes) == len(x)
39
+ out_list = []
40
+ start_index = 0
41
+ for size in sizes:
42
+ out_list.append(x[start_index: start_index + size])
43
+ start_index += size
44
+ return out_list
45
+
46
+
47
+ def split_by_slice(x, slices):
48
+ """
49
+ References:
50
+ SliceLayer in Caffe, and numpy.split
51
+ """
52
+ # NB: np.ndarray is not Sequence
53
+ assert isinstance(x, (Sequence, np.ndarray))
54
+ assert isinstance(slices, (list, tuple))
55
+
56
+ out_list = []
57
+ indices = [0] + list(slices) + [len(x)]
58
+ for i in range(len(slices) + 1):
59
+ out_list.append(x[indices[i]: indices[i + 1]])
60
+ return out_list
61
+
62
+
63
+ def split_by_ratio(x, ratios):
64
+ # NB: np.ndarray is not Sequence
65
+ assert isinstance(x, (Sequence, np.ndarray))
66
+ assert isinstance(ratios, (list, tuple))
67
+
68
+ pdf = [k / sum(ratios) for k in ratios]
69
+ cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
70
+ indices = [int(round(len(x) * k)) for k in cdf]
71
+ return [x[indices[i]: indices[i + 1]] for i in range(len(ratios))]
72
+
73
+
khandy/text_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def strip_content_in_paren(string):
5
+ """
6
+ Notes:
7
+ strip_content_in_paren cannot process nested paren correctly
8
+ """
9
+ return re.sub(r"\([^)]*\)|([^)]*)", "", string)
10
+
11
+
12
+ def is_chinese_char(uchar: str) -> bool:
13
+ """Whether the input char is a Chinese character.
14
+
15
+ Args:
16
+ uchar: input char in unicode
17
+
18
+ References:
19
+ `is_chinese_char` in https://github.com/thunlp/OpenNRE/
20
+ """
21
+ codepoint = ord(uchar)
22
+ if ((0x4E00 <= codepoint <= 0x9FFF) or # CJK Unified Ideographs
23
+ (0x3400 <= codepoint <= 0x4DBF) or # CJK Unified Ideographs Extension A
24
+ (0xF900 <= codepoint <= 0xFAFF) or # CJK Compatibility Ideographs
25
+ (0x20000 <= codepoint <= 0x2A6DF) or # CJK Unified Ideographs Extension B
26
+ (0x2A700 <= codepoint <= 0x2B73F) or
27
+ (0x2B740 <= codepoint <= 0x2B81F) or
28
+ (0x2B820 <= codepoint <= 0x2CEAF) or
29
+ (0x2F800 <= codepoint <= 0x2FA1F)): # CJK Compatibility Supplement
30
+ return True
31
+ return False
32
+
33
+
khandy/time_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ import numbers
4
+ import datetime
5
+
6
+
7
+ def _to_timestamp(val, multiplier=1, rounded=False):
8
+ if val is None:
9
+ timestamp = time.time()
10
+ elif isinstance(val, numbers.Real):
11
+ timestamp = float(val)
12
+ elif isinstance(val, time.struct_time):
13
+ timestamp = time.mktime(val)
14
+ elif isinstance(val, datetime.datetime):
15
+ timestamp = val.timestamp()
16
+ elif isinstance(val, datetime.date):
17
+ dt = datetime.datetime.combine(val, datetime.time())
18
+ timestamp = dt.timestamp()
19
+ elif isinstance(val, str):
20
+ try:
21
+ # The full format looks like 'YYYY-MM-DD HH:MM:SS.mmmmmm'.
22
+ dt = datetime.datetime.fromisoformat(val)
23
+ timestamp = dt.timestamp()
24
+ except:
25
+ raise TypeError('when argument is str, it should conform to isoformat')
26
+ else:
27
+ raise TypeError('unsupported type!')
28
+ timestamp = timestamp * multiplier
29
+ if rounded:
30
+ # The return value is an integer if ndigits is omitted or None.
31
+ timestamp = round(timestamp)
32
+ return timestamp
33
+
34
+
35
+ def get_timestamp(time_val=None, rounded=True):
36
+ """timestamp in seconds.
37
+ """
38
+ return _to_timestamp(time_val, multiplier=1, rounded=rounded)
39
+
40
+
41
+ def get_timestamp_ms(time_val=None, rounded=True):
42
+ """timestamp in milliseconds.
43
+ """
44
+ return _to_timestamp(time_val, multiplier=1000, rounded=rounded)
45
+
46
+
47
+ def get_timestamp_us(time_val=None, rounded=True):
48
+ """timestamp in microseconds.
49
+ """
50
+ return _to_timestamp(time_val, multiplier=1000000, rounded=rounded)
51
+
52
+
53
+ def get_utc8now() -> datetime.datetime:
54
+ """get current UTC-8 time or Beijing time
55
+ """
56
+ tz = datetime.timezone(datetime.timedelta(hours=8))
57
+ utc8now = datetime.datetime.now(tz)
58
+ return utc8now
59
+
60
+
61
+ class ContextTimer(object):
62
+ """
63
+ References:
64
+ WithTimer in https://github.com/uber/ludwig/blob/master/ludwig/utils/time_utils.py
65
+ """
66
+ def __init__(self, name=None, use_log=False, quiet=False):
67
+ self.use_log = use_log
68
+ self.quiet = quiet
69
+ if name is None:
70
+ self.name = ''
71
+ else:
72
+ self.name = '{}, '.format(name.rstrip())
73
+
74
+ def __enter__(self):
75
+ self.start_time = time.time()
76
+ if not self.quiet:
77
+ self._print_or_log('{}{} starts'.format(self.name, self._now_time_str))
78
+ return self
79
+
80
+ def __exit__(self, exc_type, exc_val, exc_tb):
81
+ if not self.quiet:
82
+ self._print_or_log('{}elapsed_time = {:.5}s'.format(self.name, self.get_eplased_time()))
83
+ self._print_or_log('{}{} ends'.format(self.name, self._now_time_str))
84
+
85
+ @property
86
+ def _now_time_str(self):
87
+ return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
88
+
89
+ def _print_or_log(self, output_str):
90
+ if self.use_log:
91
+ logging.info(output_str)
92
+ else:
93
+ print(output_str)
94
+
95
+ def get_eplased_time(self):
96
+ return time.time() - self.start_time
97
+
98
+ def enter(self):
99
+ """Manually trigger enter"""
100
+ self.__enter__()
101
+
khandy/version.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = '0.1.8'
2
+
3
+ __all__ = ['__version__']
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ opencv-python>=4.5
2
+ numpy>=1.11.1
3
+ pillow==6.2.1
4
+ lxml
5
+ requests
6
+ onnxruntime