Spaces:
Running
Running
admin
commited on
Commit
•
67a9b5d
1
Parent(s):
89fbbde
sync
Browse files- .gitattributes +2 -0
- .gitignore +7 -0
- README.md +8 -7
- app.py +86 -0
- insectid/__init__.py +2 -0
- insectid/base.py +51 -0
- insectid/detector.py +58 -0
- insectid/identifier.py +76 -0
- khandy/__init__.py +18 -0
- khandy/boxes/__init__.py +13 -0
- khandy/boxes/boxes_and_indices.py +68 -0
- khandy/boxes/boxes_clip.py +34 -0
- khandy/boxes/boxes_coder.py +69 -0
- khandy/boxes/boxes_convert.py +101 -0
- khandy/boxes/boxes_filter.py +113 -0
- khandy/boxes/boxes_overlap.py +166 -0
- khandy/boxes/boxes_transform_flip.py +135 -0
- khandy/boxes/boxes_transform_rotate.py +140 -0
- khandy/boxes/boxes_transform_scale.py +86 -0
- khandy/boxes/boxes_transform_translate.py +136 -0
- khandy/boxes/boxes_utils.py +28 -0
- khandy/dict_utils.py +168 -0
- khandy/draw_utils.py +148 -0
- khandy/feature_utils.py +62 -0
- khandy/file_io_utils.py +87 -0
- khandy/fs_utils.py +375 -0
- khandy/hash_utils.py +25 -0
- khandy/image/__init__.py +10 -0
- khandy/image/align_and_crop.py +60 -0
- khandy/image/crop_or_pad.py +138 -0
- khandy/image/flip.py +72 -0
- khandy/image/image_hash.py +69 -0
- khandy/image/misc.py +329 -0
- khandy/image/resize.py +177 -0
- khandy/image/rotate.py +72 -0
- khandy/image/translate.py +57 -0
- khandy/label/__init__.py +2 -0
- khandy/label/detect.py +594 -0
- khandy/list_utils.py +68 -0
- khandy/misc.py +245 -0
- khandy/numpy_utils.py +173 -0
- khandy/points/__init__.py +2 -0
- khandy/points/pts_letterbox.py +19 -0
- khandy/points/pts_transform_scale.py +33 -0
- khandy/split_utils.py +71 -0
- khandy/text_utils.py +33 -0
- khandy/time_utils.py +101 -0
- khandy/version.py +3 -0
- requirements.txt +7 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/Coccinella_septempunctata.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
simsun.ttc filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
_local/
|
3 |
+
*.pyc
|
4 |
+
local_models_*/
|
5 |
+
rename.sh
|
6 |
+
*.onnx
|
7 |
+
simsun.ttc
|
README.md
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
1 |
---
|
2 |
+
title: insecta
|
3 |
+
emoji: 🐞
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.39.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
+
# 特性
|
13 |
+
- 支持 2037 类 (可能是目, 科, 属或种等) 昆虫或其他节肢动物
|
14 |
+
- 模型开源, 持续更新.
|
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import khandy
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
from modelscope import snapshot_download
|
7 |
+
from insectid import InsectDetector, InsectIdentifier
|
8 |
+
|
9 |
+
MODEL_DIR = snapshot_download("MuGeminorum/insecta", cache_dir="./insectid/__pycache__")
|
10 |
+
|
11 |
+
|
12 |
+
def infer(filename: str):
|
13 |
+
if not filename:
|
14 |
+
None, "请上传图片 Please upload a picture"
|
15 |
+
|
16 |
+
detector = InsectDetector()
|
17 |
+
identifier = InsectIdentifier()
|
18 |
+
image = khandy.imread(filename)
|
19 |
+
if image is None:
|
20 |
+
return None
|
21 |
+
|
22 |
+
if max(image.shape[:2]) > 1280:
|
23 |
+
image = khandy.resize_image_long(image, 1280)
|
24 |
+
|
25 |
+
image_for_draw = image.copy()
|
26 |
+
image_height, image_width = image.shape[:2]
|
27 |
+
boxes, confs, classes = detector.detect(image)
|
28 |
+
text = "未知"
|
29 |
+
for box, _, _ in zip(boxes, confs, classes):
|
30 |
+
box = box.astype(np.int32)
|
31 |
+
box_width = box[2] - box[0] + 1
|
32 |
+
box_height = box[3] - box[1] + 1
|
33 |
+
if box_width < 30 or box_height < 30:
|
34 |
+
continue
|
35 |
+
|
36 |
+
cropped = khandy.crop_or_pad(image, box[0], box[1], box[2], box[3])
|
37 |
+
results = identifier.identify(cropped)
|
38 |
+
print(results[0])
|
39 |
+
prob = results[0]["probability"]
|
40 |
+
if prob >= 0.10:
|
41 |
+
text = "{} {}: {:.2f}%".format(
|
42 |
+
results[0]["chinese_name"],
|
43 |
+
results[0]["latin_name"],
|
44 |
+
100.0 * results[0]["probability"],
|
45 |
+
)
|
46 |
+
|
47 |
+
position = [box[0] + 2, box[1] - 20]
|
48 |
+
position[0] = min(max(position[0], 0), image_width)
|
49 |
+
position[1] = min(max(position[1], 0), image_height)
|
50 |
+
cv2.rectangle(
|
51 |
+
image_for_draw,
|
52 |
+
(box[0], box[1]),
|
53 |
+
(box[2], box[3]),
|
54 |
+
(0, 255, 0),
|
55 |
+
2,
|
56 |
+
)
|
57 |
+
image_for_draw = khandy.draw_text(
|
58 |
+
image_for_draw,
|
59 |
+
text,
|
60 |
+
position,
|
61 |
+
font=f"{MODEL_DIR}/simsun.ttc",
|
62 |
+
font_size=15,
|
63 |
+
)
|
64 |
+
|
65 |
+
outxt = text.split(":")[0] if ":" in text else text
|
66 |
+
return Image.fromarray(image_for_draw[:, :, ::-1], mode="RGB"), outxt
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
iface = gr.Interface(
|
71 |
+
fn=infer,
|
72 |
+
inputs=gr.Image(label="上传昆虫照片 Upload insect picture", type="filepath"),
|
73 |
+
outputs=[
|
74 |
+
gr.Image(label="识别结果 Recognition result"),
|
75 |
+
gr.Textbox(label="最可能的物种 Best match", show_copy_button=True),
|
76 |
+
],
|
77 |
+
title="图像文件格式支持 PNG, JPG, JPEG 和 BMP, 且文件大小不超过 10M<br>Image file format support PNG, JPG, JPEG and BMP, and the file size does not exceed 10M.",
|
78 |
+
examples=[
|
79 |
+
f"{MODEL_DIR}/examples/butterfly.jpg",
|
80 |
+
f"{MODEL_DIR}/examples/beetle.jpg",
|
81 |
+
],
|
82 |
+
allow_flagging="never",
|
83 |
+
cache_examples=False,
|
84 |
+
)
|
85 |
+
|
86 |
+
iface.launch()
|
insectid/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .detector import *
|
2 |
+
from .identifier import *
|
insectid/base.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnxruntime
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class OnnxModel(object):
|
6 |
+
def __init__(self, model_path):
|
7 |
+
sess_options = onnxruntime.SessionOptions()
|
8 |
+
# # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.
|
9 |
+
# sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
10 |
+
# # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
|
11 |
+
# sess_options.intra_op_num_threads = multiprocessing.cpu_count()
|
12 |
+
onnx_gpu = (onnxruntime.get_device() == 'GPU')
|
13 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if onnx_gpu else ['CPUExecutionProvider']
|
14 |
+
self.sess = onnxruntime.InferenceSession(model_path, sess_options, providers=providers)
|
15 |
+
self._input_names = [item.name for item in self.sess.get_inputs()]
|
16 |
+
self._output_names = [item.name for item in self.sess.get_outputs()]
|
17 |
+
|
18 |
+
@property
|
19 |
+
def input_names(self):
|
20 |
+
return self._input_names
|
21 |
+
|
22 |
+
@property
|
23 |
+
def output_names(self):
|
24 |
+
return self._output_names
|
25 |
+
|
26 |
+
def forward(self, inputs):
|
27 |
+
to_list_flag = False
|
28 |
+
if not isinstance(inputs, (tuple, list)):
|
29 |
+
inputs = [inputs]
|
30 |
+
to_list_flag = True
|
31 |
+
input_feed = {name: input for name, input in zip(self.input_names, inputs)}
|
32 |
+
outputs = self.sess.run(self.output_names, input_feed)
|
33 |
+
if (len(self.output_names) == 1) and to_list_flag:
|
34 |
+
return outputs[0]
|
35 |
+
else:
|
36 |
+
return outputs
|
37 |
+
|
38 |
+
|
39 |
+
def check_image_dtype_and_shape(image):
|
40 |
+
if not isinstance(image, np.ndarray):
|
41 |
+
raise Exception(f'image is not np.ndarray!')
|
42 |
+
|
43 |
+
if isinstance(image.dtype, (np.uint8, np.uint16)):
|
44 |
+
raise Exception(f'Unsupported image dtype, only support uint8 and uint16, got {image.dtype}!')
|
45 |
+
if image.ndim not in {2, 3}:
|
46 |
+
raise Exception(f'Unsupported image dimension number, only support 2 and 3, got {image.ndim}!')
|
47 |
+
if image.ndim == 3:
|
48 |
+
num_channels = image.shape[-1]
|
49 |
+
if num_channels not in {1, 3, 4}:
|
50 |
+
raise Exception(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')
|
51 |
+
|
insectid/detector.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import khandy
|
3 |
+
import numpy as np
|
4 |
+
from .base import OnnxModel
|
5 |
+
from .base import check_image_dtype_and_shape
|
6 |
+
|
7 |
+
|
8 |
+
class InsectDetector(OnnxModel):
|
9 |
+
def __init__(self):
|
10 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
model_path = os.path.join(
|
12 |
+
current_dir,
|
13 |
+
"__pycache__/MuGeminorum/insecta/quarrying_insect_detector.onnx",
|
14 |
+
)
|
15 |
+
self.input_width = 640
|
16 |
+
self.input_height = 640
|
17 |
+
super(InsectDetector, self).__init__(model_path)
|
18 |
+
|
19 |
+
def _preprocess(self, image):
|
20 |
+
check_image_dtype_and_shape(image)
|
21 |
+
# image size normalization
|
22 |
+
image, scale, pad_left, pad_top = khandy.letterbox_image(
|
23 |
+
image, self.input_width, self.input_height, 0, return_scale=True
|
24 |
+
)
|
25 |
+
# image channel normalization
|
26 |
+
image = khandy.normalize_image_channel(image, swap_rb=True)
|
27 |
+
# image dtype normalization
|
28 |
+
image = khandy.rescale_image(image, "auto", np.float32)
|
29 |
+
# to tensor
|
30 |
+
image = np.transpose(image, (2, 0, 1))
|
31 |
+
image = np.expand_dims(image, axis=0)
|
32 |
+
return image, scale, pad_left, pad_top
|
33 |
+
|
34 |
+
def _post_process(
|
35 |
+
self, outputs_list, scale, pad_left, pad_top, conf_thresh, iou_thresh
|
36 |
+
):
|
37 |
+
pred = outputs_list[0][0]
|
38 |
+
pass_t = pred[:, 4] > conf_thresh
|
39 |
+
pred = pred[pass_t]
|
40 |
+
boxes = khandy.convert_boxes_format(pred[:, :4], "cxcywh", "xyxy")
|
41 |
+
boxes = khandy.unletterbox_2d_points(boxes, scale, pad_left, pad_top, False)
|
42 |
+
confs = np.max(pred[:, 5:] * pred[:, 4:5], axis=-1)
|
43 |
+
classes = np.argmax(pred[:, 5:] * pred[:, 4:5], axis=-1)
|
44 |
+
keep = khandy.non_max_suppression(boxes, confs, iou_thresh)
|
45 |
+
return boxes[keep], confs[keep], classes[keep]
|
46 |
+
|
47 |
+
def detect(self, image, conf_thresh=0.5, iou_thresh=0.5):
|
48 |
+
image, scale, pad_left, pad_top = self._preprocess(image)
|
49 |
+
outputs_list = self.forward(image)
|
50 |
+
boxes, confs, classes = self._post_process(
|
51 |
+
outputs_list,
|
52 |
+
scale=scale,
|
53 |
+
pad_left=pad_left,
|
54 |
+
pad_top=pad_top,
|
55 |
+
conf_thresh=conf_thresh,
|
56 |
+
iou_thresh=iou_thresh,
|
57 |
+
)
|
58 |
+
return boxes, confs, classes
|
insectid/identifier.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import khandy
|
4 |
+
import numpy as np
|
5 |
+
from .base import OnnxModel
|
6 |
+
from collections import OrderedDict
|
7 |
+
from .base import check_image_dtype_and_shape
|
8 |
+
|
9 |
+
|
10 |
+
class InsectIdentifier(OnnxModel):
|
11 |
+
def __init__(self):
|
12 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
13 |
+
model_path = os.path.join(
|
14 |
+
current_dir,
|
15 |
+
"__pycache__/MuGeminorum/insecta/quarrying_insect_identifier.onnx",
|
16 |
+
)
|
17 |
+
label_map_path = os.path.join(
|
18 |
+
current_dir,
|
19 |
+
"__pycache__/MuGeminorum/insecta/quarrying_insectid_label_map.txt",
|
20 |
+
)
|
21 |
+
super(InsectIdentifier, self).__init__(model_path)
|
22 |
+
self.label_name_dict = self._get_label_name_dict(label_map_path)
|
23 |
+
self.names = [
|
24 |
+
self.label_name_dict[i]["chinese_name"]
|
25 |
+
for i in range(len(self.label_name_dict))
|
26 |
+
]
|
27 |
+
self.num_classes = len(self.label_name_dict)
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def _get_label_name_dict(filename):
|
31 |
+
records = khandy.load_list(filename)
|
32 |
+
label_name_dict = {}
|
33 |
+
for record in records:
|
34 |
+
label, chinese_name, latin_name = record.split(",")
|
35 |
+
label_name_dict[int(label)] = OrderedDict(
|
36 |
+
[("chinese_name", chinese_name), ("latin_name", latin_name)]
|
37 |
+
)
|
38 |
+
|
39 |
+
return label_name_dict
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def _preprocess(image):
|
43 |
+
check_image_dtype_and_shape(image)
|
44 |
+
# image size normalization
|
45 |
+
image = khandy.letterbox_image(image, 224, 224)
|
46 |
+
# image channel normalization
|
47 |
+
image = khandy.normalize_image_channel(image, swap_rb=True)
|
48 |
+
# image dtype normalization
|
49 |
+
# image dtype and value range normalization
|
50 |
+
mean, stddev = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
51 |
+
image = khandy.normalize_image_value(image, mean, stddev, "auto")
|
52 |
+
# to tensor
|
53 |
+
image = np.transpose(image, (2, 0, 1))
|
54 |
+
image = np.expand_dims(image, axis=0)
|
55 |
+
return image
|
56 |
+
|
57 |
+
def predict(self, image):
|
58 |
+
inputs = self._preprocess(image)
|
59 |
+
logits = self.forward(inputs)
|
60 |
+
probs = khandy.softmax(logits)
|
61 |
+
return probs
|
62 |
+
|
63 |
+
def identify(self, image, topk=5):
|
64 |
+
assert isinstance(topk, int)
|
65 |
+
if topk <= 0 or topk > self.num_classes:
|
66 |
+
topk = self.num_classes
|
67 |
+
|
68 |
+
probs = self.predict(image)
|
69 |
+
topk_probs, topk_indices = khandy.top_k(probs, topk)
|
70 |
+
results = []
|
71 |
+
for ind, prob in zip(topk_indices[0], topk_probs[0]):
|
72 |
+
one_result = copy.deepcopy(self.label_name_dict[ind])
|
73 |
+
one_result["probability"] = prob
|
74 |
+
results.append(one_result)
|
75 |
+
|
76 |
+
return results
|
khandy/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dict_utils import *
|
2 |
+
from .draw_utils import *
|
3 |
+
from .feature_utils import *
|
4 |
+
from .file_io_utils import *
|
5 |
+
from .fs_utils import *
|
6 |
+
from .hash_utils import *
|
7 |
+
from .list_utils import *
|
8 |
+
from .misc import *
|
9 |
+
from .numpy_utils import *
|
10 |
+
from .split_utils import *
|
11 |
+
from .text_utils import *
|
12 |
+
from .time_utils import *
|
13 |
+
from .version import *
|
14 |
+
|
15 |
+
from .boxes import *
|
16 |
+
from .image import *
|
17 |
+
from .points import *
|
18 |
+
from . import label
|
khandy/boxes/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .boxes_clip import *
|
2 |
+
from .boxes_overlap import *
|
3 |
+
from .boxes_filter import *
|
4 |
+
from .boxes_convert import *
|
5 |
+
from .boxes_coder import *
|
6 |
+
|
7 |
+
from .boxes_transform_flip import *
|
8 |
+
from .boxes_transform_rotate import *
|
9 |
+
from .boxes_transform_scale import *
|
10 |
+
from .boxes_transform_translate import *
|
11 |
+
from .boxes_utils import *
|
12 |
+
|
13 |
+
from .boxes_and_indices import *
|
khandy/boxes/boxes_and_indices.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def _concat(arr_list, axis=0):
|
5 |
+
"""Avoids a copy if there is only a single element in a list.
|
6 |
+
"""
|
7 |
+
if len(arr_list) == 1:
|
8 |
+
return arr_list[0]
|
9 |
+
return np.concatenate(arr_list, axis)
|
10 |
+
|
11 |
+
|
12 |
+
def convert_boxes_list_to_boxes_and_indices(boxes_list):
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
boxes_list (np.ndarray): list or tuple of ndarray with shape (N_i, 4+K)
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
boxes (ndarray): shape (M, 4+K) where M is sum of N_i.
|
19 |
+
indices (ndarray): shape (M, 1) where M is sum of N_i.
|
20 |
+
|
21 |
+
References:
|
22 |
+
`mmdet.core.bbox.bbox2roi` in mmdetection
|
23 |
+
`convert_boxes_to_roi_format` in TorchVision
|
24 |
+
`modeling.poolers.convert_boxes_to_pooler_format` in detectron2
|
25 |
+
"""
|
26 |
+
assert isinstance(boxes_list, (list, tuple))
|
27 |
+
boxes = _concat(boxes_list, axis=0)
|
28 |
+
|
29 |
+
indices_list = [np.full((len(b), 1), i, boxes.dtype)
|
30 |
+
for i, b in enumerate(boxes_list)]
|
31 |
+
indices = _concat(indices_list, axis=0)
|
32 |
+
return boxes, indices
|
33 |
+
|
34 |
+
|
35 |
+
def convert_boxes_and_indices_to_boxes_list(boxes, indices, num_indices):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
boxes (np.ndarray): shape (N, 4+K)
|
39 |
+
indices (np.ndarray): shape (N,) or (N, 1), maybe batch index
|
40 |
+
in mini-batch or class label index.
|
41 |
+
num_indices (int): number of index.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
list (ndarray): boxes list of each index
|
45 |
+
|
46 |
+
References:
|
47 |
+
`mmdet.core.bbox2result` in mmdetection
|
48 |
+
`mmdet.core.bbox.roi2bbox` in mmdetection
|
49 |
+
`convert_boxes_to_roi_format` in TorchVision
|
50 |
+
`modeling.poolers.convert_boxes_to_pooler_format` in detectron2
|
51 |
+
"""
|
52 |
+
boxes = np.asarray(boxes)
|
53 |
+
indices = np.asarray(indices)
|
54 |
+
assert boxes.ndim == 2, "boxes ndim must be 2, got {}".format(boxes.ndim)
|
55 |
+
assert (indices.ndim == 1) or (indices.ndim == 2 and indices.shape[-1] == 1), \
|
56 |
+
"indices ndim must be 1 or 2 if last dimension size is 1, got shape {}".format(indices.shape)
|
57 |
+
assert boxes.shape[0] == indices.shape[0], "the 1st dimension size of boxes and indices "\
|
58 |
+
"must be the same, got {} != {}".format(boxes.shape[0], indices.shape[0])
|
59 |
+
|
60 |
+
if boxes.shape[0] == 0:
|
61 |
+
return [np.zeros((0, boxes.shape[1]), dtype=np.float32)
|
62 |
+
for i in range(num_indices)]
|
63 |
+
else:
|
64 |
+
if indices.ndim == 2:
|
65 |
+
indices = np.squeeze(indices, axis=-1)
|
66 |
+
return [boxes[indices == i, :] for i in range(num_indices)]
|
67 |
+
|
68 |
+
|
khandy/boxes/boxes_clip.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def clip_boxes(boxes, reference_box, copy=True):
|
5 |
+
"""Clip boxes to reference box.
|
6 |
+
|
7 |
+
References:
|
8 |
+
`clip_to_window` in TensorFlow object detection API.
|
9 |
+
"""
|
10 |
+
if copy:
|
11 |
+
boxes = boxes.copy()
|
12 |
+
ref_x_min, ref_y_min, ref_x_max, ref_y_max = reference_box[:4]
|
13 |
+
lower = np.array([ref_x_min, ref_y_min, ref_x_min, ref_y_min])
|
14 |
+
upper = np.array([ref_x_max, ref_y_max, ref_x_max, ref_y_max])
|
15 |
+
np.clip(boxes[..., :4], lower, upper, boxes[..., :4])
|
16 |
+
return boxes
|
17 |
+
|
18 |
+
|
19 |
+
def clip_boxes_to_image(boxes, image_width, image_height, subpixel=True, copy=True):
|
20 |
+
"""Clip boxes to image boundaries.
|
21 |
+
|
22 |
+
References:
|
23 |
+
`clip_boxes` in py-faster-rcnn
|
24 |
+
`core.boxes_op_list.clip_to_window` in TensorFlow object detection API.
|
25 |
+
`structures.Boxes.clip` in detectron2
|
26 |
+
|
27 |
+
Notes:
|
28 |
+
Equivalent to `clip_boxes(boxes, [0,0,image_width-1,image_height-1], copy)`
|
29 |
+
"""
|
30 |
+
if not subpixel:
|
31 |
+
image_width -= 1
|
32 |
+
image_height -= 1
|
33 |
+
reference_box = [0, 0, image_width, image_height]
|
34 |
+
return clip_boxes(boxes, reference_box, copy)
|
khandy/boxes/boxes_coder.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class FasterRcnnBoxCoder:
|
5 |
+
"""Faster RCNN box coder.
|
6 |
+
|
7 |
+
Notes:
|
8 |
+
boxes should be in cxcywh format.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, stddevs=None):
|
12 |
+
"""Constructor for FasterRcnnBoxCoder.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
stddevs: List of 4 positive scalars to scale ty, tx, th and tw.
|
16 |
+
If set to None, does not perform scaling. For Faster RCNN,
|
17 |
+
the open-source implementation recommends using [0.1, 0.1, 0.2, 0.2].
|
18 |
+
"""
|
19 |
+
if stddevs:
|
20 |
+
assert len(stddevs) == 4
|
21 |
+
for scalar in stddevs:
|
22 |
+
assert scalar > 0
|
23 |
+
self.stddevs = stddevs
|
24 |
+
|
25 |
+
def encode(self, boxes, reference_boxes, copy=True):
|
26 |
+
"""Encode boxes with respect to reference boxes.
|
27 |
+
"""
|
28 |
+
if copy:
|
29 |
+
boxes = boxes.copy()
|
30 |
+
|
31 |
+
boxes[..., 2:4] += 1e-8
|
32 |
+
reference_boxes[..., 2:4] += 1e-8
|
33 |
+
|
34 |
+
boxes[..., 0:2] -= reference_boxes[..., 0:2]
|
35 |
+
boxes[..., 0:2] /= reference_boxes[..., 2:4]
|
36 |
+
boxes[..., 2:4] /= reference_boxes[..., 2:4]
|
37 |
+
boxes[..., 2:4] = np.log(boxes[..., 2:4], boxes[..., 2:4])
|
38 |
+
if self.stddevs:
|
39 |
+
boxes[..., 0:4] /= self.stddevs
|
40 |
+
return boxes
|
41 |
+
|
42 |
+
def decode(self, rel_boxes, reference_boxes, copy=True):
|
43 |
+
"""Decode relative codes to boxes.
|
44 |
+
"""
|
45 |
+
if copy:
|
46 |
+
rel_boxes = rel_boxes.copy()
|
47 |
+
|
48 |
+
if self.stddevs:
|
49 |
+
rel_boxes[..., 0:4] *= self.stddevs
|
50 |
+
|
51 |
+
rel_boxes[..., 0:2] *= reference_boxes[..., 2:4]
|
52 |
+
rel_boxes[..., 0:2] += reference_boxes[..., 0:2]
|
53 |
+
rel_boxes[..., 2:4] = np.exp(rel_boxes[..., 2:4], rel_boxes[..., 2:4])
|
54 |
+
rel_boxes[..., 2:4] *= reference_boxes[..., 2:4]
|
55 |
+
return rel_boxes
|
56 |
+
|
57 |
+
def decode_points(self, rel_points, reference_boxes, copy=True):
|
58 |
+
"""Decode relative codes to points.
|
59 |
+
"""
|
60 |
+
if copy:
|
61 |
+
rel_points = rel_points.copy()
|
62 |
+
if self.stddevs:
|
63 |
+
rel_points[..., 0::2] *= self.stddevs[0]
|
64 |
+
rel_points[..., 1::2] *= self.stddevs[1]
|
65 |
+
rel_points[..., 0::2] *= reference_boxes[..., 2:3]
|
66 |
+
rel_points[..., 1::2] *= reference_boxes[..., 3:4]
|
67 |
+
rel_points[..., 0::2] += reference_boxes[..., 0:1]
|
68 |
+
rel_points[..., 1::2] += reference_boxes[..., 1:2]
|
69 |
+
return rel_points
|
khandy/boxes/boxes_convert.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def convert_xyxy_to_xywh(boxes, copy=True):
|
5 |
+
"""Convert [x_min, y_min, x_max, y_max] format to [x_min, y_min, width, height] format.
|
6 |
+
"""
|
7 |
+
if copy:
|
8 |
+
boxes = boxes.copy()
|
9 |
+
boxes[..., 2:4] -= boxes[..., 0:2]
|
10 |
+
return boxes
|
11 |
+
|
12 |
+
|
13 |
+
def convert_xywh_to_xyxy(boxes, copy=True):
|
14 |
+
"""Convert [x_min, y_min, width, height] format to [x_min, y_min, x_max, y_max] format.
|
15 |
+
"""
|
16 |
+
if copy:
|
17 |
+
boxes = boxes.copy()
|
18 |
+
boxes[..., 2:4] += boxes[..., 0:2]
|
19 |
+
return boxes
|
20 |
+
|
21 |
+
|
22 |
+
def convert_xywh_to_cxcywh(boxes, copy=True):
|
23 |
+
"""Convert [x_min, y_min, width, height] format to [cx, cy, width, height] format.
|
24 |
+
"""
|
25 |
+
if copy:
|
26 |
+
boxes = boxes.copy()
|
27 |
+
boxes[..., 0:2] += boxes[..., 2:4] * 0.5
|
28 |
+
return boxes
|
29 |
+
|
30 |
+
|
31 |
+
def convert_cxcywh_to_xywh(boxes, copy=True):
|
32 |
+
"""Convert [cx, cy, width, height] format to [x_min, y_min, width, height] format.
|
33 |
+
"""
|
34 |
+
if copy:
|
35 |
+
boxes = boxes.copy()
|
36 |
+
boxes[..., 0:2] -= boxes[..., 2:4] * 0.5
|
37 |
+
return boxes
|
38 |
+
|
39 |
+
|
40 |
+
def convert_xyxy_to_cxcywh(boxes, copy=True):
|
41 |
+
"""Convert [x_min, y_min, x_max, y_max] format to [cx, cy, width, height] format.
|
42 |
+
"""
|
43 |
+
if copy:
|
44 |
+
boxes = boxes.copy()
|
45 |
+
boxes[..., 2:4] -= boxes[..., 0:2]
|
46 |
+
boxes[..., 0:2] += boxes[..., 2:4] * 0.5
|
47 |
+
return boxes
|
48 |
+
|
49 |
+
|
50 |
+
def convert_cxcywh_to_xyxy(boxes, copy=True):
|
51 |
+
"""Convert [cx, cy, width, height] format to [x_min, y_min, x_max, y_max] format.
|
52 |
+
"""
|
53 |
+
if copy:
|
54 |
+
boxes = boxes.copy()
|
55 |
+
boxes[..., 0:2] -= boxes[..., 2:4] * 0.5
|
56 |
+
boxes[..., 2:4] += boxes[..., 0:2]
|
57 |
+
return boxes
|
58 |
+
|
59 |
+
|
60 |
+
def convert_boxes_format(boxes, in_fmt, out_fmt, copy=True):
|
61 |
+
"""Converts boxes from given in_fmt to out_fmt.
|
62 |
+
|
63 |
+
Supported in_fmt and out_fmt are:
|
64 |
+
'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
|
65 |
+
'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
|
66 |
+
'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h
|
67 |
+
being width and height.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
boxes: boxes which will be converted.
|
71 |
+
in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].
|
72 |
+
out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
boxes: Boxes into converted format.
|
76 |
+
|
77 |
+
References:
|
78 |
+
torchvision.ops.box_convert
|
79 |
+
"""
|
80 |
+
allowed_fmts = ("xyxy", "xywh", "cxcywh")
|
81 |
+
if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
|
82 |
+
raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
|
83 |
+
if copy:
|
84 |
+
boxes = boxes.copy()
|
85 |
+
if in_fmt == out_fmt:
|
86 |
+
return boxes
|
87 |
+
|
88 |
+
if (in_fmt, out_fmt) == ("xyxy", "xywh"):
|
89 |
+
boxes = convert_xyxy_to_xywh(boxes, copy=False)
|
90 |
+
elif (in_fmt, out_fmt) == ("xywh", "xyxy"):
|
91 |
+
boxes = convert_xywh_to_xyxy(boxes, copy=False)
|
92 |
+
elif (in_fmt, out_fmt) == ("xywh", "cxcywh"):
|
93 |
+
boxes = convert_xywh_to_cxcywh(boxes, copy=False)
|
94 |
+
elif (in_fmt, out_fmt) == ("cxcywh", "xywh"):
|
95 |
+
boxes = convert_cxcywh_to_xywh(boxes, copy=False)
|
96 |
+
elif (in_fmt, out_fmt) == ("xyxy", "cxcywh"):
|
97 |
+
boxes = convert_xyxy_to_cxcywh(boxes, copy=False)
|
98 |
+
elif (in_fmt, out_fmt) == ("cxcywh", "xyxy"):
|
99 |
+
boxes = convert_cxcywh_to_xyxy(boxes, copy=False)
|
100 |
+
return boxes
|
101 |
+
|
khandy/boxes/boxes_filter.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def filter_small_boxes(boxes, min_width, min_height):
|
5 |
+
"""Filters all boxes with side smaller than min size.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
boxes: a numpy array with shape [N, 4] holding N boxes.
|
9 |
+
min_width (float): minimum width
|
10 |
+
min_height (float): minimum height
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
keep: indices of the boxes that have width larger than
|
14 |
+
min_width and height larger than min_height.
|
15 |
+
|
16 |
+
References:
|
17 |
+
`_filter_boxes` in py-faster-rcnn
|
18 |
+
`prune_small_boxes` in TensorFlow object detection API.
|
19 |
+
`structures.Boxes.nonempty` in detectron2
|
20 |
+
`ops.boxes.remove_small_boxes` in torchvision
|
21 |
+
"""
|
22 |
+
widths = boxes[:, 2] - boxes[:, 0]
|
23 |
+
heights = boxes[:, 3] - boxes[:, 1]
|
24 |
+
# keep represents indices to keep,
|
25 |
+
# mask represents bool ndarray, so use mask here.
|
26 |
+
mask = (widths >= min_width)
|
27 |
+
mask &= (heights >= min_height)
|
28 |
+
return np.nonzero(mask)[0]
|
29 |
+
|
30 |
+
|
31 |
+
def filter_boxes_outside(boxes, reference_box):
|
32 |
+
"""Filters bounding boxes that fall outside reference box.
|
33 |
+
|
34 |
+
References:
|
35 |
+
`prune_outside_window` in TensorFlow object detection API.
|
36 |
+
"""
|
37 |
+
x_min, y_min, x_max, y_max = reference_box[:4]
|
38 |
+
mask = ((boxes[:, 0] >= x_min) & (boxes[:, 1] >= y_min) &
|
39 |
+
(boxes[:, 2] <= x_max) & (boxes[:, 3] <= y_max))
|
40 |
+
return np.nonzero(mask)[0]
|
41 |
+
|
42 |
+
|
43 |
+
def filter_boxes_completely_outside(boxes, reference_box):
|
44 |
+
"""Filters bounding boxes that fall completely outside of reference box.
|
45 |
+
|
46 |
+
References:
|
47 |
+
`prune_completely_outside_window` in TensorFlow object detection API.
|
48 |
+
"""
|
49 |
+
x_min, y_min, x_max, y_max = reference_box[:4]
|
50 |
+
mask = ((boxes[:, 0] < x_max) & (boxes[:, 1] < y_max) &
|
51 |
+
(boxes[:, 2] > x_min) & (boxes[:, 3] > y_min))
|
52 |
+
return np.nonzero(mask)[0]
|
53 |
+
|
54 |
+
|
55 |
+
def non_max_suppression(boxes, scores, thresh, classes=None, ratio_type="iou"):
|
56 |
+
"""Greedily select boxes with high confidence
|
57 |
+
Args:
|
58 |
+
boxes: [[x_min, y_min, x_max, y_max], ...]
|
59 |
+
scores: object confidence
|
60 |
+
thresh: retain overlap_ratio <= thresh
|
61 |
+
classes: class labels
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
indices to keep
|
65 |
+
|
66 |
+
References:
|
67 |
+
`py_cpu_nms` in py-faster-rcnn
|
68 |
+
torchvision.ops.nms
|
69 |
+
torchvision.ops.batched_nms
|
70 |
+
"""
|
71 |
+
|
72 |
+
if boxes.size == 0:
|
73 |
+
return np.empty((0,), dtype=np.int64)
|
74 |
+
if classes is not None:
|
75 |
+
# strategy: in order to perform NMS independently per class,
|
76 |
+
# we add an offset to all the boxes. The offset is dependent
|
77 |
+
# only on the class idx, and is large enough so that boxes
|
78 |
+
# from different classes do not overlap
|
79 |
+
max_coordinate = np.max(boxes)
|
80 |
+
offsets = classes * (max_coordinate + 1)
|
81 |
+
boxes = boxes + offsets[:, None]
|
82 |
+
|
83 |
+
x_mins = boxes[:, 0]
|
84 |
+
y_mins = boxes[:, 1]
|
85 |
+
x_maxs = boxes[:, 2]
|
86 |
+
y_maxs = boxes[:, 3]
|
87 |
+
areas = (x_maxs - x_mins) * (y_maxs - y_mins)
|
88 |
+
order = scores.flatten().argsort()[::-1]
|
89 |
+
|
90 |
+
keep = []
|
91 |
+
while order.size > 0:
|
92 |
+
i = order[0]
|
93 |
+
keep.append(i)
|
94 |
+
|
95 |
+
max_x_mins = np.maximum(x_mins[i], x_mins[order[1:]])
|
96 |
+
max_y_mins = np.maximum(y_mins[i], y_mins[order[1:]])
|
97 |
+
min_x_maxs = np.minimum(x_maxs[i], x_maxs[order[1:]])
|
98 |
+
min_y_maxs = np.minimum(y_maxs[i], y_maxs[order[1:]])
|
99 |
+
widths = np.maximum(0, min_x_maxs - max_x_mins)
|
100 |
+
heights = np.maximum(0, min_y_maxs - max_y_mins)
|
101 |
+
intersect_areas = widths * heights
|
102 |
+
|
103 |
+
if ratio_type in ["union", 'iou']:
|
104 |
+
ratio = intersect_areas / (areas[i] + areas[order[1:]] - intersect_areas)
|
105 |
+
elif ratio_type == "min":
|
106 |
+
ratio = intersect_areas / np.minimum(areas[i], areas[order[1:]])
|
107 |
+
else:
|
108 |
+
raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
|
109 |
+
|
110 |
+
inds = np.nonzero(ratio <= thresh)[0]
|
111 |
+
order = order[inds + 1]
|
112 |
+
return np.asarray(keep)
|
113 |
+
|
khandy/boxes/boxes_overlap.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def paired_intersection(boxes1, boxes2):
|
5 |
+
"""Compute paired intersection areas between boxes.
|
6 |
+
Args:
|
7 |
+
boxes1: a numpy array with shape [N, 4] holding N boxes
|
8 |
+
boxes2: a numpy array with shape [N, 4] holding N boxes
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
a numpy array with shape [N,] representing itemwise intersection area
|
12 |
+
|
13 |
+
References:
|
14 |
+
`core.box_list_ops.matched_intersection` in Tensorflow object detection API
|
15 |
+
|
16 |
+
Notes:
|
17 |
+
can called as itemwise_intersection, matched_intersection, aligned_intersection
|
18 |
+
"""
|
19 |
+
max_x_mins = np.maximum(boxes1[:, 0], boxes2[:, 0])
|
20 |
+
max_y_mins = np.maximum(boxes1[:, 1], boxes2[:, 1])
|
21 |
+
min_x_maxs = np.minimum(boxes1[:, 2], boxes2[:, 2])
|
22 |
+
min_y_maxs = np.minimum(boxes1[:, 3], boxes2[:, 3])
|
23 |
+
intersect_widths = np.maximum(0, min_x_maxs - max_x_mins)
|
24 |
+
intersect_heights = np.maximum(0, min_y_maxs - max_y_mins)
|
25 |
+
return intersect_widths * intersect_heights
|
26 |
+
|
27 |
+
|
28 |
+
def pairwise_intersection(boxes1, boxes2):
|
29 |
+
"""Compute pairwise intersection areas between boxes.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
boxes1: a numpy array with shape [N, 4] holding N boxes.
|
33 |
+
boxes2: a numpy array with shape [M, 4] holding M boxes.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
a numpy array with shape [N, M] representing pairwise intersection area.
|
37 |
+
|
38 |
+
References:
|
39 |
+
`core.box_list_ops.intersection` in Tensorflow object detection API
|
40 |
+
`utils.box_list_ops.intersection` in Tensorflow object detection API
|
41 |
+
"""
|
42 |
+
if boxes1.shape[0] * boxes2.shape[0] == 0:
|
43 |
+
return np.zeros((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
|
44 |
+
|
45 |
+
swap = False
|
46 |
+
if boxes1.shape[0] > boxes2.shape[0]:
|
47 |
+
boxes1, boxes2 = boxes2, boxes1
|
48 |
+
swap = True
|
49 |
+
intersect_areas = np.empty((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
|
50 |
+
|
51 |
+
for i in range(boxes1.shape[0]):
|
52 |
+
max_x_mins = np.maximum(boxes1[i, 0], boxes2[:, 0])
|
53 |
+
max_y_mins = np.maximum(boxes1[i, 1], boxes2[:, 1])
|
54 |
+
min_x_maxs = np.minimum(boxes1[i, 2], boxes2[:, 2])
|
55 |
+
min_y_maxs = np.minimum(boxes1[i, 3], boxes2[:, 3])
|
56 |
+
intersect_widths = np.maximum(0, min_x_maxs - max_x_mins)
|
57 |
+
intersect_heights = np.maximum(0, min_y_maxs - max_y_mins)
|
58 |
+
intersect_areas[i, :] = intersect_widths * intersect_heights
|
59 |
+
if swap:
|
60 |
+
intersect_areas = intersect_areas.T
|
61 |
+
return intersect_areas
|
62 |
+
|
63 |
+
|
64 |
+
def paired_overlap_ratio(boxes1, boxes2, ratio_type='iou'):
|
65 |
+
"""Compute paired overlap ratio between boxes.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
boxes1: a numpy array with shape [N, 4] holding N boxes
|
69 |
+
boxes2: a numpy array with shape [N, 4] holding N boxes
|
70 |
+
ratio_type:
|
71 |
+
iou: Intersection-over-union (iou).
|
72 |
+
ioa: Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
|
73 |
+
their intersection area over box2's area. Note that ioa is not symmetric,
|
74 |
+
that is, IOA(box1, box2) != IOA(box2, box1).
|
75 |
+
min: Compute the ratio as the area of intersection between box1 and box2,
|
76 |
+
divided by the minimum area of the two bounding boxes.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
a numpy array with shape [N,] representing itemwise overlap ratio.
|
80 |
+
|
81 |
+
References:
|
82 |
+
`core.box_list_ops.matched_iou` in Tensorflow object detection API
|
83 |
+
`structures.boxes.matched_boxlist_iou` in detectron2
|
84 |
+
`mmdet.core.bbox.bbox_overlaps`, see https://mmdetection.readthedocs.io/en/v2.17.0/api.html#mmdet.core.bbox.bbox_overlaps
|
85 |
+
"""
|
86 |
+
intersect_areas = paired_intersection(boxes1, boxes2)
|
87 |
+
areas1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
|
88 |
+
areas2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
|
89 |
+
|
90 |
+
if ratio_type in ['union', 'iou', 'giou']:
|
91 |
+
union_areas = areas1 - intersect_areas
|
92 |
+
union_areas += areas2
|
93 |
+
intersect_areas /= union_areas
|
94 |
+
elif ratio_type == 'min':
|
95 |
+
min_areas = np.minimum(areas1, areas2)
|
96 |
+
intersect_areas /= min_areas
|
97 |
+
elif ratio_type == 'ioa':
|
98 |
+
intersect_areas /= areas2
|
99 |
+
else:
|
100 |
+
raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
|
101 |
+
|
102 |
+
if ratio_type == 'giou':
|
103 |
+
min_xy_mins = np.minimum(boxes1[:, 0:2], boxes2[:, 0:2])
|
104 |
+
max_xy_mins = np.maximum(boxes1[:, 2:4], boxes2[:, 2:4])
|
105 |
+
# mebb = minimum enclosing bounding boxes
|
106 |
+
mebb_whs = np.maximum(0, max_xy_mins - min_xy_mins)
|
107 |
+
mebb_areas = mebb_whs[:, 0] * mebb_whs[:, 1]
|
108 |
+
union_areas -= mebb_areas
|
109 |
+
union_areas /= mebb_areas
|
110 |
+
intersect_areas += union_areas
|
111 |
+
return intersect_areas
|
112 |
+
|
113 |
+
|
114 |
+
def pairwise_overlap_ratio(boxes1, boxes2, ratio_type='iou'):
|
115 |
+
"""Compute pairwise overlap ratio between boxes.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
boxes1: a numpy array with shape [N, 4] holding N boxes
|
119 |
+
boxes2: a numpy array with shape [M, 4] holding M boxes
|
120 |
+
ratio_type:
|
121 |
+
iou: Intersection-over-union (iou).
|
122 |
+
ioa: Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
|
123 |
+
their intersection area over box2's area. Note that ioa is not symmetric,
|
124 |
+
that is, IOA(box1, box2) != IOA(box2, box1).
|
125 |
+
min: Compute the ratio as the area of intersection between box1 and box2,
|
126 |
+
divided by the minimum area of the two bounding boxes.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
a numpy array with shape [N, M] representing pairwise overlap ratio.
|
130 |
+
|
131 |
+
References:
|
132 |
+
`utils.np_box_ops.iou` in Tensorflow object detection API
|
133 |
+
`utils.np_box_ops.ioa` in Tensorflow object detection API
|
134 |
+
`utils.np_box_ops.giou` in Tensorflow object detection API
|
135 |
+
`mmdet.core.bbox.bbox_overlaps`, see https://mmdetection.readthedocs.io/en/v2.17.0/api.html#mmdet.core.bbox.bbox_overlaps
|
136 |
+
`torchvision.ops.box_iou`, see https://pytorch.org/vision/stable/ops.html#torchvision.ops.box_iou
|
137 |
+
`torchvision.ops.generalized_box_iou`, see https://pytorch.org/vision/stable/ops.html#torchvision.ops.generalized_box_iou
|
138 |
+
http://ww2.mathworks.cn/help/vision/ref/bboxoverlapratio.html
|
139 |
+
"""
|
140 |
+
intersect_areas = pairwise_intersection(boxes1, boxes2)
|
141 |
+
areas1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
|
142 |
+
areas2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
|
143 |
+
|
144 |
+
if ratio_type in ['union', 'iou', 'giou']:
|
145 |
+
union_areas = np.expand_dims(areas1, axis=1) - intersect_areas
|
146 |
+
union_areas += np.expand_dims(areas2, axis=0)
|
147 |
+
intersect_areas /= union_areas
|
148 |
+
elif ratio_type == 'min':
|
149 |
+
min_areas = np.minimum(np.expand_dims(areas1, axis=1), np.expand_dims(areas2, axis=0))
|
150 |
+
intersect_areas /= min_areas
|
151 |
+
elif ratio_type == 'ioa':
|
152 |
+
intersect_areas /= np.expand_dims(areas2, axis=0)
|
153 |
+
else:
|
154 |
+
raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
|
155 |
+
|
156 |
+
if ratio_type == 'giou':
|
157 |
+
min_xy_mins = np.minimum(boxes1[:, None, 0:2], boxes2[:, 0:2])
|
158 |
+
max_xy_mins = np.maximum(boxes1[:, None, 2:4], boxes2[:, 2:4])
|
159 |
+
# mebb = minimum enclosing bounding boxes
|
160 |
+
mebb_whs = np.maximum(0, max_xy_mins - min_xy_mins)
|
161 |
+
mebb_areas = mebb_whs[:, :, 0] * mebb_whs[:, :, 1]
|
162 |
+
union_areas -= mebb_areas
|
163 |
+
union_areas /= mebb_areas
|
164 |
+
intersect_areas += union_areas
|
165 |
+
return intersect_areas
|
166 |
+
|
khandy/boxes/boxes_transform_flip.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from .boxes_utils import assert_and_normalize_shape
|
3 |
+
|
4 |
+
|
5 |
+
def flip_boxes(boxes, x_center=0, y_center=0, direction='h'):
|
6 |
+
"""
|
7 |
+
Args:
|
8 |
+
boxes: (N, 4+K)
|
9 |
+
x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
10 |
+
y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
11 |
+
direction: str
|
12 |
+
"""
|
13 |
+
assert direction in ['x', 'h', 'horizontal',
|
14 |
+
'y', 'v', 'vertical',
|
15 |
+
'o', 'b', 'both']
|
16 |
+
boxes = np.asarray(boxes, np.float32)
|
17 |
+
ret_boxes = boxes.copy()
|
18 |
+
|
19 |
+
x_center = np.asarray(x_center, np.float32)
|
20 |
+
y_center = np.asarray(y_center, np.float32)
|
21 |
+
x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
|
22 |
+
y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
|
23 |
+
|
24 |
+
if direction in ['o', 'b', 'both', 'x', 'h', 'horizontal']:
|
25 |
+
ret_boxes[:, 0] = 2 * x_center - boxes[:, 2]
|
26 |
+
ret_boxes[:, 2] = 2 * x_center - boxes[:, 0]
|
27 |
+
if direction in ['o', 'b', 'both', 'y', 'v', 'vertical']:
|
28 |
+
ret_boxes[:, 1] = 2 * y_center - boxes[:, 3]
|
29 |
+
ret_boxes[:, 3] = 2 * y_center - boxes[:, 1]
|
30 |
+
return ret_boxes
|
31 |
+
|
32 |
+
|
33 |
+
def fliplr_boxes(boxes, x_center=0, y_center=0):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
boxes: (N, 4+K)
|
37 |
+
x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
38 |
+
y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
39 |
+
"""
|
40 |
+
boxes = np.asarray(boxes, np.float32)
|
41 |
+
ret_boxes = boxes.copy()
|
42 |
+
|
43 |
+
x_center = np.asarray(x_center, np.float32)
|
44 |
+
y_center = np.asarray(y_center, np.float32)
|
45 |
+
x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
|
46 |
+
y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
|
47 |
+
|
48 |
+
ret_boxes[:, 0] = 2 * x_center - boxes[:, 2]
|
49 |
+
ret_boxes[:, 2] = 2 * x_center - boxes[:, 0]
|
50 |
+
return ret_boxes
|
51 |
+
|
52 |
+
|
53 |
+
def flipud_boxes(boxes, x_center=0, y_center=0):
|
54 |
+
"""
|
55 |
+
Args:
|
56 |
+
boxes: (N, 4+K)
|
57 |
+
x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
58 |
+
y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
59 |
+
"""
|
60 |
+
boxes = np.asarray(boxes, np.float32)
|
61 |
+
ret_boxes = boxes.copy()
|
62 |
+
|
63 |
+
x_center = np.asarray(x_center, np.float32)
|
64 |
+
y_center = np.asarray(y_center, np.float32)
|
65 |
+
x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
|
66 |
+
y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
|
67 |
+
|
68 |
+
ret_boxes[:, 1] = 2 * y_center - boxes[:, 3]
|
69 |
+
ret_boxes[:, 3] = 2 * y_center - boxes[:, 1]
|
70 |
+
return ret_boxes
|
71 |
+
|
72 |
+
|
73 |
+
def transpose_boxes(boxes, x_center=0, y_center=0):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
boxes: (N, 4+K)
|
77 |
+
x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
78 |
+
y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
79 |
+
"""
|
80 |
+
boxes = np.asarray(boxes, np.float32)
|
81 |
+
ret_boxes = boxes.copy()
|
82 |
+
|
83 |
+
x_center = np.asarray(x_center, np.float32)
|
84 |
+
y_center = np.asarray(y_center, np.float32)
|
85 |
+
x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
|
86 |
+
y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
|
87 |
+
|
88 |
+
shift = x_center - y_center
|
89 |
+
ret_boxes[:, 0] = boxes[:, 1] + shift
|
90 |
+
ret_boxes[:, 1] = boxes[:, 0] - shift
|
91 |
+
ret_boxes[:, 2] = boxes[:, 3] + shift
|
92 |
+
ret_boxes[:, 3] = boxes[:, 2] - shift
|
93 |
+
return ret_boxes
|
94 |
+
|
95 |
+
|
96 |
+
def flip_boxes_in_image(boxes, image_width, image_height, direction='h'):
|
97 |
+
"""
|
98 |
+
Args:
|
99 |
+
boxes: (N, 4+K)
|
100 |
+
image_width: int
|
101 |
+
image_width: int
|
102 |
+
direction: str
|
103 |
+
|
104 |
+
References:
|
105 |
+
`core.bbox.bbox_flip` in mmdetection
|
106 |
+
`datasets.pipelines.RandomFlip.bbox_flip` in mmdetection
|
107 |
+
"""
|
108 |
+
x_center = (image_width - 1) * 0.5
|
109 |
+
y_center = (image_height - 1) * 0.5
|
110 |
+
ret_boxes = flip_boxes(boxes, x_center, y_center, direction=direction)
|
111 |
+
return ret_boxes
|
112 |
+
|
113 |
+
|
114 |
+
def rot90_boxes_in_image(boxes, image_width, image_height, n=1):
|
115 |
+
"""Rotate boxes counter-clockwise by 90 degrees.
|
116 |
+
|
117 |
+
References:
|
118 |
+
np.rot90
|
119 |
+
cv2.rotate
|
120 |
+
tf.image.rot90
|
121 |
+
"""
|
122 |
+
n = n % 4
|
123 |
+
if n == 0:
|
124 |
+
ret_boxes = boxes.copy()
|
125 |
+
elif n == 1:
|
126 |
+
ret_boxes = transpose_boxes(boxes)
|
127 |
+
ret_boxes = flip_boxes_in_image(ret_boxes, image_width, image_height, 'v')
|
128 |
+
elif n == 2:
|
129 |
+
ret_boxes = flip_boxes_in_image(boxes, image_width, image_height, 'o')
|
130 |
+
else:
|
131 |
+
ret_boxes = transpose_boxes(boxes)
|
132 |
+
ret_boxes = flip_boxes_in_image(ret_boxes, image_width, image_height, 'h');
|
133 |
+
return ret_boxes
|
134 |
+
|
135 |
+
|
khandy/boxes/boxes_transform_rotate.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from .boxes_utils import assert_and_normalize_shape
|
3 |
+
|
4 |
+
|
5 |
+
def rotate_boxes(boxes, angle, x_center=0, y_center=0, scale=1,
|
6 |
+
degrees=True, return_rotated_boxes=False):
|
7 |
+
"""
|
8 |
+
Args:
|
9 |
+
boxes: (N, 4+K)
|
10 |
+
angle: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
11 |
+
x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
12 |
+
y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
13 |
+
scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
14 |
+
scale factor in x and y dimension
|
15 |
+
degrees: bool
|
16 |
+
return_rotated_boxes: bool
|
17 |
+
"""
|
18 |
+
boxes = np.asarray(boxes, np.float32)
|
19 |
+
|
20 |
+
angle = np.asarray(angle, np.float32)
|
21 |
+
x_center = np.asarray(x_center, np.float32)
|
22 |
+
y_center = np.asarray(y_center, np.float32)
|
23 |
+
scale = np.asarray(scale, np.float32)
|
24 |
+
|
25 |
+
angle = assert_and_normalize_shape(angle, boxes.shape[0])
|
26 |
+
x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
|
27 |
+
y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
|
28 |
+
scale = assert_and_normalize_shape(scale, boxes.shape[0])
|
29 |
+
|
30 |
+
if degrees:
|
31 |
+
angle = np.deg2rad(angle)
|
32 |
+
cos_val = scale * np.cos(angle)
|
33 |
+
sin_val = scale * np.sin(angle)
|
34 |
+
x_shift = x_center - x_center * cos_val + y_center * sin_val
|
35 |
+
y_shift = y_center - x_center * sin_val - y_center * cos_val
|
36 |
+
|
37 |
+
x_mins, y_mins = boxes[:,0], boxes[:,1]
|
38 |
+
x_maxs, y_maxs = boxes[:,2], boxes[:,3]
|
39 |
+
x00 = x_mins * cos_val - y_mins * sin_val + x_shift
|
40 |
+
x10 = x_maxs * cos_val - y_mins * sin_val + x_shift
|
41 |
+
x11 = x_maxs * cos_val - y_maxs * sin_val + x_shift
|
42 |
+
x01 = x_mins * cos_val - y_maxs * sin_val + x_shift
|
43 |
+
|
44 |
+
y00 = x_mins * sin_val + y_mins * cos_val + y_shift
|
45 |
+
y10 = x_maxs * sin_val + y_mins * cos_val + y_shift
|
46 |
+
y11 = x_maxs * sin_val + y_maxs * cos_val + y_shift
|
47 |
+
y01 = x_mins * sin_val + y_maxs * cos_val + y_shift
|
48 |
+
|
49 |
+
rotated_boxes = np.stack([x00, y00, x10, y10, x11, y11, x01, y01], axis=-1)
|
50 |
+
ret_x_mins = np.min(rotated_boxes[:,0::2], axis=1)
|
51 |
+
ret_y_mins = np.min(rotated_boxes[:,1::2], axis=1)
|
52 |
+
ret_x_maxs = np.max(rotated_boxes[:,0::2], axis=1)
|
53 |
+
ret_y_maxs = np.max(rotated_boxes[:,1::2], axis=1)
|
54 |
+
|
55 |
+
if boxes.ndim == 4:
|
56 |
+
ret_boxes = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
|
57 |
+
else:
|
58 |
+
ret_boxes = boxes.copy()
|
59 |
+
ret_boxes[:, :4] = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
|
60 |
+
|
61 |
+
if not return_rotated_boxes:
|
62 |
+
return ret_boxes
|
63 |
+
else:
|
64 |
+
return ret_boxes, rotated_boxes
|
65 |
+
|
66 |
+
|
67 |
+
def rotate_boxes_wrt_centers(boxes, angle, scale=1, degrees=True,
|
68 |
+
return_rotated_boxes=False):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
boxes: (N, 4+K)
|
72 |
+
angle: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
73 |
+
scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
74 |
+
scale factor in x and y dimension
|
75 |
+
degrees: bool
|
76 |
+
return_rotated_boxes: bool
|
77 |
+
"""
|
78 |
+
boxes = np.asarray(boxes, np.float32)
|
79 |
+
|
80 |
+
angle = np.asarray(angle, np.float32)
|
81 |
+
scale = np.asarray(scale, np.float32)
|
82 |
+
angle = assert_and_normalize_shape(angle, boxes.shape[0])
|
83 |
+
scale = assert_and_normalize_shape(scale, boxes.shape[0])
|
84 |
+
|
85 |
+
if degrees:
|
86 |
+
angle = np.deg2rad(angle)
|
87 |
+
cos_val = scale * np.cos(angle)
|
88 |
+
sin_val = scale * np.sin(angle)
|
89 |
+
|
90 |
+
x_centers = boxes[:, 2] + boxes[:, 0]
|
91 |
+
y_centers = boxes[:, 3] + boxes[:, 1]
|
92 |
+
x_centers *= 0.5
|
93 |
+
y_centers *= 0.5
|
94 |
+
|
95 |
+
half_widths = boxes[:, 2] - boxes[:, 0]
|
96 |
+
half_heights = boxes[:, 3] - boxes[:, 1]
|
97 |
+
half_widths *= 0.5
|
98 |
+
half_heights *= 0.5
|
99 |
+
|
100 |
+
half_widths_cos = half_widths * cos_val
|
101 |
+
half_widths_sin = half_widths * sin_val
|
102 |
+
half_heights_cos = half_heights * cos_val
|
103 |
+
half_heights_sin = half_heights * sin_val
|
104 |
+
|
105 |
+
x00 = -half_widths_cos + half_heights_sin
|
106 |
+
x10 = half_widths_cos + half_heights_sin
|
107 |
+
x11 = half_widths_cos - half_heights_sin
|
108 |
+
x01 = -half_widths_cos - half_heights_sin
|
109 |
+
x00 += x_centers
|
110 |
+
x10 += x_centers
|
111 |
+
x11 += x_centers
|
112 |
+
x01 += x_centers
|
113 |
+
|
114 |
+
y00 = -half_widths_sin - half_heights_cos
|
115 |
+
y10 = half_widths_sin - half_heights_cos
|
116 |
+
y11 = half_widths_sin + half_heights_cos
|
117 |
+
y01 = -half_widths_sin + half_heights_cos
|
118 |
+
y00 += y_centers
|
119 |
+
y10 += y_centers
|
120 |
+
y11 += y_centers
|
121 |
+
y01 += y_centers
|
122 |
+
|
123 |
+
rotated_boxes = np.stack([x00, y00, x10, y10, x11, y11, x01, y01], axis=-1)
|
124 |
+
ret_x_mins = np.min(rotated_boxes[:,0::2], axis=1)
|
125 |
+
ret_y_mins = np.min(rotated_boxes[:,1::2], axis=1)
|
126 |
+
ret_x_maxs = np.max(rotated_boxes[:,0::2], axis=1)
|
127 |
+
ret_y_maxs = np.max(rotated_boxes[:,1::2], axis=1)
|
128 |
+
|
129 |
+
if boxes.ndim == 4:
|
130 |
+
ret_boxes = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
|
131 |
+
else:
|
132 |
+
ret_boxes = boxes.copy()
|
133 |
+
ret_boxes[:, :4] = np.stack([ret_x_mins, ret_y_mins, ret_x_maxs, ret_y_maxs], axis=-1)
|
134 |
+
|
135 |
+
if not return_rotated_boxes:
|
136 |
+
return ret_boxes
|
137 |
+
else:
|
138 |
+
return ret_boxes, rotated_boxes
|
139 |
+
|
140 |
+
|
khandy/boxes/boxes_transform_scale.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from .boxes_utils import assert_and_normalize_shape
|
3 |
+
|
4 |
+
|
5 |
+
def scale_boxes(boxes, x_scale=1, y_scale=1, x_center=0, y_center=0, copy=True):
|
6 |
+
"""Scale boxes coordinates in x and y dimensions.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
boxes: (N, 4+K)
|
10 |
+
x_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
11 |
+
scale factor in x dimension
|
12 |
+
y_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
13 |
+
scale factor in y dimension
|
14 |
+
x_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
15 |
+
y_center: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
16 |
+
|
17 |
+
References:
|
18 |
+
`core.box_list_ops.scale` in TensorFlow object detection API
|
19 |
+
`utils.box_list_ops.scale` in TensorFlow object detection API
|
20 |
+
`datasets.pipelines.Resize._resize_bboxes` in mmdetection
|
21 |
+
`core.anchor.guided_anchor_target.calc_region` in mmdetection where comments may be misleading!
|
22 |
+
`layers.mask_ops.scale_boxes` in detectron2
|
23 |
+
`mmcv.bbox_scaling`
|
24 |
+
"""
|
25 |
+
boxes = np.array(boxes, dtype=np.float32, copy=copy)
|
26 |
+
|
27 |
+
x_scale = np.asarray(x_scale, np.float32)
|
28 |
+
y_scale = np.asarray(y_scale, np.float32)
|
29 |
+
x_scale = assert_and_normalize_shape(x_scale, boxes.shape[0])
|
30 |
+
y_scale = assert_and_normalize_shape(y_scale, boxes.shape[0])
|
31 |
+
|
32 |
+
x_center = np.asarray(x_center, np.float32)
|
33 |
+
y_center = np.asarray(y_center, np.float32)
|
34 |
+
x_center = assert_and_normalize_shape(x_center, boxes.shape[0])
|
35 |
+
y_center = assert_and_normalize_shape(y_center, boxes.shape[0])
|
36 |
+
|
37 |
+
x_shift = 1 - x_scale
|
38 |
+
y_shift = 1 - y_scale
|
39 |
+
x_shift *= x_center
|
40 |
+
y_shift *= y_center
|
41 |
+
|
42 |
+
boxes[:, 0] *= x_scale
|
43 |
+
boxes[:, 1] *= y_scale
|
44 |
+
boxes[:, 2] *= x_scale
|
45 |
+
boxes[:, 3] *= y_scale
|
46 |
+
boxes[:, 0] += x_shift
|
47 |
+
boxes[:, 1] += y_shift
|
48 |
+
boxes[:, 2] += x_shift
|
49 |
+
boxes[:, 3] += y_shift
|
50 |
+
return boxes
|
51 |
+
|
52 |
+
|
53 |
+
def scale_boxes_wrt_centers(boxes, x_scale=1, y_scale=1, copy=True):
|
54 |
+
"""
|
55 |
+
Args:
|
56 |
+
boxes: (N, 4+K)
|
57 |
+
x_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
58 |
+
scale factor in x dimension
|
59 |
+
y_scale: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
60 |
+
scale factor in y dimension
|
61 |
+
|
62 |
+
References:
|
63 |
+
`core.anchor.guided_anchor_target.calc_region` in mmdetection where comments may be misleading!
|
64 |
+
`layers.mask_ops.scale_boxes` in detectron2
|
65 |
+
`mmcv.bbox_scaling`
|
66 |
+
"""
|
67 |
+
boxes = np.array(boxes, dtype=np.float32, copy=copy)
|
68 |
+
|
69 |
+
x_scale = np.asarray(x_scale, np.float32)
|
70 |
+
y_scale = np.asarray(y_scale, np.float32)
|
71 |
+
x_scale = assert_and_normalize_shape(x_scale, boxes.shape[0])
|
72 |
+
y_scale = assert_and_normalize_shape(y_scale, boxes.shape[0])
|
73 |
+
|
74 |
+
x_factor = (x_scale - 1) * 0.5
|
75 |
+
y_factor = (y_scale - 1) * 0.5
|
76 |
+
x_deltas = boxes[:, 2] - boxes[:, 0]
|
77 |
+
y_deltas = boxes[:, 3] - boxes[:, 1]
|
78 |
+
x_deltas *= x_factor
|
79 |
+
y_deltas *= y_factor
|
80 |
+
|
81 |
+
boxes[:, 0] -= x_deltas
|
82 |
+
boxes[:, 1] -= y_deltas
|
83 |
+
boxes[:, 2] += x_deltas
|
84 |
+
boxes[:, 3] += y_deltas
|
85 |
+
return boxes
|
86 |
+
|
khandy/boxes/boxes_transform_translate.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from .boxes_utils import assert_and_normalize_shape
|
3 |
+
|
4 |
+
|
5 |
+
def translate_boxes(boxes, x_shift=0, y_shift=0, copy=True):
|
6 |
+
"""translate boxes coordinates in x and y dimensions.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
boxes: (N, 4+K)
|
10 |
+
x_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
11 |
+
shift in x dimension
|
12 |
+
y_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
13 |
+
shift in y dimension
|
14 |
+
copy: bool
|
15 |
+
|
16 |
+
References:
|
17 |
+
`datasets.pipelines.RandomCrop` in mmdetection
|
18 |
+
"""
|
19 |
+
boxes = np.array(boxes, dtype=np.float32, copy=copy)
|
20 |
+
|
21 |
+
x_shift = np.asarray(x_shift, np.float32)
|
22 |
+
y_shift = np.asarray(y_shift, np.float32)
|
23 |
+
|
24 |
+
x_shift = assert_and_normalize_shape(x_shift, boxes.shape[0])
|
25 |
+
y_shift = assert_and_normalize_shape(y_shift, boxes.shape[0])
|
26 |
+
|
27 |
+
boxes[:, 0] += x_shift
|
28 |
+
boxes[:, 1] += y_shift
|
29 |
+
boxes[:, 2] += x_shift
|
30 |
+
boxes[:, 3] += y_shift
|
31 |
+
return boxes
|
32 |
+
|
33 |
+
|
34 |
+
def adjust_boxes(boxes, x_min_shift, y_min_shift, x_max_shift, y_max_shift, copy=True):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
boxes: (N, 4+K)
|
38 |
+
x_min_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
39 |
+
shift (x_min, y_min) in x dimension
|
40 |
+
y_min_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
41 |
+
shift (x_min, y_min) in y dimension
|
42 |
+
x_max_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
43 |
+
shift (x_max, y_max) in x dimension
|
44 |
+
y_max_shift: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
45 |
+
shift (x_max, y_max) in y dimension
|
46 |
+
copy: bool
|
47 |
+
"""
|
48 |
+
boxes = np.array(boxes, dtype=np.float32, copy=copy)
|
49 |
+
|
50 |
+
x_min_shift = np.asarray(x_min_shift, np.float32)
|
51 |
+
y_min_shift = np.asarray(y_min_shift, np.float32)
|
52 |
+
x_max_shift = np.asarray(x_max_shift, np.float32)
|
53 |
+
y_max_shift = np.asarray(y_max_shift, np.float32)
|
54 |
+
|
55 |
+
x_min_shift = assert_and_normalize_shape(x_min_shift, boxes.shape[0])
|
56 |
+
y_min_shift = assert_and_normalize_shape(y_min_shift, boxes.shape[0])
|
57 |
+
x_max_shift = assert_and_normalize_shape(x_max_shift, boxes.shape[0])
|
58 |
+
y_max_shift = assert_and_normalize_shape(y_max_shift, boxes.shape[0])
|
59 |
+
|
60 |
+
boxes[:, 0] += x_min_shift
|
61 |
+
boxes[:, 1] += y_min_shift
|
62 |
+
boxes[:, 2] += x_max_shift
|
63 |
+
boxes[:, 3] += y_max_shift
|
64 |
+
return boxes
|
65 |
+
|
66 |
+
|
67 |
+
def inflate_or_deflate_boxes(boxes, width_delta=0, height_delta=0, copy=True):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
boxes: (N, 4+K)
|
71 |
+
width_delta: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
72 |
+
height_delta: array-like whose shape is (), (1,), (N,), (1, 1) or (N, 1)
|
73 |
+
copy: bool
|
74 |
+
"""
|
75 |
+
boxes = np.array(boxes, dtype=np.float32, copy=copy)
|
76 |
+
|
77 |
+
width_delta = np.asarray(width_delta, np.float32)
|
78 |
+
height_delta = np.asarray(height_delta, np.float32)
|
79 |
+
|
80 |
+
width_delta = assert_and_normalize_shape(width_delta, boxes.shape[0])
|
81 |
+
height_delta = assert_and_normalize_shape(height_delta, boxes.shape[0])
|
82 |
+
|
83 |
+
half_width_delta = width_delta * 0.5
|
84 |
+
half_height_delta = height_delta * 0.5
|
85 |
+
boxes[:, 0] -= half_width_delta
|
86 |
+
boxes[:, 1] -= half_height_delta
|
87 |
+
boxes[:, 2] += half_width_delta
|
88 |
+
boxes[:, 3] += half_height_delta
|
89 |
+
return boxes
|
90 |
+
|
91 |
+
|
92 |
+
def inflate_boxes_to_square(boxes, copy=True):
|
93 |
+
"""Inflate boxes to square
|
94 |
+
Args:
|
95 |
+
boxes: (N, 4+K)
|
96 |
+
copy: bool
|
97 |
+
"""
|
98 |
+
boxes = np.array(boxes, dtype=np.float32, copy=copy)
|
99 |
+
|
100 |
+
widths = boxes[:, 2] - boxes[:, 0]
|
101 |
+
heights = boxes[:, 3] - boxes[:, 1]
|
102 |
+
max_side_lengths = np.maximum(widths, heights)
|
103 |
+
|
104 |
+
width_deltas = np.subtract(max_side_lengths, widths, widths)
|
105 |
+
height_deltas = np.subtract(max_side_lengths, heights, heights)
|
106 |
+
width_deltas *= 0.5
|
107 |
+
height_deltas *= 0.5
|
108 |
+
boxes[:, 0] -= width_deltas
|
109 |
+
boxes[:, 1] -= height_deltas
|
110 |
+
boxes[:, 2] += width_deltas
|
111 |
+
boxes[:, 3] += height_deltas
|
112 |
+
return boxes
|
113 |
+
|
114 |
+
|
115 |
+
def deflate_boxes_to_square(boxes, copy=True):
|
116 |
+
"""Deflate boxes to square
|
117 |
+
Args:
|
118 |
+
boxes: (N, 4+K)
|
119 |
+
copy: bool
|
120 |
+
"""
|
121 |
+
boxes = np.array(boxes, dtype=np.float32, copy=copy)
|
122 |
+
|
123 |
+
widths = boxes[:, 2] - boxes[:, 0]
|
124 |
+
heights = boxes[:, 3] - boxes[:, 1]
|
125 |
+
min_side_lengths = np.minimum(widths, heights)
|
126 |
+
|
127 |
+
width_deltas = np.subtract(min_side_lengths, widths, widths)
|
128 |
+
height_deltas = np.subtract(min_side_lengths, heights, heights)
|
129 |
+
width_deltas *= 0.5
|
130 |
+
height_deltas *= 0.5
|
131 |
+
boxes[:, 0] -= width_deltas
|
132 |
+
boxes[:, 1] -= height_deltas
|
133 |
+
boxes[:, 2] += width_deltas
|
134 |
+
boxes[:, 3] += height_deltas
|
135 |
+
return boxes
|
136 |
+
|
khandy/boxes/boxes_utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def assert_and_normalize_shape(x, length):
|
5 |
+
"""
|
6 |
+
Args:
|
7 |
+
x: ndarray
|
8 |
+
length: int
|
9 |
+
"""
|
10 |
+
if x.ndim == 0:
|
11 |
+
return x
|
12 |
+
elif x.ndim == 1:
|
13 |
+
if len(x) == 1:
|
14 |
+
return x
|
15 |
+
elif len(x) == length:
|
16 |
+
return x
|
17 |
+
else:
|
18 |
+
raise ValueError('Incompatible shape!')
|
19 |
+
elif x.ndim == 2:
|
20 |
+
if x.shape == (1, 1):
|
21 |
+
return np.squeeze(x, axis=-1)
|
22 |
+
elif x.shape == (length, 1):
|
23 |
+
return np.squeeze(x, axis=-1)
|
24 |
+
else:
|
25 |
+
raise ValueError('Incompatible shape!')
|
26 |
+
else:
|
27 |
+
raise ValueError('Incompatible ndim!')
|
28 |
+
|
khandy/dict_utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
|
5 |
+
def get_dict_first_item(dict_obj):
|
6 |
+
for key in dict_obj:
|
7 |
+
return key, dict_obj[key]
|
8 |
+
|
9 |
+
|
10 |
+
def sort_dict(dict_obj, key=None, reverse=False):
|
11 |
+
return OrderedDict(sorted(dict_obj.items(), key=key, reverse=reverse))
|
12 |
+
|
13 |
+
|
14 |
+
def create_multidict(key_list, value_list):
|
15 |
+
assert len(key_list) == len(value_list)
|
16 |
+
multidict_obj = {}
|
17 |
+
for key, value in zip(key_list, value_list):
|
18 |
+
multidict_obj.setdefault(key, []).append(value)
|
19 |
+
return multidict_obj
|
20 |
+
|
21 |
+
|
22 |
+
def convert_multidict_to_list(multidict_obj):
|
23 |
+
key_list, value_list = [], []
|
24 |
+
for key, value in multidict_obj.items():
|
25 |
+
key_list += [key] * len(value)
|
26 |
+
value_list += value
|
27 |
+
return key_list, value_list
|
28 |
+
|
29 |
+
|
30 |
+
def convert_multidict_to_records(multidict_obj, key_map=None, raise_if_key_error=True):
|
31 |
+
records = []
|
32 |
+
if key_map is None:
|
33 |
+
for key in multidict_obj:
|
34 |
+
for value in multidict_obj[key]:
|
35 |
+
records.append('{},{}'.format(value, key))
|
36 |
+
else:
|
37 |
+
for key in multidict_obj:
|
38 |
+
if raise_if_key_error:
|
39 |
+
mapped_key = key_map[key]
|
40 |
+
else:
|
41 |
+
mapped_key = key_map.get(key, key)
|
42 |
+
for value in multidict_obj[key]:
|
43 |
+
records.append('{},{}'.format(value, mapped_key))
|
44 |
+
return records
|
45 |
+
|
46 |
+
|
47 |
+
def sample_multidict(multidict_obj, num_keys, num_per_key=None):
|
48 |
+
num_keys = min(num_keys, len(multidict_obj))
|
49 |
+
sub_keys = random.sample(list(multidict_obj), num_keys)
|
50 |
+
if num_per_key is None:
|
51 |
+
sub_mdict = {key: multidict_obj[key] for key in sub_keys}
|
52 |
+
else:
|
53 |
+
sub_mdict = {}
|
54 |
+
for key in sub_keys:
|
55 |
+
num_examples_inner = min(num_per_key, len(multidict_obj[key]))
|
56 |
+
sub_mdict[key] = random.sample(multidict_obj[key], num_examples_inner)
|
57 |
+
return sub_mdict
|
58 |
+
|
59 |
+
|
60 |
+
def split_multidict_on_key(multidict_obj, split_ratio, use_shuffle=False):
|
61 |
+
"""Split multidict_obj on its key.
|
62 |
+
"""
|
63 |
+
assert isinstance(multidict_obj, dict)
|
64 |
+
assert isinstance(split_ratio, (list, tuple))
|
65 |
+
|
66 |
+
pdf = [k / float(sum(split_ratio)) for k in split_ratio]
|
67 |
+
cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
|
68 |
+
indices = [int(round(len(multidict_obj) * k)) for k in cdf]
|
69 |
+
dict_keys = list(multidict_obj)
|
70 |
+
if use_shuffle:
|
71 |
+
random.shuffle(dict_keys)
|
72 |
+
|
73 |
+
be_split_list = []
|
74 |
+
for i in range(len(split_ratio)):
|
75 |
+
part_keys = dict_keys[indices[i]: indices[i + 1]]
|
76 |
+
part_dict = dict([(key, multidict_obj[key]) for key in part_keys])
|
77 |
+
be_split_list.append(part_dict)
|
78 |
+
return be_split_list
|
79 |
+
|
80 |
+
|
81 |
+
def split_multidict_on_value(multidict_obj, split_ratio, use_shuffle=False):
|
82 |
+
"""Split multidict_obj on its value.
|
83 |
+
"""
|
84 |
+
assert isinstance(multidict_obj, dict)
|
85 |
+
assert isinstance(split_ratio, (list, tuple))
|
86 |
+
|
87 |
+
pdf = [k / float(sum(split_ratio)) for k in split_ratio]
|
88 |
+
cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
|
89 |
+
be_split_list = [dict() for k in range(len(split_ratio))]
|
90 |
+
for key, value in multidict_obj.items():
|
91 |
+
indices = [int(round(len(value) * k)) for k in cdf]
|
92 |
+
cloned = value[:]
|
93 |
+
if use_shuffle:
|
94 |
+
random.shuffle(cloned)
|
95 |
+
for i in range(len(split_ratio)):
|
96 |
+
be_split_list[i][key] = cloned[indices[i]: indices[i + 1]]
|
97 |
+
return be_split_list
|
98 |
+
|
99 |
+
|
100 |
+
def get_multidict_info(multidict_obj, with_print=False, desc=None):
|
101 |
+
num_list = [len(val) for val in multidict_obj.values()]
|
102 |
+
num_keys = len(num_list)
|
103 |
+
num_values = sum(num_list)
|
104 |
+
max_values_per_key = max(num_list)
|
105 |
+
min_values_per_key = min(num_list)
|
106 |
+
if num_keys == 0:
|
107 |
+
avg_values_per_key = 0
|
108 |
+
else:
|
109 |
+
avg_values_per_key = num_values / num_keys
|
110 |
+
info = {
|
111 |
+
'num_keys': num_keys,
|
112 |
+
'num_values': num_values,
|
113 |
+
'max_values_per_key': max_values_per_key,
|
114 |
+
'min_values_per_key': min_values_per_key,
|
115 |
+
'avg_values_per_key': avg_values_per_key,
|
116 |
+
}
|
117 |
+
if with_print:
|
118 |
+
desc = desc or '<unknown>'
|
119 |
+
print('{} key number: {}'.format(desc, info['num_keys']))
|
120 |
+
print('{} value number: {}'.format(desc, info['num_values']))
|
121 |
+
print('{} max number per-key: {}'.format(desc, info['max_values_per_key']))
|
122 |
+
print('{} min number per-key: {}'.format(desc, info['min_values_per_key']))
|
123 |
+
print('{} avg number per-key: {:.2f}'.format(desc, info['avg_values_per_key']))
|
124 |
+
return info
|
125 |
+
|
126 |
+
|
127 |
+
def filter_multidict_by_number(multidict_obj, lower, upper=None):
|
128 |
+
if upper is None:
|
129 |
+
return {key: value for key, value in multidict_obj.items()
|
130 |
+
if lower <= len(value) }
|
131 |
+
else:
|
132 |
+
assert lower <= upper, 'lower must not be greater than upper'
|
133 |
+
return {key: value for key, value in multidict_obj.items()
|
134 |
+
if lower <= len(value) <= upper }
|
135 |
+
|
136 |
+
|
137 |
+
def sort_multidict_by_number(multidict_obj, num_keys_to_keep=None, reverse=True):
|
138 |
+
"""
|
139 |
+
Args:
|
140 |
+
reverse: sort in ascending order when is True.
|
141 |
+
"""
|
142 |
+
if num_keys_to_keep is None:
|
143 |
+
num_keys_to_keep = len(multidict_obj)
|
144 |
+
else:
|
145 |
+
num_keys_to_keep = min(num_keys_to_keep, len(multidict_obj))
|
146 |
+
sorted_items = sorted(multidict_obj.items(), key=lambda x: len(x[1]), reverse=reverse)
|
147 |
+
filtered_dict = OrderedDict()
|
148 |
+
for i in range(num_keys_to_keep):
|
149 |
+
filtered_dict[sorted_items[i][0]] = sorted_items[i][1]
|
150 |
+
return filtered_dict
|
151 |
+
|
152 |
+
|
153 |
+
def merge_multidict(*mdicts):
|
154 |
+
merged_multidict = {}
|
155 |
+
for item in mdicts:
|
156 |
+
for key, value in item.items():
|
157 |
+
merged_multidict.setdefault(key, []).extend(value)
|
158 |
+
return merged_multidict
|
159 |
+
|
160 |
+
|
161 |
+
def invert_multidict(multidict_obj):
|
162 |
+
inverted_dict = {}
|
163 |
+
for key, value in multidict_obj.items():
|
164 |
+
for item in value:
|
165 |
+
inverted_dict.setdefault(item, []).append(key)
|
166 |
+
return inverted_dict
|
167 |
+
|
168 |
+
|
khandy/draw_utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import PIL
|
3 |
+
from PIL import Image
|
4 |
+
from PIL import ImageDraw
|
5 |
+
from PIL import ImageFont
|
6 |
+
from PIL import ImageColor
|
7 |
+
|
8 |
+
|
9 |
+
def _is_legal_color(color):
|
10 |
+
if color is None:
|
11 |
+
return True
|
12 |
+
if isinstance(color, str):
|
13 |
+
return True
|
14 |
+
return isinstance(color, (tuple, list)) and len(color) == 3
|
15 |
+
|
16 |
+
|
17 |
+
def _normalize_color(color, pil_mode, swap_rgb=False):
|
18 |
+
if color is None:
|
19 |
+
return color
|
20 |
+
if isinstance(color, str):
|
21 |
+
color = ImageColor.getrgb(color)
|
22 |
+
gray = color[0]
|
23 |
+
if swap_rgb:
|
24 |
+
color = (color[2], color[1], color[0])
|
25 |
+
if pil_mode == 'L':
|
26 |
+
color = gray
|
27 |
+
return color
|
28 |
+
|
29 |
+
|
30 |
+
def draw_text(image, text, position, color=(255,0,0), font=None, font_size=15):
|
31 |
+
"""Draws text on given image.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
image (ndarray).
|
35 |
+
text (str): text to be drawn.
|
36 |
+
position (Tuple[int, int]): position where to be drawn.
|
37 |
+
color (List[Union[str, Tuple[int, int, int]]]): text color.
|
38 |
+
font (str): A filename or file-like object containing a TrueType font. If the file is not found in this
|
39 |
+
filename, the loader may also search in other directories, such as the `fonts/` directory on Windows
|
40 |
+
or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
|
41 |
+
font_size (int): The requested font size in points.
|
42 |
+
|
43 |
+
References:
|
44 |
+
torchvision.utils.draw_bounding_boxes
|
45 |
+
"""
|
46 |
+
if isinstance(image, np.ndarray):
|
47 |
+
# PIL.Image.fromarray fails with uint16 arrays
|
48 |
+
# https://github.com/python-pillow/Pillow/issues/1514
|
49 |
+
if (image.dtype == np.uint16) and (image.ndim != 2):
|
50 |
+
image = (image / 256).astype(np.uint8)
|
51 |
+
pil_image = Image.fromarray(image)
|
52 |
+
elif isinstance(image, PIL.Image.Image):
|
53 |
+
pil_image = image
|
54 |
+
else:
|
55 |
+
raise TypeError('Unsupported image type!')
|
56 |
+
assert pil_image.mode in ['L', 'RGB', 'RGBA']
|
57 |
+
|
58 |
+
assert _is_legal_color(color)
|
59 |
+
color = _normalize_color(color, pil_image.mode, isinstance(image, np.ndarray))
|
60 |
+
|
61 |
+
if font is None:
|
62 |
+
font_object = ImageFont.load_default()
|
63 |
+
else:
|
64 |
+
font_object = ImageFont.truetype(font, size=font_size)
|
65 |
+
|
66 |
+
draw = ImageDraw.Draw(pil_image)
|
67 |
+
draw.text((position[0], position[1]), text,
|
68 |
+
fill=color, font=font_object)
|
69 |
+
|
70 |
+
if isinstance(image, np.ndarray):
|
71 |
+
return np.asarray(pil_image)
|
72 |
+
return pil_image
|
73 |
+
|
74 |
+
|
75 |
+
def draw_bounding_boxes(image, boxes, labels=None, colors=None,
|
76 |
+
fill=False, width=1, font=None, font_size=15):
|
77 |
+
"""Draws bounding boxes on given image.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
image (ndarray).
|
81 |
+
boxes (ndarray): ndarray of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format.
|
82 |
+
labels (List[str]): List containing the labels of bounding boxes.
|
83 |
+
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes or labels.
|
84 |
+
fill (bool): If `True` fills the bounding box with specified color.
|
85 |
+
width (int): Width of bounding box.
|
86 |
+
font (str): A filename or file-like object containing a TrueType font. If the file is not found in this
|
87 |
+
filename, the loader may also search in other directories, such as the `fonts/` directory on Windows
|
88 |
+
or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
|
89 |
+
font_size (int): The requested font size in points.
|
90 |
+
|
91 |
+
References:
|
92 |
+
torchvision.utils.draw_bounding_boxes
|
93 |
+
"""
|
94 |
+
if isinstance(image, np.ndarray):
|
95 |
+
# PIL.Image.fromarray fails with uint16 arrays
|
96 |
+
# https://github.com/python-pillow/Pillow/issues/1514
|
97 |
+
if (image.dtype == np.uint16) and (image.ndim != 2):
|
98 |
+
image = (image / 256).astype(np.uint8)
|
99 |
+
pil_image = Image.fromarray(image)
|
100 |
+
elif isinstance(image, PIL.Image.Image):
|
101 |
+
pil_image = image
|
102 |
+
else:
|
103 |
+
raise TypeError('Unsupported image type!')
|
104 |
+
pil_image = pil_image.convert('RGB')
|
105 |
+
|
106 |
+
if font is None:
|
107 |
+
font_object = ImageFont.load_default()
|
108 |
+
else:
|
109 |
+
font_object = ImageFont.truetype(font, size=font_size)
|
110 |
+
|
111 |
+
if fill:
|
112 |
+
draw = ImageDraw.Draw(pil_image, "RGBA")
|
113 |
+
else:
|
114 |
+
draw = ImageDraw.Draw(pil_image)
|
115 |
+
|
116 |
+
for i, bbox in enumerate(boxes):
|
117 |
+
if colors is None:
|
118 |
+
color = None
|
119 |
+
else:
|
120 |
+
color = colors[i]
|
121 |
+
|
122 |
+
assert _is_legal_color(color)
|
123 |
+
color = _normalize_color(color, pil_image.mode, isinstance(image, np.ndarray))
|
124 |
+
|
125 |
+
if fill:
|
126 |
+
if color is None:
|
127 |
+
fill_color = (255, 255, 255, 100)
|
128 |
+
elif isinstance(color, str):
|
129 |
+
# This will automatically raise Error if rgb cannot be parsed.
|
130 |
+
fill_color = ImageColor.getrgb(color) + (100,)
|
131 |
+
elif isinstance(color, tuple):
|
132 |
+
fill_color = color + (100,)
|
133 |
+
# the first argument of ImageDraw.rectangle:
|
134 |
+
# in old version only supports [(x0, y0), (x1, y1)]
|
135 |
+
# in new version supports either [(x0, y0), (x1, y1)] or [x0, y0, x1, y1]
|
136 |
+
draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], width=width, outline=color, fill=fill_color)
|
137 |
+
else:
|
138 |
+
draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], width=width, outline=color)
|
139 |
+
|
140 |
+
if labels is not None:
|
141 |
+
margin = width + 1
|
142 |
+
draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=font_object)
|
143 |
+
|
144 |
+
if isinstance(image, np.ndarray):
|
145 |
+
return np.asarray(pil_image)
|
146 |
+
return pil_image
|
147 |
+
|
148 |
+
|
khandy/feature_utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import khandy
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def convert_feature_dict_to_array(feature_dict):
|
8 |
+
one_feature = khandy.get_dict_first_item(feature_dict)[1]
|
9 |
+
num_features = sum([len(item) for item in feature_dict.values()])
|
10 |
+
|
11 |
+
key_list = []
|
12 |
+
start_index = 0
|
13 |
+
feature_array = np.empty((num_features, one_feature.shape[-1]), one_feature.dtype)
|
14 |
+
for key, value in feature_dict.items():
|
15 |
+
feature_array[start_index: start_index + len(value)]= value
|
16 |
+
key_list += [key] * len(value)
|
17 |
+
start_index += len(value)
|
18 |
+
return key_list, feature_array
|
19 |
+
|
20 |
+
|
21 |
+
def convert_feature_array_to_dict(key_list, feature_array):
|
22 |
+
assert len(key_list) == len(feature_array)
|
23 |
+
feature_dict = OrderedDict()
|
24 |
+
for key, feat in zip(key_list, feature_array):
|
25 |
+
feature_dict.setdefault(key, []).append(feat)
|
26 |
+
for label in feature_dict.keys():
|
27 |
+
feature_dict[label] = np.vstack(feature_dict[label])
|
28 |
+
return feature_dict
|
29 |
+
|
30 |
+
|
31 |
+
def pairwise_distances(x, y, squared=True):
|
32 |
+
"""Compute pairwise (squared) Euclidean distances.
|
33 |
+
|
34 |
+
References:
|
35 |
+
[2016 CVPR] Deep Metric Learning via Lifted Structured Feature Embedding
|
36 |
+
`euclidean_distances` from sklearn
|
37 |
+
"""
|
38 |
+
assert isinstance(x, np.ndarray) and x.ndim == 2
|
39 |
+
assert isinstance(y, np.ndarray) and y.ndim == 2
|
40 |
+
assert x.shape[1] == y.shape[1]
|
41 |
+
|
42 |
+
x_square = np.expand_dims(np.einsum('ij,ij->i', x, x), axis=1)
|
43 |
+
if x is y:
|
44 |
+
y_square = x_square.T
|
45 |
+
else:
|
46 |
+
y_square = np.expand_dims(np.einsum('ij,ij->i', y, y), axis=0)
|
47 |
+
distances = np.dot(x, y.T)
|
48 |
+
# use inplace operation to accelerate
|
49 |
+
distances *= -2
|
50 |
+
distances += x_square
|
51 |
+
distances += y_square
|
52 |
+
# result maybe less than 0 due to floating point rounding errors.
|
53 |
+
np.maximum(distances, 0, distances)
|
54 |
+
if x is y:
|
55 |
+
# Ensure that distances between vectors and themselves are set to 0.0.
|
56 |
+
# This may not be the case due to floating point rounding errors.
|
57 |
+
distances.flat[::distances.shape[0] + 1] = 0.0
|
58 |
+
if not squared:
|
59 |
+
np.sqrt(distances, distances)
|
60 |
+
return distances
|
61 |
+
|
62 |
+
|
khandy/file_io_utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import numbers
|
4 |
+
import pickle
|
5 |
+
import warnings
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
|
9 |
+
def load_list(filename, encoding='utf-8', start=0, stop=None):
|
10 |
+
assert isinstance(start, numbers.Integral) and start >= 0
|
11 |
+
assert (stop is None) or (isinstance(stop, numbers.Integral) and stop > start)
|
12 |
+
|
13 |
+
lines = []
|
14 |
+
with open(filename, 'r', encoding=encoding) as f:
|
15 |
+
for _ in range(start):
|
16 |
+
f.readline()
|
17 |
+
for k, line in enumerate(f):
|
18 |
+
if (stop is not None) and (k + start > stop):
|
19 |
+
break
|
20 |
+
lines.append(line.rstrip('\n'))
|
21 |
+
return lines
|
22 |
+
|
23 |
+
|
24 |
+
def save_list(filename, list_obj, encoding='utf-8', append_break=True):
|
25 |
+
with open(filename, 'w', encoding=encoding) as f:
|
26 |
+
if append_break:
|
27 |
+
for item in list_obj:
|
28 |
+
f.write(str(item) + '\n')
|
29 |
+
else:
|
30 |
+
for item in list_obj:
|
31 |
+
f.write(str(item))
|
32 |
+
|
33 |
+
|
34 |
+
def load_json(filename, encoding='utf-8'):
|
35 |
+
with open(filename, 'r', encoding=encoding) as f:
|
36 |
+
data = json.load(f, object_pairs_hook=OrderedDict)
|
37 |
+
return data
|
38 |
+
|
39 |
+
|
40 |
+
def save_json(filename, data, encoding='utf-8', indent=4, cls=None, sort_keys=False):
|
41 |
+
if not filename.endswith('.json'):
|
42 |
+
filename = filename + '.json'
|
43 |
+
with open(filename, 'w', encoding=encoding) as f:
|
44 |
+
json.dump(data, f, indent=indent, separators=(',',': '),
|
45 |
+
ensure_ascii=False, cls=cls, sort_keys=sort_keys)
|
46 |
+
|
47 |
+
|
48 |
+
def load_bytes(filename, use_base64: bool = False) -> bytes:
|
49 |
+
"""Open the file in bytes mode, read it, and close the file.
|
50 |
+
|
51 |
+
References:
|
52 |
+
pathlib.Path.read_bytes
|
53 |
+
"""
|
54 |
+
with open(filename, 'rb') as f:
|
55 |
+
data = f.read()
|
56 |
+
if use_base64:
|
57 |
+
data = base64.b64encode(data)
|
58 |
+
return data
|
59 |
+
|
60 |
+
|
61 |
+
def save_bytes(filename, data: bytes, use_base64: bool = False) -> int:
|
62 |
+
"""Open the file in bytes mode, write to it, and close the file.
|
63 |
+
|
64 |
+
References:
|
65 |
+
pathlib.Path.write_bytes
|
66 |
+
"""
|
67 |
+
if use_base64:
|
68 |
+
data = base64.b64decode(data)
|
69 |
+
with open(filename, 'wb') as f:
|
70 |
+
ret = f.write(data)
|
71 |
+
return ret
|
72 |
+
|
73 |
+
|
74 |
+
def load_as_base64(filename) -> bytes:
|
75 |
+
warnings.warn('khandy.load_as_base64 will be deprecated, use khandy.load_bytes instead!')
|
76 |
+
return load_bytes(filename, True)
|
77 |
+
|
78 |
+
|
79 |
+
def load_object(filename):
|
80 |
+
with open(filename, 'rb') as f:
|
81 |
+
return pickle.load(f)
|
82 |
+
|
83 |
+
|
84 |
+
def save_object(filename, obj):
|
85 |
+
with open(filename, 'wb') as f:
|
86 |
+
pickle.dump(obj, f)
|
87 |
+
|
khandy/fs_utils.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import shutil
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
|
7 |
+
def get_path_stem(path):
|
8 |
+
"""
|
9 |
+
References:
|
10 |
+
`std::filesystem::path::stem` since C++17
|
11 |
+
"""
|
12 |
+
return os.path.splitext(os.path.basename(path))[0]
|
13 |
+
|
14 |
+
|
15 |
+
def replace_path_stem(path, new_stem):
|
16 |
+
dirname, basename = os.path.split(path)
|
17 |
+
stem, extension = os.path.splitext(basename)
|
18 |
+
if isinstance(new_stem, str):
|
19 |
+
return os.path.join(dirname, new_stem + extension)
|
20 |
+
elif hasattr(new_stem, '__call__'):
|
21 |
+
return os.path.join(dirname, new_stem(stem) + extension)
|
22 |
+
else:
|
23 |
+
raise TypeError('Unsupported Type!')
|
24 |
+
|
25 |
+
|
26 |
+
def get_path_extension(path):
|
27 |
+
"""
|
28 |
+
References:
|
29 |
+
`std::filesystem::path::extension` since C++17
|
30 |
+
|
31 |
+
Notes:
|
32 |
+
Not fully consistent with `std::filesystem::path::extension`
|
33 |
+
"""
|
34 |
+
return os.path.splitext(os.path.basename(path))[1]
|
35 |
+
|
36 |
+
|
37 |
+
def replace_path_extension(path, new_extension=None):
|
38 |
+
"""Replaces the extension with new_extension or removes it when the default value is used.
|
39 |
+
Firstly, if this path has an extension, it is removed. Then, a dot character is appended
|
40 |
+
to the pathname, if new_extension is not empty or does not begin with a dot character.
|
41 |
+
|
42 |
+
References:
|
43 |
+
`std::filesystem::path::replace_extension` since C++17
|
44 |
+
"""
|
45 |
+
filename_wo_ext = os.path.splitext(path)[0]
|
46 |
+
if new_extension == '' or new_extension is None:
|
47 |
+
return filename_wo_ext
|
48 |
+
elif new_extension.startswith('.'):
|
49 |
+
return ''.join([filename_wo_ext, new_extension])
|
50 |
+
else:
|
51 |
+
return '.'.join([filename_wo_ext, new_extension])
|
52 |
+
|
53 |
+
|
54 |
+
def normalize_extension(extension):
|
55 |
+
if extension.startswith('.'):
|
56 |
+
new_extension = extension.lower()
|
57 |
+
else:
|
58 |
+
new_extension = '.' + extension.lower()
|
59 |
+
return new_extension
|
60 |
+
|
61 |
+
|
62 |
+
def is_path_in_extensions(path, extensions):
|
63 |
+
if isinstance(extensions, str):
|
64 |
+
extensions = [extensions]
|
65 |
+
extensions = [normalize_extension(item) for item in extensions]
|
66 |
+
extension = get_path_extension(path)
|
67 |
+
return extension.lower() in extensions
|
68 |
+
|
69 |
+
|
70 |
+
def normalize_path(path, norm_case=True):
|
71 |
+
"""
|
72 |
+
References:
|
73 |
+
https://en.cppreference.com/w/cpp/filesystem/canonical
|
74 |
+
"""
|
75 |
+
# On Unix and Windows, return the argument with an initial
|
76 |
+
# component of ~ or ~user replaced by that user's home directory.
|
77 |
+
path = os.path.expanduser(path)
|
78 |
+
# Return a normalized absolutized version of the pathname path.
|
79 |
+
# On most platforms, this is equivalent to calling the function
|
80 |
+
# normpath() as follows: normpath(join(os.getcwd(), path)).
|
81 |
+
path = os.path.abspath(path)
|
82 |
+
if norm_case:
|
83 |
+
# Normalize the case of a pathname. On Windows,
|
84 |
+
# convert all characters in the pathname to lowercase,
|
85 |
+
# and also convert forward slashes to backward slashes.
|
86 |
+
# On other operating systems, return the path unchanged.
|
87 |
+
path = os.path.normcase(path)
|
88 |
+
return path
|
89 |
+
|
90 |
+
|
91 |
+
def makedirs(name, mode=0o755):
|
92 |
+
"""
|
93 |
+
References:
|
94 |
+
mmcv.mkdir_or_exist
|
95 |
+
"""
|
96 |
+
warnings.warn('`makedirs` will be deprecated!')
|
97 |
+
if name == '':
|
98 |
+
return
|
99 |
+
name = os.path.expanduser(name)
|
100 |
+
os.makedirs(name, mode=mode, exist_ok=True)
|
101 |
+
|
102 |
+
|
103 |
+
def listdirs(paths, path_sep=None, full_path=True):
|
104 |
+
"""Enhancement on `os.listdir`
|
105 |
+
"""
|
106 |
+
warnings.warn('`listdirs` will be deprecated!')
|
107 |
+
assert isinstance(paths, (str, tuple, list))
|
108 |
+
if isinstance(paths, str):
|
109 |
+
path_sep = path_sep or os.path.pathsep
|
110 |
+
paths = paths.split(path_sep)
|
111 |
+
|
112 |
+
all_filenames = []
|
113 |
+
for path in paths:
|
114 |
+
path_ex = os.path.expanduser(path)
|
115 |
+
filenames = os.listdir(path_ex)
|
116 |
+
if full_path:
|
117 |
+
filenames = [os.path.join(path_ex, filename) for filename in filenames]
|
118 |
+
all_filenames.extend(filenames)
|
119 |
+
return all_filenames
|
120 |
+
|
121 |
+
|
122 |
+
def get_all_filenames(path, extensions=None, is_valid_file=None):
|
123 |
+
warnings.warn('`get_all_filenames` will be deprecated, use `list_files_in_dir` with `recursive=True` instead!')
|
124 |
+
if (extensions is not None) and (is_valid_file is not None):
|
125 |
+
raise ValueError("Both extensions and is_valid_file cannot "
|
126 |
+
"be not None at the same time")
|
127 |
+
if is_valid_file is None:
|
128 |
+
if extensions is not None:
|
129 |
+
def is_valid_file(filename):
|
130 |
+
return is_path_in_extensions(filename, extensions)
|
131 |
+
else:
|
132 |
+
def is_valid_file(filename):
|
133 |
+
return True
|
134 |
+
|
135 |
+
all_filenames = []
|
136 |
+
path_ex = os.path.expanduser(path)
|
137 |
+
for root, _, filenames in sorted(os.walk(path_ex, followlinks=True)):
|
138 |
+
for filename in sorted(filenames):
|
139 |
+
fullname = os.path.join(root, filename)
|
140 |
+
if is_valid_file(fullname):
|
141 |
+
all_filenames.append(fullname)
|
142 |
+
return all_filenames
|
143 |
+
|
144 |
+
|
145 |
+
def get_top_level_dirs(path, full_path=True):
|
146 |
+
warnings.warn('`get_top_level_dirs` will be deprecated, use `list_dirs_in_dir` instead!')
|
147 |
+
if path is None:
|
148 |
+
path = os.getcwd()
|
149 |
+
path_ex = os.path.expanduser(path)
|
150 |
+
filenames = os.listdir(path_ex)
|
151 |
+
if full_path:
|
152 |
+
return [os.path.join(path_ex, item) for item in filenames
|
153 |
+
if os.path.isdir(os.path.join(path_ex, item))]
|
154 |
+
else:
|
155 |
+
return [item for item in filenames
|
156 |
+
if os.path.isdir(os.path.join(path_ex, item))]
|
157 |
+
|
158 |
+
|
159 |
+
def get_top_level_files(path, full_path=True):
|
160 |
+
warnings.warn('`get_top_level_files` will be deprecated, use `list_files_in_dir` instead!')
|
161 |
+
if path is None:
|
162 |
+
path = os.getcwd()
|
163 |
+
path_ex = os.path.expanduser(path)
|
164 |
+
filenames = os.listdir(path_ex)
|
165 |
+
if full_path:
|
166 |
+
return [os.path.join(path_ex, item) for item in filenames
|
167 |
+
if os.path.isfile(os.path.join(path_ex, item))]
|
168 |
+
else:
|
169 |
+
return [item for item in filenames
|
170 |
+
if os.path.isfile(os.path.join(path_ex, item))]
|
171 |
+
|
172 |
+
|
173 |
+
def list_items_in_dir(path=None, recursive=False, full_path=True):
|
174 |
+
"""List all entries in directory
|
175 |
+
"""
|
176 |
+
if path is None:
|
177 |
+
path = os.getcwd()
|
178 |
+
path_ex = os.path.expanduser(path)
|
179 |
+
|
180 |
+
if not recursive:
|
181 |
+
names = os.listdir(path_ex)
|
182 |
+
if full_path:
|
183 |
+
return [os.path.join(path_ex, name) for name in sorted(names)]
|
184 |
+
else:
|
185 |
+
return sorted(names)
|
186 |
+
else:
|
187 |
+
all_names = []
|
188 |
+
for root, dirnames, filenames in sorted(os.walk(path_ex, followlinks=True)):
|
189 |
+
all_names += [os.path.join(root, name) for name in sorted(dirnames)]
|
190 |
+
all_names += [os.path.join(root, name) for name in sorted(filenames)]
|
191 |
+
return all_names
|
192 |
+
|
193 |
+
|
194 |
+
def list_dirs_in_dir(path=None, recursive=False, full_path=True):
|
195 |
+
"""List all dirs in directory
|
196 |
+
"""
|
197 |
+
if path is None:
|
198 |
+
path = os.getcwd()
|
199 |
+
path_ex = os.path.expanduser(path)
|
200 |
+
|
201 |
+
if not recursive:
|
202 |
+
names = os.listdir(path_ex)
|
203 |
+
if full_path:
|
204 |
+
return [os.path.join(path_ex, name) for name in sorted(names)
|
205 |
+
if os.path.isdir(os.path.join(path_ex, name))]
|
206 |
+
else:
|
207 |
+
return [name for name in sorted(names)
|
208 |
+
if os.path.isdir(os.path.join(path_ex, name))]
|
209 |
+
else:
|
210 |
+
all_names = []
|
211 |
+
for root, dirnames, _ in sorted(os.walk(path_ex, followlinks=True)):
|
212 |
+
all_names += [os.path.join(root, name) for name in sorted(dirnames)]
|
213 |
+
return all_names
|
214 |
+
|
215 |
+
|
216 |
+
def list_files_in_dir(path=None, recursive=False, full_path=True):
|
217 |
+
"""List all files in directory
|
218 |
+
"""
|
219 |
+
if path is None:
|
220 |
+
path = os.getcwd()
|
221 |
+
path_ex = os.path.expanduser(path)
|
222 |
+
|
223 |
+
if not recursive:
|
224 |
+
names = os.listdir(path_ex)
|
225 |
+
if full_path:
|
226 |
+
return [os.path.join(path_ex, name) for name in sorted(names)
|
227 |
+
if os.path.isfile(os.path.join(path_ex, name))]
|
228 |
+
else:
|
229 |
+
return [name for name in sorted(names)
|
230 |
+
if os.path.isfile(os.path.join(path_ex, name))]
|
231 |
+
else:
|
232 |
+
all_names = []
|
233 |
+
for root, _, filenames in sorted(os.walk(path_ex, followlinks=True)):
|
234 |
+
all_names += [os.path.join(root, name) for name in sorted(filenames)]
|
235 |
+
return all_names
|
236 |
+
|
237 |
+
|
238 |
+
def get_folder_size(dirname):
|
239 |
+
if not os.path.exists(dirname):
|
240 |
+
raise ValueError("Incorrect path: {}".format(dirname))
|
241 |
+
total_size = 0
|
242 |
+
for root, _, filenames in os.walk(dirname):
|
243 |
+
for name in filenames:
|
244 |
+
total_size += os.path.getsize(os.path.join(root, name))
|
245 |
+
return total_size
|
246 |
+
|
247 |
+
|
248 |
+
def escape_filename(filename, new_char='_'):
|
249 |
+
assert isinstance(new_char, str)
|
250 |
+
control_chars = ''.join((map(chr, range(0x00, 0x20))))
|
251 |
+
pattern = r'[\\/*?:"<>|{}]'.format(control_chars)
|
252 |
+
return re.sub(pattern, new_char, filename)
|
253 |
+
|
254 |
+
|
255 |
+
def replace_invalid_filename_char(filename, new_char='_'):
|
256 |
+
warnings.warn('`replace_invalid_filename_char` will be deprecated, use `escape_filename` instead!')
|
257 |
+
return escape_filename(filename, new_char)
|
258 |
+
|
259 |
+
|
260 |
+
def copy_file(src, dst_dir, action_if_exist='rename'):
|
261 |
+
"""
|
262 |
+
Args:
|
263 |
+
src: source file path
|
264 |
+
dst_dir: dest dir
|
265 |
+
action_if_exist:
|
266 |
+
None: same as shutil.copy
|
267 |
+
ignore: when dest file exists, don't copy and return None
|
268 |
+
rename: when dest file exists, copy after rename
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
dest filename
|
272 |
+
"""
|
273 |
+
dst = os.path.join(dst_dir, os.path.basename(src))
|
274 |
+
|
275 |
+
if action_if_exist is None:
|
276 |
+
os.makedirs(dst_dir, exist_ok=True)
|
277 |
+
shutil.copy(src, dst)
|
278 |
+
elif action_if_exist.lower() == 'ignore':
|
279 |
+
if os.path.exists(dst):
|
280 |
+
warnings.warn(f'{dst} already exists, do not copy!')
|
281 |
+
return dst
|
282 |
+
os.makedirs(dst_dir, exist_ok=True)
|
283 |
+
shutil.copy(src, dst)
|
284 |
+
elif action_if_exist.lower() == 'rename':
|
285 |
+
suffix = 2
|
286 |
+
stem, extension = os.path.splitext(os.path.basename(src))
|
287 |
+
while os.path.exists(dst):
|
288 |
+
dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
|
289 |
+
suffix += 1
|
290 |
+
os.makedirs(dst_dir, exist_ok=True)
|
291 |
+
shutil.copy(src, dst)
|
292 |
+
else:
|
293 |
+
raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
|
294 |
+
|
295 |
+
return dst
|
296 |
+
|
297 |
+
|
298 |
+
def move_file(src, dst_dir, action_if_exist='rename'):
|
299 |
+
"""
|
300 |
+
Args:
|
301 |
+
src: source file path
|
302 |
+
dst_dir: dest dir
|
303 |
+
action_if_exist:
|
304 |
+
None: same as shutil.move
|
305 |
+
ignore: when dest file exists, don't move and return None
|
306 |
+
rename: when dest file exists, move after rename
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
dest filename
|
310 |
+
"""
|
311 |
+
dst = os.path.join(dst_dir, os.path.basename(src))
|
312 |
+
|
313 |
+
if action_if_exist is None:
|
314 |
+
os.makedirs(dst_dir, exist_ok=True)
|
315 |
+
shutil.move(src, dst)
|
316 |
+
elif action_if_exist.lower() == 'ignore':
|
317 |
+
if os.path.exists(dst):
|
318 |
+
warnings.warn(f'{dst} already exists, do not move!')
|
319 |
+
return dst
|
320 |
+
os.makedirs(dst_dir, exist_ok=True)
|
321 |
+
shutil.move(src, dst)
|
322 |
+
elif action_if_exist.lower() == 'rename':
|
323 |
+
suffix = 2
|
324 |
+
stem, extension = os.path.splitext(os.path.basename(src))
|
325 |
+
while os.path.exists(dst):
|
326 |
+
dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
|
327 |
+
suffix += 1
|
328 |
+
os.makedirs(dst_dir, exist_ok=True)
|
329 |
+
shutil.move(src, dst)
|
330 |
+
else:
|
331 |
+
raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
|
332 |
+
|
333 |
+
return dst
|
334 |
+
|
335 |
+
|
336 |
+
def rename_file(src, dst, action_if_exist='rename'):
|
337 |
+
"""
|
338 |
+
Args:
|
339 |
+
src: source file path
|
340 |
+
dst: dest file path
|
341 |
+
action_if_exist:
|
342 |
+
None: same as os.rename
|
343 |
+
ignore: when dest file exists, don't rename and return None
|
344 |
+
rename: when dest file exists, rename it
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
dest filename
|
348 |
+
"""
|
349 |
+
if dst == src:
|
350 |
+
return dst
|
351 |
+
dst_dir = os.path.dirname(os.path.abspath(dst))
|
352 |
+
|
353 |
+
if action_if_exist is None:
|
354 |
+
os.makedirs(dst_dir, exist_ok=True)
|
355 |
+
os.rename(src, dst)
|
356 |
+
elif action_if_exist.lower() == 'ignore':
|
357 |
+
if os.path.exists(dst):
|
358 |
+
warnings.warn(f'{dst} already exists, do not rename!')
|
359 |
+
return dst
|
360 |
+
os.makedirs(dst_dir, exist_ok=True)
|
361 |
+
os.rename(src, dst)
|
362 |
+
elif action_if_exist.lower() == 'rename':
|
363 |
+
suffix = 2
|
364 |
+
stem, extension = os.path.splitext(os.path.basename(dst))
|
365 |
+
while os.path.exists(dst):
|
366 |
+
dst = os.path.join(dst_dir, f'{stem} ({suffix}){extension}')
|
367 |
+
suffix += 1
|
368 |
+
os.makedirs(dst_dir, exist_ok=True)
|
369 |
+
os.rename(src, dst)
|
370 |
+
else:
|
371 |
+
raise ValueError('Invalid action_if_exist, got {}.'.format(action_if_exist))
|
372 |
+
|
373 |
+
return dst
|
374 |
+
|
375 |
+
|
khandy/hash_utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
|
3 |
+
|
4 |
+
def calc_hash(content, hash_object=None):
|
5 |
+
hash_object = hash_object or hashlib.md5()
|
6 |
+
if isinstance(hash_object, str):
|
7 |
+
hash_object = hashlib.new(hash_object)
|
8 |
+
hash_object.update(content)
|
9 |
+
return hash_object.hexdigest()
|
10 |
+
|
11 |
+
|
12 |
+
def calc_file_hash(filename, hash_object=None, chunk_size=1024 * 1024):
|
13 |
+
hash_object = hash_object or hashlib.md5()
|
14 |
+
if isinstance(hash_object, str):
|
15 |
+
hash_object = hashlib.new(hash_object)
|
16 |
+
|
17 |
+
with open(filename, "rb") as f:
|
18 |
+
while True:
|
19 |
+
chunk = f.read(chunk_size)
|
20 |
+
if not chunk:
|
21 |
+
break
|
22 |
+
hash_object.update(chunk)
|
23 |
+
return hash_object.hexdigest()
|
24 |
+
|
25 |
+
|
khandy/image/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .align_and_crop import *
|
2 |
+
from .crop_or_pad import *
|
3 |
+
from .flip import *
|
4 |
+
from .image_hash import *
|
5 |
+
from .resize import *
|
6 |
+
from .rotate import *
|
7 |
+
from .translate import *
|
8 |
+
|
9 |
+
from .misc import *
|
10 |
+
|
khandy/image/align_and_crop.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def get_similarity_transform(src_pts, dst_pts):
|
6 |
+
"""Get similarity transform matrix from src_pts to dst_pts
|
7 |
+
|
8 |
+
Args:
|
9 |
+
src_pts: Kx2 np.array
|
10 |
+
source points matrix, each row is a pair of coordinates (x, y)
|
11 |
+
dst_pts: Kx2 np.array
|
12 |
+
destination points matrix, each row is a pair of coordinates (x, y)
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
xform_matrix: 3x3 np.array
|
16 |
+
transform matrix from src_pts to dst_pts
|
17 |
+
"""
|
18 |
+
src_pts = np.asarray(src_pts)
|
19 |
+
dst_pts = np.asarray(dst_pts)
|
20 |
+
assert src_pts.shape == dst_pts.shape
|
21 |
+
assert (src_pts.ndim == 2) and (src_pts.shape[-1] == 2)
|
22 |
+
|
23 |
+
npts = src_pts.shape[0]
|
24 |
+
src_x = src_pts[:, 0].reshape((-1, 1))
|
25 |
+
src_y = src_pts[:, 1].reshape((-1, 1))
|
26 |
+
tmp1 = np.hstack((src_x, -src_y, np.ones((npts, 1)), np.zeros((npts, 1))))
|
27 |
+
tmp2 = np.hstack((src_y, src_x, np.zeros((npts, 1)), np.ones((npts, 1))))
|
28 |
+
A = np.vstack((tmp1, tmp2))
|
29 |
+
|
30 |
+
dst_x = dst_pts[:, 0].reshape((-1, 1))
|
31 |
+
dst_y = dst_pts[:, 1].reshape((-1, 1))
|
32 |
+
b = np.vstack((dst_x, dst_y))
|
33 |
+
|
34 |
+
x = np.linalg.lstsq(A, b, rcond=-1)[0]
|
35 |
+
x = np.squeeze(x)
|
36 |
+
sc, ss, tx, ty = x[0], x[1], x[2], x[3]
|
37 |
+
xform_matrix = np.array([
|
38 |
+
[sc, -ss, tx],
|
39 |
+
[ss, sc, ty],
|
40 |
+
[ 0, 0, 1]
|
41 |
+
])
|
42 |
+
return xform_matrix
|
43 |
+
|
44 |
+
|
45 |
+
def align_and_crop(image, landmarks, std_landmarks, align_size,
|
46 |
+
border_value=0, return_transform_matrix=False):
|
47 |
+
landmarks = np.asarray(landmarks)
|
48 |
+
std_landmarks = np.asarray(std_landmarks)
|
49 |
+
xform_matrix = get_similarity_transform(landmarks, std_landmarks)
|
50 |
+
|
51 |
+
landmarks_ex = np.pad(landmarks, ((0,0),(0,1)), mode='constant', constant_values=1)
|
52 |
+
dst_landmarks = np.dot(landmarks_ex, xform_matrix[:2,:].T)
|
53 |
+
dst_image = cv2.warpAffine(image, xform_matrix[:2,:], dsize=align_size,
|
54 |
+
borderValue=border_value)
|
55 |
+
if return_transform_matrix:
|
56 |
+
return dst_image, dst_landmarks, xform_matrix
|
57 |
+
else:
|
58 |
+
return dst_image, dst_landmarks
|
59 |
+
|
60 |
+
|
khandy/image/crop_or_pad.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numbers
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import khandy
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def crop(image, x_min, y_min, x_max, y_max, border_value=0):
|
9 |
+
"""Crop the given image at specified rectangular area.
|
10 |
+
|
11 |
+
See Also:
|
12 |
+
translate_image
|
13 |
+
|
14 |
+
References:
|
15 |
+
PIL.Image.crop
|
16 |
+
tf.image.resize_image_with_crop_or_pad
|
17 |
+
"""
|
18 |
+
assert khandy.is_numpy_image(image)
|
19 |
+
assert isinstance(x_min, numbers.Integral) and isinstance(y_min, numbers.Integral)
|
20 |
+
assert isinstance(x_max, numbers.Integral) and isinstance(y_max, numbers.Integral)
|
21 |
+
assert (x_min <= x_max) and (y_min <= y_max)
|
22 |
+
|
23 |
+
src_height, src_width = image.shape[:2]
|
24 |
+
dst_height, dst_width = y_max - y_min + 1, x_max - x_min + 1
|
25 |
+
channels = 1 if image.ndim == 2 else image.shape[2]
|
26 |
+
|
27 |
+
if isinstance(border_value, (tuple, list)):
|
28 |
+
assert len(border_value) == channels, \
|
29 |
+
'Expected the num of elements in tuple equals the channels ' \
|
30 |
+
'of input image. Found {} vs {}'.format(
|
31 |
+
len(border_value), channels)
|
32 |
+
else:
|
33 |
+
border_value = (border_value,) * channels
|
34 |
+
dst_image = khandy.create_solid_color_image(
|
35 |
+
dst_width, dst_height, border_value, dtype=image.dtype)
|
36 |
+
|
37 |
+
src_x_begin = max(x_min, 0)
|
38 |
+
src_x_end = min(x_max + 1, src_width)
|
39 |
+
dst_x_begin = src_x_begin - x_min
|
40 |
+
dst_x_end = src_x_end - x_min
|
41 |
+
|
42 |
+
src_y_begin = max(y_min, 0)
|
43 |
+
src_y_end = min(y_max + 1, src_height)
|
44 |
+
dst_y_begin = src_y_begin - y_min
|
45 |
+
dst_y_end = src_y_end - y_min
|
46 |
+
|
47 |
+
if (src_x_begin >= src_x_end) or (src_y_begin >= src_y_end):
|
48 |
+
return dst_image
|
49 |
+
dst_image[dst_y_begin: dst_y_end, dst_x_begin: dst_x_end, ...] = \
|
50 |
+
image[src_y_begin: src_y_end, src_x_begin: src_x_end, ...]
|
51 |
+
return dst_image
|
52 |
+
|
53 |
+
|
54 |
+
def crop_or_pad(image, x_min, y_min, x_max, y_max, border_value=0):
|
55 |
+
warnings.warn('crop_or_pad will be deprecated, use crop instead!')
|
56 |
+
return crop(image, x_min, y_min, x_max, y_max, border_value)
|
57 |
+
|
58 |
+
|
59 |
+
def crop_coords(boxes, image_width, image_height):
|
60 |
+
"""
|
61 |
+
References:
|
62 |
+
`mmcv.impad`
|
63 |
+
`pad` in https://github.com/kpzhang93/MTCNN_face_detection_alignment
|
64 |
+
`MtcnnDetector.pad` in https://github.com/AITTSMD/MTCNN-Tensorflow
|
65 |
+
"""
|
66 |
+
x_mins = boxes[:, 0]
|
67 |
+
y_mins = boxes[:, 1]
|
68 |
+
x_maxs = boxes[:, 2]
|
69 |
+
y_maxs = boxes[:, 3]
|
70 |
+
dst_widths = x_maxs - x_mins + 1
|
71 |
+
dst_heights = y_maxs - y_mins + 1
|
72 |
+
|
73 |
+
src_x_begin = np.maximum(x_mins, 0)
|
74 |
+
src_x_end = np.minimum(x_maxs + 1, image_width)
|
75 |
+
dst_x_begin = src_x_begin - x_mins
|
76 |
+
dst_x_end = src_x_end - x_mins
|
77 |
+
|
78 |
+
src_y_begin = np.maximum(y_mins, 0)
|
79 |
+
src_y_end = np.minimum(y_maxs + 1, image_height)
|
80 |
+
dst_y_begin = src_y_begin - y_mins
|
81 |
+
dst_y_end = src_y_end - y_mins
|
82 |
+
|
83 |
+
coords = np.stack([dst_y_begin, dst_y_end, dst_x_begin, dst_x_end,
|
84 |
+
src_y_begin, src_y_end, src_x_begin, src_x_end,
|
85 |
+
dst_heights, dst_widths], axis=0)
|
86 |
+
return coords
|
87 |
+
|
88 |
+
|
89 |
+
def crop_or_pad_coords(boxes, image_width, image_height):
|
90 |
+
warnings.warn('crop_or_pad_coords will be deprecated, use crop_coords instead!')
|
91 |
+
return crop_coords(boxes, image_width, image_height)
|
92 |
+
|
93 |
+
|
94 |
+
def center_crop(image, dst_width, dst_height, strict=True):
|
95 |
+
"""
|
96 |
+
strict:
|
97 |
+
when True, raise error if src size is less than dst size.
|
98 |
+
when False, remain unchanged if src size is less than dst size, otherwise center crop.
|
99 |
+
"""
|
100 |
+
assert khandy.is_numpy_image(image)
|
101 |
+
assert isinstance(dst_width, numbers.Integral) and isinstance(dst_height, numbers.Integral)
|
102 |
+
src_height, src_width = image.shape[:2]
|
103 |
+
if strict:
|
104 |
+
assert (src_height >= dst_height) and (src_width >= dst_width)
|
105 |
+
|
106 |
+
crop_top = max((src_height - dst_height) // 2, 0)
|
107 |
+
crop_left = max((src_width - dst_width) // 2, 0)
|
108 |
+
cropped = image[crop_top: dst_height + crop_top,
|
109 |
+
crop_left: dst_width + crop_left, ...]
|
110 |
+
return cropped
|
111 |
+
|
112 |
+
|
113 |
+
def center_pad(image, dst_width, dst_height, strict=True):
|
114 |
+
"""
|
115 |
+
strict:
|
116 |
+
when True, raise error if src size is greater than dst size.
|
117 |
+
when False, remain unchanged if src size is greater than dst size, otherwise center pad.
|
118 |
+
"""
|
119 |
+
assert khandy.is_numpy_image(image)
|
120 |
+
assert isinstance(dst_width, numbers.Integral) and isinstance(dst_height, numbers.Integral)
|
121 |
+
|
122 |
+
src_height, src_width = image.shape[:2]
|
123 |
+
if strict:
|
124 |
+
assert (src_height <= dst_height) and (src_width <= dst_width)
|
125 |
+
|
126 |
+
padding_x = max(dst_width - src_width, 0)
|
127 |
+
padding_y = max(dst_height - src_height, 0)
|
128 |
+
padding_top = padding_y // 2
|
129 |
+
padding_left = padding_x // 2
|
130 |
+
if image.ndim == 2:
|
131 |
+
padding = ((padding_top, padding_y - padding_top),
|
132 |
+
(padding_left, padding_x - padding_left))
|
133 |
+
else:
|
134 |
+
padding = ((padding_top, padding_y - padding_top),
|
135 |
+
(padding_left, padding_x - padding_left), (0, 0))
|
136 |
+
return np.pad(image, padding, 'constant')
|
137 |
+
|
138 |
+
|
khandy/image/flip.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import khandy
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def flip_image(image, direction='h', copy=True):
|
6 |
+
"""
|
7 |
+
References:
|
8 |
+
np.flipud, np.fliplr, np.flip
|
9 |
+
cv2.flip
|
10 |
+
tf.image.flip_up_down
|
11 |
+
tf.image.flip_left_right
|
12 |
+
"""
|
13 |
+
assert khandy.is_numpy_image(image)
|
14 |
+
assert direction in ['x', 'h', 'horizontal',
|
15 |
+
'y', 'v', 'vertical',
|
16 |
+
'o', 'b', 'both']
|
17 |
+
if copy:
|
18 |
+
image = image.copy()
|
19 |
+
if direction in ['o', 'b', 'both', 'x', 'h', 'horizontal']:
|
20 |
+
image = np.fliplr(image)
|
21 |
+
if direction in ['o', 'b', 'both', 'y', 'v', 'vertical']:
|
22 |
+
image = np.flipud(image)
|
23 |
+
return image
|
24 |
+
|
25 |
+
|
26 |
+
def transpose_image(image, copy=True):
|
27 |
+
"""Transpose image.
|
28 |
+
|
29 |
+
References:
|
30 |
+
np.transpose
|
31 |
+
cv2.transpose
|
32 |
+
tf.image.transpose
|
33 |
+
"""
|
34 |
+
assert khandy.is_numpy_image(image)
|
35 |
+
if copy:
|
36 |
+
image = image.copy()
|
37 |
+
if image.ndim == 2:
|
38 |
+
transpose_axes = (1, 0)
|
39 |
+
else:
|
40 |
+
transpose_axes = (1, 0, 2)
|
41 |
+
image = np.transpose(image, transpose_axes)
|
42 |
+
return image
|
43 |
+
|
44 |
+
|
45 |
+
def rot90_image(image, n=1, copy=True):
|
46 |
+
"""Rotate image counter-clockwise by 90 degrees.
|
47 |
+
|
48 |
+
References:
|
49 |
+
np.rot90
|
50 |
+
cv2.rotate
|
51 |
+
tf.image.rot90
|
52 |
+
"""
|
53 |
+
assert khandy.is_numpy_image(image)
|
54 |
+
if copy:
|
55 |
+
image = image.copy()
|
56 |
+
if image.ndim == 2:
|
57 |
+
transpose_axes = (1, 0)
|
58 |
+
else:
|
59 |
+
transpose_axes = (1, 0, 2)
|
60 |
+
|
61 |
+
n = n % 4
|
62 |
+
if n == 0:
|
63 |
+
return image[:]
|
64 |
+
elif n == 1:
|
65 |
+
image = np.transpose(image, transpose_axes)
|
66 |
+
image = np.flipud(image)
|
67 |
+
elif n == 2:
|
68 |
+
image = np.fliplr(np.flipud(image))
|
69 |
+
else:
|
70 |
+
image = np.transpose(image, transpose_axes)
|
71 |
+
image = np.fliplr(image)
|
72 |
+
return image
|
khandy/image/image_hash.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import khandy
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def _convert_bool_matrix_to_int(bool_mat):
|
7 |
+
hash_val = int(0)
|
8 |
+
for item in bool_mat.flatten():
|
9 |
+
hash_val <<= 1
|
10 |
+
hash_val |= int(item)
|
11 |
+
return hash_val
|
12 |
+
|
13 |
+
|
14 |
+
def calc_image_ahash(image):
|
15 |
+
"""Average Hashing
|
16 |
+
|
17 |
+
References:
|
18 |
+
http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
|
19 |
+
"""
|
20 |
+
assert khandy.is_numpy_image(image)
|
21 |
+
if image.ndim == 3:
|
22 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
23 |
+
resized = cv2.resize(image, (8, 8))
|
24 |
+
|
25 |
+
mean_val = np.mean(resized)
|
26 |
+
hash_mat = resized >= mean_val
|
27 |
+
hash_val = _convert_bool_matrix_to_int(hash_mat)
|
28 |
+
return f'{hash_val:016x}'
|
29 |
+
|
30 |
+
|
31 |
+
def calc_image_dhash(image):
|
32 |
+
"""Difference Hashing
|
33 |
+
|
34 |
+
References:
|
35 |
+
http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
|
36 |
+
"""
|
37 |
+
assert khandy.is_numpy_image(image)
|
38 |
+
if image.ndim == 3:
|
39 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
40 |
+
resized = cv2.resize(image, (9, 8))
|
41 |
+
|
42 |
+
hash_mat = resized[:,:-1] >= resized[:,1:]
|
43 |
+
hash_val = _convert_bool_matrix_to_int(hash_mat)
|
44 |
+
return f'{hash_val:016x}'
|
45 |
+
|
46 |
+
|
47 |
+
def calc_image_phash(image):
|
48 |
+
"""Perceptual Hashing
|
49 |
+
|
50 |
+
References:
|
51 |
+
http://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html
|
52 |
+
"""
|
53 |
+
assert khandy.is_numpy_image(image)
|
54 |
+
if image.ndim == 3:
|
55 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
56 |
+
resized = cv2.resize(image, (32, 32))
|
57 |
+
|
58 |
+
dct_coeff = cv2.dct(resized.astype(np.float32))
|
59 |
+
reduced_dct_coeff = dct_coeff[:8, :8]
|
60 |
+
|
61 |
+
# # mean of coefficients excluding the DC term (0th term)
|
62 |
+
# mean_val = np.mean(reduced_dct_coeff.flatten()[1:])
|
63 |
+
# median of coefficients
|
64 |
+
median_val = np.median(reduced_dct_coeff)
|
65 |
+
|
66 |
+
hash_mat = reduced_dct_coeff >= median_val
|
67 |
+
hash_val = _convert_bool_matrix_to_int(hash_mat)
|
68 |
+
return f'{hash_val:016x}'
|
69 |
+
|
khandy/image/misc.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imghdr
|
3 |
+
import numbers
|
4 |
+
import warnings
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import khandy
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
def imread(file_or_buffer, flags=-1):
|
14 |
+
"""Improvement on cv2.imread, make it support filename including chinese character.
|
15 |
+
"""
|
16 |
+
try:
|
17 |
+
if isinstance(file_or_buffer, bytes):
|
18 |
+
return cv2.imdecode(np.frombuffer(file_or_buffer, dtype=np.uint8), flags)
|
19 |
+
else:
|
20 |
+
# support type: file or str or Path
|
21 |
+
return cv2.imdecode(np.fromfile(file_or_buffer, dtype=np.uint8), flags)
|
22 |
+
except Exception as e:
|
23 |
+
print(e)
|
24 |
+
return None
|
25 |
+
|
26 |
+
|
27 |
+
def imread_cv(file_or_buffer, flags=-1):
|
28 |
+
warnings.warn('khandy.imread_cv will be deprecated, use khandy.imread instead!')
|
29 |
+
return imread(file_or_buffer, flags)
|
30 |
+
|
31 |
+
|
32 |
+
def imwrite(filename, image, params=None):
|
33 |
+
"""Improvement on cv2.imwrite, make it support filename including chinese character.
|
34 |
+
"""
|
35 |
+
cv2.imencode(os.path.splitext(filename)[-1], image, params)[1].tofile(filename)
|
36 |
+
|
37 |
+
|
38 |
+
def imwrite_cv(filename, image, params=None):
|
39 |
+
warnings.warn('khandy.imwrite_cv will be deprecated, use khandy.imwrite instead!')
|
40 |
+
return imwrite(filename, image, params)
|
41 |
+
|
42 |
+
|
43 |
+
def imread_pil(file_or_buffer, to_mode=None):
|
44 |
+
"""Improvement on Image.open to avoid ResourceWarning.
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
if isinstance(file_or_buffer, bytes):
|
48 |
+
buffer = BytesIO()
|
49 |
+
buffer.write(file_or_buffer)
|
50 |
+
buffer.seek(0)
|
51 |
+
file_or_buffer = buffer
|
52 |
+
|
53 |
+
if hasattr(file_or_buffer, 'read'):
|
54 |
+
image = Image.open(file_or_buffer)
|
55 |
+
if to_mode is not None:
|
56 |
+
image = image.convert(to_mode)
|
57 |
+
else:
|
58 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
59 |
+
with open(file_or_buffer, 'rb') as f:
|
60 |
+
image = Image.open(f)
|
61 |
+
# If convert outside with statement, will raise "seek of closed file" as
|
62 |
+
# https://github.com/microsoft/Swin-Transformer/issues/66
|
63 |
+
if to_mode is not None:
|
64 |
+
image = image.convert(to_mode)
|
65 |
+
return image
|
66 |
+
except Exception as e:
|
67 |
+
print(e)
|
68 |
+
return None
|
69 |
+
|
70 |
+
|
71 |
+
def imwrite_bytes(filename, image_bytes: bytes, update_extension: bool = True):
|
72 |
+
"""Write image bytes to file.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
filename: str
|
76 |
+
filename which image_bytes is written into.
|
77 |
+
image_bytes: bytes
|
78 |
+
image content to be written.
|
79 |
+
update_extension: bool
|
80 |
+
whether update extension according to image_bytes or not.
|
81 |
+
the cost of update extension is smaller than update image format.
|
82 |
+
"""
|
83 |
+
extension = imghdr.what('', image_bytes)
|
84 |
+
file_extension = khandy.get_path_extension(filename)
|
85 |
+
# imghdr.what fails to determine image format sometimes!
|
86 |
+
# so when its return value is None, never update extension.
|
87 |
+
if extension is None:
|
88 |
+
image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
|
89 |
+
image_bytes = cv2.imencode(file_extension, image)[1]
|
90 |
+
elif (extension.lower() != file_extension.lower()[1:]):
|
91 |
+
if update_extension:
|
92 |
+
filename = khandy.replace_path_extension(filename, extension)
|
93 |
+
else:
|
94 |
+
image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
|
95 |
+
image_bytes = cv2.imencode(file_extension, image)[1]
|
96 |
+
|
97 |
+
with open(filename, "wb") as f:
|
98 |
+
f.write(image_bytes)
|
99 |
+
return filename
|
100 |
+
|
101 |
+
|
102 |
+
def rescale_image(image: np.ndarray, rescale_factor='auto', dst_dtype=np.float32):
|
103 |
+
"""Rescale image by rescale_factor.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
img (ndarray): Image to be rescaled.
|
107 |
+
rescale_factor (str, int or float, *optional*, defaults to `'auto'`):
|
108 |
+
rescale the image by the specified scale factor. When is `'auto'`,
|
109 |
+
rescale the image to [0, 1).
|
110 |
+
dtype (np.dtype, *optional*, defaults to `np.float32`):
|
111 |
+
The dtype of the output image. Defaults to `np.float32`.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
ndarray: The rescaled image.
|
115 |
+
"""
|
116 |
+
if rescale_factor == 'auto':
|
117 |
+
if np.issubdtype(image.dtype, np.unsignedinteger):
|
118 |
+
rescale_factor = 1. / np.iinfo(image.dtype).max
|
119 |
+
else:
|
120 |
+
raise TypeError(f'Only support uint dtype ndarray when `rescale_factor` is `auto`, got {image.dtype}')
|
121 |
+
elif issubclass(rescale_factor, (int, float)):
|
122 |
+
pass
|
123 |
+
else:
|
124 |
+
raise TypeError('rescale_factor must be "auto", int or float')
|
125 |
+
image = image.astype(dst_dtype, copy=True)
|
126 |
+
image *= rescale_factor
|
127 |
+
image = image.astype(dst_dtype)
|
128 |
+
return image
|
129 |
+
|
130 |
+
|
131 |
+
def normalize_image_value(image: np.ndarray, mean, std, rescale_factor=None):
|
132 |
+
"""Normalize an image with mean and std, rescale optionally.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
image (ndarray): Image to be normalized.
|
136 |
+
mean (int, float, Sequence[int], Sequence[float], ndarray): The mean to be used for normalize.
|
137 |
+
std (int, float, Sequence[int], Sequence[float], ndarray): The std to be used for normalize.
|
138 |
+
rescale_factor (None, 'auto', int or float, *optional*, defaults to `None`):
|
139 |
+
rescale the image by the specified scale factor. When is `'auto'`,
|
140 |
+
rescale the image to [0, 1); When is `None`, do not rescale.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
ndarray: The normalized image which dtype is np.float32.
|
144 |
+
"""
|
145 |
+
dst_dtype = np.float32
|
146 |
+
mean = np.array(mean, dtype=dst_dtype).flatten()
|
147 |
+
std = np.array(std, dtype=dst_dtype).flatten()
|
148 |
+
if rescale_factor == 'auto':
|
149 |
+
if np.issubdtype(image.dtype, np.unsignedinteger):
|
150 |
+
mean *= np.iinfo(image.dtype).max
|
151 |
+
std *= np.iinfo(image.dtype).max
|
152 |
+
else:
|
153 |
+
raise TypeError(f'Only support uint dtype ndarray when `rescale_factor` is `auto`, got {image.dtype}')
|
154 |
+
elif isinstance(rescale_factor, (int, float)):
|
155 |
+
mean *= rescale_factor
|
156 |
+
std *= rescale_factor
|
157 |
+
image = image.astype(dst_dtype, copy=True)
|
158 |
+
image -= mean
|
159 |
+
image /= std
|
160 |
+
return image
|
161 |
+
|
162 |
+
|
163 |
+
def normalize_image_dtype(image, keep_num_channels=False):
|
164 |
+
"""Normalize image dtype to uint8 (usually for visualization).
|
165 |
+
|
166 |
+
Args:
|
167 |
+
image : ndarray
|
168 |
+
Input image.
|
169 |
+
keep_num_channels : bool, optional
|
170 |
+
If this is set to True, the result is an array which has
|
171 |
+
the same shape as input image, otherwise the result is
|
172 |
+
an array whose channels number is 3.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
out: ndarray
|
176 |
+
Image whose dtype is np.uint8.
|
177 |
+
"""
|
178 |
+
assert (image.ndim == 3 and image.shape[-1] in [1, 3]) or (image.ndim == 2)
|
179 |
+
|
180 |
+
image = image.astype(np.float32)
|
181 |
+
image = khandy.minmax_normalize(image, axis=None, copy=False)
|
182 |
+
image = np.array(image * 255, dtype=np.uint8)
|
183 |
+
|
184 |
+
if not keep_num_channels:
|
185 |
+
if image.ndim == 2:
|
186 |
+
image = np.expand_dims(image, -1)
|
187 |
+
if image.shape[-1] == 1:
|
188 |
+
image = np.tile(image, (1,1,3))
|
189 |
+
return image
|
190 |
+
|
191 |
+
|
192 |
+
def normalize_image_channel(image, swap_rb=False):
|
193 |
+
"""Normalize image channel number and order to RGB or BGR.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
image : ndarray
|
197 |
+
Input image.
|
198 |
+
swap_rb : bool, optional
|
199 |
+
whether swap red and blue channel or not
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
out: ndarray
|
203 |
+
Image whose shape is (..., 3).
|
204 |
+
"""
|
205 |
+
if image.ndim == 2:
|
206 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
207 |
+
elif image.ndim == 3:
|
208 |
+
num_channels = image.shape[-1]
|
209 |
+
if num_channels == 1:
|
210 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
211 |
+
elif num_channels == 3:
|
212 |
+
if swap_rb:
|
213 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
214 |
+
elif num_channels == 4:
|
215 |
+
if swap_rb:
|
216 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
217 |
+
else:
|
218 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
219 |
+
else:
|
220 |
+
raise ValueError(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')
|
221 |
+
else:
|
222 |
+
raise ValueError(f'Unsupported image ndarray ndim, only support 2 and 3, got {image.ndim}!')
|
223 |
+
return image
|
224 |
+
|
225 |
+
|
226 |
+
def normalize_image_shape(image, swap_rb=False):
|
227 |
+
warnings.warn('khandy.normalize_image_shape will be deprecated, use khandy.normalize_image_channel instead!')
|
228 |
+
return normalize_image_channel(image, swap_rb)
|
229 |
+
|
230 |
+
|
231 |
+
def stack_image_list(image_list, dtype=np.float32):
|
232 |
+
"""Join a sequence of image along a new axis before first axis.
|
233 |
+
|
234 |
+
References:
|
235 |
+
`im_list_to_blob` in `py-faster-rcnn-master/lib/utils/blob.py`
|
236 |
+
"""
|
237 |
+
assert isinstance(image_list, (tuple, list))
|
238 |
+
|
239 |
+
max_dimension = np.array([image.ndim for image in image_list]).max()
|
240 |
+
assert max_dimension in [2, 3]
|
241 |
+
max_shape = np.array([image.shape[:2] for image in image_list]).max(axis=0)
|
242 |
+
|
243 |
+
num_channels = []
|
244 |
+
for image in image_list:
|
245 |
+
if image.ndim == 2:
|
246 |
+
num_channels.append(1)
|
247 |
+
else:
|
248 |
+
num_channels.append(image.shape[-1])
|
249 |
+
assert len(set(num_channels) - set([1])) in [0, 1]
|
250 |
+
max_num_channels = np.max(num_channels)
|
251 |
+
|
252 |
+
blob = np.empty((len(image_list), max_shape[0], max_shape[1], max_num_channels), dtype=dtype)
|
253 |
+
for k, image in enumerate(image_list):
|
254 |
+
blob[k, :image.shape[0], :image.shape[1], :] = np.atleast_3d(image).astype(dtype, copy=False)
|
255 |
+
if max_dimension == 2:
|
256 |
+
blob = np.squeeze(blob, axis=-1)
|
257 |
+
return blob
|
258 |
+
|
259 |
+
|
260 |
+
def is_numpy_image(image):
|
261 |
+
return isinstance(image, np.ndarray) and image.ndim in {2, 3}
|
262 |
+
|
263 |
+
|
264 |
+
def is_gray_image(image, tol=3):
|
265 |
+
assert is_numpy_image(image)
|
266 |
+
if image.ndim == 2:
|
267 |
+
return True
|
268 |
+
elif image.ndim == 3:
|
269 |
+
num_channels = image.shape[-1]
|
270 |
+
if num_channels == 1:
|
271 |
+
return True
|
272 |
+
elif num_channels == 3:
|
273 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
274 |
+
gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
275 |
+
mae = np.mean(cv2.absdiff(image, gray3))
|
276 |
+
return mae <= tol
|
277 |
+
elif num_channels == 4:
|
278 |
+
rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
279 |
+
gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
|
280 |
+
gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
281 |
+
mae = np.mean(cv2.absdiff(rgb, gray3))
|
282 |
+
return mae <= tol
|
283 |
+
else:
|
284 |
+
return False
|
285 |
+
else:
|
286 |
+
return False
|
287 |
+
|
288 |
+
|
289 |
+
def is_solid_color_image(image, tol=4):
|
290 |
+
assert is_numpy_image(image)
|
291 |
+
mean = np.array(cv2.mean(image)[:-1], dtype=np.float32)
|
292 |
+
|
293 |
+
if image.ndim == 2:
|
294 |
+
mae = np.mean(np.abs(image - mean[0]))
|
295 |
+
return mae <= tol
|
296 |
+
elif image.ndim == 3:
|
297 |
+
num_channels = image.shape[-1]
|
298 |
+
if num_channels == 1:
|
299 |
+
mae = np.mean(np.abs(image - mean[0]))
|
300 |
+
return mae <= tol
|
301 |
+
elif num_channels == 3:
|
302 |
+
mae = np.mean(np.abs(image - mean))
|
303 |
+
return mae <= tol
|
304 |
+
elif num_channels == 4:
|
305 |
+
mae = np.mean(np.abs(image[:,:,:-1] - mean))
|
306 |
+
return mae <= tol
|
307 |
+
else:
|
308 |
+
return False
|
309 |
+
else:
|
310 |
+
return False
|
311 |
+
|
312 |
+
|
313 |
+
def create_solid_color_image(image_width, image_height, color, dtype=None):
|
314 |
+
if isinstance(color, numbers.Real):
|
315 |
+
image = np.full((image_height, image_width), color, dtype=dtype)
|
316 |
+
elif isinstance(color, (tuple, list)):
|
317 |
+
if len(color) == 1:
|
318 |
+
image = np.full((image_height, image_width), color[0], dtype=dtype)
|
319 |
+
elif len(color) in (3, 4):
|
320 |
+
image = np.full((1, 1, len(color)), color, dtype=dtype)
|
321 |
+
image = cv2.copyMakeBorder(image, 0, image_height-1, 0, image_width-1,
|
322 |
+
cv2.BORDER_CONSTANT, value=color)
|
323 |
+
else:
|
324 |
+
color = np.asarray(color, dtype=dtype)
|
325 |
+
image = np.empty((image_height, image_width, len(color)), dtype=dtype)
|
326 |
+
image[:] = color
|
327 |
+
else:
|
328 |
+
raise TypeError(f'Invalid type {type(color)} for `color`.')
|
329 |
+
return image
|
khandy/image/resize.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import khandy
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
interp_codes = {
|
9 |
+
'nearest': cv2.INTER_NEAREST,
|
10 |
+
'bilinear': cv2.INTER_LINEAR,
|
11 |
+
'bicubic': cv2.INTER_CUBIC,
|
12 |
+
'area': cv2.INTER_AREA,
|
13 |
+
'lanczos': cv2.INTER_LANCZOS4
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
def scale_image(image, x_scale, y_scale, interpolation='bilinear'):
|
18 |
+
"""Scale image.
|
19 |
+
|
20 |
+
Reference:
|
21 |
+
mmcv.imrescale
|
22 |
+
"""
|
23 |
+
assert khandy.is_numpy_image(image)
|
24 |
+
src_height, src_width = image.shape[:2]
|
25 |
+
dst_width = int(round(x_scale * src_width))
|
26 |
+
dst_height = int(round(y_scale * src_height))
|
27 |
+
|
28 |
+
resized_image = cv2.resize(image, (dst_width, dst_height),
|
29 |
+
interpolation=interp_codes[interpolation])
|
30 |
+
return resized_image
|
31 |
+
|
32 |
+
|
33 |
+
def resize_image(image, dst_width, dst_height, return_scale=False, interpolation='bilinear'):
|
34 |
+
"""Resize image to a given size.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
image (ndarray): The input image.
|
38 |
+
dst_width (int): Target width.
|
39 |
+
dst_height (int): Target height.
|
40 |
+
return_scale (bool): Whether to return `x_scale` and `y_scale`.
|
41 |
+
interpolation (str): Interpolation method, accepted values are
|
42 |
+
"nearest", "bilinear", "bicubic", "area", "lanczos".
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
tuple or ndarray: (`resized_image`, `x_scale`, `y_scale`) or `resized_image`.
|
46 |
+
|
47 |
+
Reference:
|
48 |
+
mmcv.imresize
|
49 |
+
"""
|
50 |
+
assert khandy.is_numpy_image(image)
|
51 |
+
resized_image = cv2.resize(image, (dst_width, dst_height),
|
52 |
+
interpolation=interp_codes[interpolation])
|
53 |
+
if not return_scale:
|
54 |
+
return resized_image
|
55 |
+
else:
|
56 |
+
src_height, src_width = image.shape[:2]
|
57 |
+
x_scale = dst_width / src_width
|
58 |
+
y_scale = dst_height / src_height
|
59 |
+
return resized_image, x_scale, y_scale
|
60 |
+
|
61 |
+
|
62 |
+
def resize_image_short(image, dst_size, return_scale=False, interpolation='bilinear'):
|
63 |
+
"""Resize an image so that the length of shorter side is dst_size while
|
64 |
+
preserving the original aspect ratio.
|
65 |
+
|
66 |
+
References:
|
67 |
+
`resize_min` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
|
68 |
+
"""
|
69 |
+
assert khandy.is_numpy_image(image)
|
70 |
+
src_height, src_width = image.shape[:2]
|
71 |
+
scale = max(dst_size / src_width, dst_size / src_height)
|
72 |
+
dst_width = int(round(scale * src_width))
|
73 |
+
dst_height = int(round(scale * src_height))
|
74 |
+
|
75 |
+
resized_image = cv2.resize(image, (dst_width, dst_height),
|
76 |
+
interpolation=interp_codes[interpolation])
|
77 |
+
if not return_scale:
|
78 |
+
return resized_image
|
79 |
+
else:
|
80 |
+
return resized_image, scale
|
81 |
+
|
82 |
+
|
83 |
+
def resize_image_long(image, dst_size, return_scale=False, interpolation='bilinear'):
|
84 |
+
"""Resize an image so that the length of longer side is dst_size while
|
85 |
+
preserving the original aspect ratio.
|
86 |
+
|
87 |
+
References:
|
88 |
+
`resize_max` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
|
89 |
+
"""
|
90 |
+
assert khandy.is_numpy_image(image)
|
91 |
+
src_height, src_width = image.shape[:2]
|
92 |
+
scale = min(dst_size / src_width, dst_size / src_height)
|
93 |
+
dst_width = int(round(scale * src_width))
|
94 |
+
dst_height = int(round(scale * src_height))
|
95 |
+
|
96 |
+
resized_image = cv2.resize(image, (dst_width, dst_height),
|
97 |
+
interpolation=interp_codes[interpolation])
|
98 |
+
if not return_scale:
|
99 |
+
return resized_image
|
100 |
+
else:
|
101 |
+
return resized_image, scale
|
102 |
+
|
103 |
+
|
104 |
+
def resize_image_to_range(image, min_length, max_length, return_scale=False, interpolation='bilinear'):
|
105 |
+
"""Resizes an image so its dimensions are within the provided value.
|
106 |
+
|
107 |
+
Rescale the shortest side of the image up to `min_length` pixels
|
108 |
+
while keeping the largest side below `max_length` pixels without
|
109 |
+
changing the aspect ratio. Often used in object detection (e.g. RCNN and SSH.)
|
110 |
+
|
111 |
+
The output size can be described by two cases:
|
112 |
+
1. If the image can be rescaled so its shortest side is equal to the
|
113 |
+
`min_length` without the other side exceeding `max_length`, then do so.
|
114 |
+
2. Otherwise, resize so the longest side is equal to `max_length`.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
resized_image: resized image so that
|
118 |
+
min(dst_height, dst_width) == min_length or
|
119 |
+
max(dst_height, dst_width) == max_length.
|
120 |
+
|
121 |
+
References:
|
122 |
+
`resize_to_range` in `models-master/research/object_detection/core/preprocessor.py`
|
123 |
+
`prep_im_for_blob` in `py-faster-rcnn-master/lib/utils/blob.py`
|
124 |
+
mmcv.imrescale
|
125 |
+
"""
|
126 |
+
assert khandy.is_numpy_image(image)
|
127 |
+
assert min_length < max_length
|
128 |
+
src_height, src_width = image.shape[:2]
|
129 |
+
|
130 |
+
min_side_length = min(src_width, src_height)
|
131 |
+
max_side_length = max(src_width, src_height)
|
132 |
+
scale = min_length / min_side_length
|
133 |
+
if round(scale * max_side_length) > max_length:
|
134 |
+
scale = max_length / max_side_length
|
135 |
+
dst_width = int(round(scale * src_width))
|
136 |
+
dst_height = int(round(scale * src_height))
|
137 |
+
|
138 |
+
resized_image = cv2.resize(image, (dst_width, dst_height),
|
139 |
+
interpolation=interp_codes[interpolation])
|
140 |
+
if not return_scale:
|
141 |
+
return resized_image
|
142 |
+
else:
|
143 |
+
return resized_image, scale
|
144 |
+
|
145 |
+
|
146 |
+
def letterbox_image(image, dst_width, dst_height, border_value=0,
|
147 |
+
return_scale=False, interpolation='bilinear'):
|
148 |
+
"""Resize an image preserving the original aspect ratio using padding.
|
149 |
+
|
150 |
+
References:
|
151 |
+
`letterbox_image` in `https://github.com/pjreddie/darknet/blob/master/src/image.c`
|
152 |
+
"""
|
153 |
+
assert khandy.is_numpy_image(image)
|
154 |
+
src_height, src_width = image.shape[:2]
|
155 |
+
scale = min(dst_width / src_width, dst_height / src_height)
|
156 |
+
resize_w = int(round(scale * src_width))
|
157 |
+
resize_h = int(round(scale * src_height))
|
158 |
+
|
159 |
+
resized_image = cv2.resize(image, (resize_w, resize_h),
|
160 |
+
interpolation=interp_codes[interpolation])
|
161 |
+
pad_top = (dst_height - resize_h) // 2
|
162 |
+
pad_bottom = (dst_height - resize_h) - pad_top
|
163 |
+
pad_left = (dst_width - resize_w) // 2
|
164 |
+
pad_right = (dst_width - resize_w) - pad_left
|
165 |
+
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
|
166 |
+
cv2.BORDER_CONSTANT, value=border_value)
|
167 |
+
if not return_scale:
|
168 |
+
return padded_image
|
169 |
+
else:
|
170 |
+
return padded_image, scale, pad_left, pad_top
|
171 |
+
|
172 |
+
|
173 |
+
def letterbox_resize_image(image, dst_width, dst_height, border_value=0,
|
174 |
+
return_scale=False, interpolation='bilinear'):
|
175 |
+
warnings.warn('letterbox_resize_image will be deprecated, use letterbox_image instead!')
|
176 |
+
return letterbox_image(image, dst_width, dst_height, border_value,
|
177 |
+
return_scale, interpolation)
|
khandy/image/rotate.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import khandy
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def get_2d_rotation_matrix(angle, cx=0, cy=0, scale=1,
|
7 |
+
degrees=True, dtype=np.float32):
|
8 |
+
"""
|
9 |
+
References:
|
10 |
+
`cv2.getRotationMatrix2D` in OpenCV
|
11 |
+
"""
|
12 |
+
if degrees:
|
13 |
+
angle = np.deg2rad(angle)
|
14 |
+
c = scale * np.cos(angle)
|
15 |
+
s = scale * np.sin(angle)
|
16 |
+
|
17 |
+
tx = cx - cx * c + cy * s
|
18 |
+
ty = cy - cx * s - cy * c
|
19 |
+
return np.array([[ c, -s, tx],
|
20 |
+
[ s, c, ty],
|
21 |
+
[ 0, 0, 1]], dtype=dtype)
|
22 |
+
|
23 |
+
|
24 |
+
def rotate_image(image, angle, scale=1.0, center=None,
|
25 |
+
degrees=True, border_value=0, auto_bound=False):
|
26 |
+
"""Rotate an image.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
image : ndarray
|
30 |
+
Image to be rotated.
|
31 |
+
angle : float
|
32 |
+
Rotation angle in degrees, positive values mean clockwise rotation.
|
33 |
+
center : tuple
|
34 |
+
Center of the rotation in the source image, by default
|
35 |
+
it is the center of the image.
|
36 |
+
scale : float
|
37 |
+
Isotropic scale factor.
|
38 |
+
degrees : bool
|
39 |
+
border_value : int
|
40 |
+
Border value.
|
41 |
+
auto_bound : bool
|
42 |
+
Whether to adjust the image size to cover the whole rotated image.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
ndarray: The rotated image.
|
46 |
+
|
47 |
+
References:
|
48 |
+
mmcv.imrotate
|
49 |
+
"""
|
50 |
+
assert khandy.is_numpy_image(image)
|
51 |
+
image_height, image_width = image.shape[:2]
|
52 |
+
if auto_bound:
|
53 |
+
center = None
|
54 |
+
if center is None:
|
55 |
+
center = ((image_width - 1) * 0.5, (image_height - 1) * 0.5)
|
56 |
+
assert isinstance(center, tuple)
|
57 |
+
|
58 |
+
rotation_matrix = get_2d_rotation_matrix(angle, center[0], center[1], scale, degrees)
|
59 |
+
if auto_bound:
|
60 |
+
scale_cos = np.abs(rotation_matrix[0, 0])
|
61 |
+
scale_sin = np.abs(rotation_matrix[0, 1])
|
62 |
+
new_width = image_width * scale_cos + image_height * scale_sin
|
63 |
+
new_height = image_width * scale_sin + image_height * scale_cos
|
64 |
+
|
65 |
+
rotation_matrix[0, 2] += (new_width - image_width) * 0.5
|
66 |
+
rotation_matrix[1, 2] += (new_height - image_height) * 0.5
|
67 |
+
|
68 |
+
image_width = int(np.round(new_width))
|
69 |
+
image_height = int(np.round(new_height))
|
70 |
+
rotated = cv2.warpAffine(image, rotation_matrix[:2,:], (image_width, image_height),
|
71 |
+
borderValue=border_value)
|
72 |
+
return rotated
|
khandy/image/translate.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numbers
|
2 |
+
|
3 |
+
import khandy
|
4 |
+
|
5 |
+
|
6 |
+
def translate_image(image, x_shift, y_shift, border_value=0):
|
7 |
+
"""Translate an image.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
image (ndarray): Image to be translated with format (h, w) or (h, w, c).
|
11 |
+
x_shift (int): The offset used for translate in horizontal
|
12 |
+
direction. right is the positive direction.
|
13 |
+
y_shift (int): The offset used for translate in vertical
|
14 |
+
direction. down is the positive direction.
|
15 |
+
border_value (int | tuple[int]): Value used in case of a
|
16 |
+
constant border.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
ndarray: The translated image.
|
20 |
+
|
21 |
+
See Also:
|
22 |
+
crop_or_pad
|
23 |
+
"""
|
24 |
+
assert khandy.is_numpy_image(image)
|
25 |
+
assert isinstance(x_shift, numbers.Integral)
|
26 |
+
assert isinstance(y_shift, numbers.Integral)
|
27 |
+
image_height, image_width = image.shape[:2]
|
28 |
+
channels = 1 if image.ndim == 2 else image.shape[2]
|
29 |
+
|
30 |
+
if isinstance(border_value, (tuple, list)):
|
31 |
+
assert len(border_value) == channels, \
|
32 |
+
'Expected the num of elements in tuple equals the channels ' \
|
33 |
+
'of input image. Found {} vs {}'.format(
|
34 |
+
len(border_value), channels)
|
35 |
+
else:
|
36 |
+
border_value = (border_value,) * channels
|
37 |
+
dst_image = khandy.create_solid_color_image(
|
38 |
+
image_height, image_width, border_value, dtype=image.dtype)
|
39 |
+
|
40 |
+
if (abs(x_shift) >= image_width) or (abs(y_shift) >= image_height):
|
41 |
+
return dst_image
|
42 |
+
|
43 |
+
src_x_begin = max(-x_shift, 0)
|
44 |
+
src_x_end = min(image_width - x_shift, image_width)
|
45 |
+
dst_x_begin = max(x_shift, 0)
|
46 |
+
dst_x_end = min(image_width + x_shift, image_width)
|
47 |
+
|
48 |
+
src_y_begin = max(-y_shift, 0)
|
49 |
+
src_y_end = min(image_height - y_shift, image_height)
|
50 |
+
dst_y_begin = max(y_shift, 0)
|
51 |
+
dst_y_end = min(image_height + y_shift, image_height)
|
52 |
+
|
53 |
+
dst_image[dst_y_begin:dst_y_end, dst_x_begin:dst_x_end] = \
|
54 |
+
image[src_y_begin:src_y_end, src_x_begin:src_x_end]
|
55 |
+
return dst_image
|
56 |
+
|
57 |
+
|
khandy/label/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .detect import *
|
2 |
+
|
khandy/label/detect.py
ADDED
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import json
|
4 |
+
import dataclasses
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from collections import OrderedDict
|
7 |
+
from typing import Optional, List
|
8 |
+
import xml.etree.ElementTree as ET
|
9 |
+
|
10 |
+
import khandy
|
11 |
+
import lxml
|
12 |
+
import lxml.builder
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
__all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
|
17 |
+
'save_detect', 'convert_detect', 'replace_detect_label',
|
18 |
+
'load_coco_class_names']
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class DetectIrObject:
|
23 |
+
"""Intermediate Representation Format of Object
|
24 |
+
"""
|
25 |
+
label: str
|
26 |
+
x_min: float
|
27 |
+
y_min: float
|
28 |
+
x_max: float
|
29 |
+
y_max: float
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class DetectIrRecord:
|
34 |
+
"""Intermediate Representation Format of Record
|
35 |
+
"""
|
36 |
+
filename: str
|
37 |
+
width: int
|
38 |
+
height: int
|
39 |
+
objects: List[DetectIrObject] = field(default_factory=list)
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class PascalVocSource:
|
44 |
+
database: str = ''
|
45 |
+
annotation: str = ''
|
46 |
+
image: str = ''
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class PascalVocSize:
|
51 |
+
height: int
|
52 |
+
width: int
|
53 |
+
depth: int
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class PascalVocBndbox:
|
58 |
+
xmin: float
|
59 |
+
ymin: float
|
60 |
+
xmax: float
|
61 |
+
ymax: float
|
62 |
+
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class PascalVocObject:
|
66 |
+
name: str
|
67 |
+
pose: str = 'Unspecified'
|
68 |
+
truncated: int = 0
|
69 |
+
difficult: int = 0
|
70 |
+
bndbox: Optional[PascalVocBndbox] = None
|
71 |
+
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
class PascalVocRecord:
|
75 |
+
folder: str = ''
|
76 |
+
filename: str = ''
|
77 |
+
path: str = ''
|
78 |
+
source: PascalVocSource = PascalVocSource()
|
79 |
+
size: Optional[PascalVocSize] = None
|
80 |
+
segmented: int = 0
|
81 |
+
objects: List[PascalVocObject] = field(default_factory=list)
|
82 |
+
|
83 |
+
|
84 |
+
class PascalVocHandler:
|
85 |
+
@staticmethod
|
86 |
+
def load(filename, **kwargs) -> PascalVocRecord:
|
87 |
+
pascal_voc_record = PascalVocRecord()
|
88 |
+
|
89 |
+
xml_tree = ET.parse(filename)
|
90 |
+
pascal_voc_record.folder = xml_tree.find('folder').text
|
91 |
+
pascal_voc_record.filename = xml_tree.find('filename').text
|
92 |
+
pascal_voc_record.path = xml_tree.find('path').text
|
93 |
+
pascal_voc_record.segmented = xml_tree.find('segmented').text
|
94 |
+
|
95 |
+
source_tag = xml_tree.find('source')
|
96 |
+
pascal_voc_record.source = PascalVocSource(
|
97 |
+
database=source_tag.find('database').text,
|
98 |
+
# annotation=source_tag.find('annotation').text,
|
99 |
+
# image=source_tag.find('image').text
|
100 |
+
)
|
101 |
+
|
102 |
+
size_tag = xml_tree.find('size')
|
103 |
+
pascal_voc_record.size = PascalVocSize(
|
104 |
+
width=int(size_tag.find('width').text),
|
105 |
+
height=int(size_tag.find('height').text),
|
106 |
+
depth=int(size_tag.find('depth').text)
|
107 |
+
)
|
108 |
+
|
109 |
+
object_tags = xml_tree.findall('object')
|
110 |
+
for index, object_tag in enumerate(object_tags):
|
111 |
+
bndbox_tag = object_tag.find('bndbox')
|
112 |
+
bndbox = PascalVocBndbox(
|
113 |
+
xmin=float(bndbox_tag.find('xmin').text) - 1,
|
114 |
+
ymin=float(bndbox_tag.find('ymin').text) - 1,
|
115 |
+
xmax=float(bndbox_tag.find('xmax').text) - 1,
|
116 |
+
ymax=float(bndbox_tag.find('ymax').text) - 1
|
117 |
+
)
|
118 |
+
pascal_voc_object = PascalVocObject(
|
119 |
+
name=object_tag.find('name').text,
|
120 |
+
pose=object_tag.find('pose').text,
|
121 |
+
truncated=object_tag.find('truncated').text,
|
122 |
+
difficult=object_tag.find('difficult').text,
|
123 |
+
bndbox=bndbox
|
124 |
+
)
|
125 |
+
pascal_voc_record.objects.append(pascal_voc_object)
|
126 |
+
return pascal_voc_record
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def save(filename, pascal_voc_record: PascalVocRecord):
|
130 |
+
maker = lxml.builder.ElementMaker()
|
131 |
+
xml = maker.annotation(
|
132 |
+
maker.folder(pascal_voc_record.folder),
|
133 |
+
maker.filename(pascal_voc_record.filename),
|
134 |
+
maker.path(pascal_voc_record.path),
|
135 |
+
maker.source(
|
136 |
+
maker.database(pascal_voc_record.source.database),
|
137 |
+
),
|
138 |
+
maker.size(
|
139 |
+
maker.width(str(pascal_voc_record.size.width)),
|
140 |
+
maker.height(str(pascal_voc_record.size.height)),
|
141 |
+
maker.depth(str(pascal_voc_record.size.depth)),
|
142 |
+
),
|
143 |
+
maker.segmented(str(pascal_voc_record.segmented)),
|
144 |
+
)
|
145 |
+
|
146 |
+
for pascal_voc_object in pascal_voc_record.objects:
|
147 |
+
object_tag = maker.object(
|
148 |
+
maker.name(pascal_voc_object.name),
|
149 |
+
maker.pose(pascal_voc_object.pose),
|
150 |
+
maker.truncated(str(pascal_voc_object.truncated)),
|
151 |
+
maker.difficult(str(pascal_voc_object.difficult)),
|
152 |
+
maker.bndbox(
|
153 |
+
maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
|
154 |
+
maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
|
155 |
+
maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
|
156 |
+
maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
|
157 |
+
),
|
158 |
+
)
|
159 |
+
xml.append(object_tag)
|
160 |
+
|
161 |
+
if not filename.endswith('.xml'):
|
162 |
+
filename = filename + '.xml'
|
163 |
+
with open(filename, 'wb') as f:
|
164 |
+
f.write(lxml.etree.tostring(
|
165 |
+
xml, pretty_print=True, encoding='utf-8'))
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
|
169 |
+
ir_record = DetectIrRecord(
|
170 |
+
filename=pascal_voc_record.filename,
|
171 |
+
width=pascal_voc_record.size.width,
|
172 |
+
height=pascal_voc_record.size.height
|
173 |
+
)
|
174 |
+
for pascal_voc_object in pascal_voc_record.objects:
|
175 |
+
ir_object = DetectIrObject(
|
176 |
+
label=pascal_voc_object.name,
|
177 |
+
x_min=pascal_voc_object.bndbox.xmin,
|
178 |
+
y_min=pascal_voc_object.bndbox.ymin,
|
179 |
+
x_max=pascal_voc_object.bndbox.xmax,
|
180 |
+
y_max=pascal_voc_object.bndbox.ymax
|
181 |
+
)
|
182 |
+
ir_record.objects.append(ir_object)
|
183 |
+
return ir_record
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
|
187 |
+
pascal_voc_record = PascalVocRecord(
|
188 |
+
filename=ir_record.filename,
|
189 |
+
size=PascalVocSize(
|
190 |
+
width=ir_record.width,
|
191 |
+
height=ir_record.height,
|
192 |
+
depth=3
|
193 |
+
)
|
194 |
+
)
|
195 |
+
for ir_object in ir_record.objects:
|
196 |
+
pascal_voc_object = PascalVocObject(
|
197 |
+
name=ir_object.label,
|
198 |
+
bndbox=PascalVocBndbox(
|
199 |
+
xmin=ir_object.x_min,
|
200 |
+
ymin=ir_object.y_min,
|
201 |
+
xmax=ir_object.x_max,
|
202 |
+
ymax=ir_object.y_max,
|
203 |
+
)
|
204 |
+
)
|
205 |
+
pascal_voc_record.objects.append(pascal_voc_object)
|
206 |
+
return pascal_voc_record
|
207 |
+
|
208 |
+
|
209 |
+
class _NumpyEncoder(json.JSONEncoder):
|
210 |
+
""" Special json encoder for numpy types """
|
211 |
+
|
212 |
+
def default(self, obj):
|
213 |
+
if isinstance(obj, (np.bool_,)):
|
214 |
+
return bool(obj)
|
215 |
+
elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
|
216 |
+
np.int16, np.int32, np.int64, np.uint8,
|
217 |
+
np.uint16, np.uint32, np.uint64)):
|
218 |
+
return int(obj)
|
219 |
+
elif isinstance(obj, (np.float_, np.float16, np.float32,
|
220 |
+
np.float64)):
|
221 |
+
return float(obj)
|
222 |
+
elif isinstance(obj, (np.ndarray,)):
|
223 |
+
return obj.tolist()
|
224 |
+
return json.JSONEncoder.default(self, obj)
|
225 |
+
|
226 |
+
|
227 |
+
@dataclass
|
228 |
+
class LabelmeShape:
|
229 |
+
label: str
|
230 |
+
points: np.ndarray
|
231 |
+
shape_type: str
|
232 |
+
flags: dict = field(default_factory=dict)
|
233 |
+
group_id: Optional[int] = None
|
234 |
+
|
235 |
+
def __post_init__(self):
|
236 |
+
self.points = np.asarray(self.points)
|
237 |
+
|
238 |
+
|
239 |
+
@dataclass
|
240 |
+
class LabelmeRecord:
|
241 |
+
version: str = '4.5.6'
|
242 |
+
flags: dict = field(default_factory=dict)
|
243 |
+
shapes: List[LabelmeShape] = field(default_factory=list)
|
244 |
+
imagePath: Optional[str] = None
|
245 |
+
imageData: Optional[str] = None
|
246 |
+
imageHeight: Optional[int] = None
|
247 |
+
imageWidth: Optional[int] = None
|
248 |
+
|
249 |
+
def __post_init__(self):
|
250 |
+
for k, shape in enumerate(self.shapes):
|
251 |
+
self.shapes[k] = LabelmeShape(**shape)
|
252 |
+
|
253 |
+
|
254 |
+
class LabelmeHandler:
|
255 |
+
@staticmethod
|
256 |
+
def load(filename, **kwargs) -> LabelmeRecord:
|
257 |
+
json_content = khandy.load_json(filename)
|
258 |
+
return LabelmeRecord(**json_content)
|
259 |
+
|
260 |
+
@staticmethod
|
261 |
+
def save(filename, labelme_record: LabelmeRecord):
|
262 |
+
json_content = dataclasses.asdict(labelme_record)
|
263 |
+
khandy.save_json(filename, json_content, cls=_NumpyEncoder)
|
264 |
+
|
265 |
+
@staticmethod
|
266 |
+
def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
|
267 |
+
ir_record = DetectIrRecord(
|
268 |
+
filename=labelme_record.imagePath,
|
269 |
+
width=labelme_record.imageWidth,
|
270 |
+
height=labelme_record.imageHeight
|
271 |
+
)
|
272 |
+
for labelme_shape in labelme_record.shapes:
|
273 |
+
if labelme_shape.shape_type != 'rectangle':
|
274 |
+
continue
|
275 |
+
ir_object = DetectIrObject(
|
276 |
+
label=labelme_shape.label,
|
277 |
+
x_min=labelme_shape.points[0][0],
|
278 |
+
y_min=labelme_shape.points[0][1],
|
279 |
+
x_max=labelme_shape.points[1][0],
|
280 |
+
y_max=labelme_shape.points[1][1],
|
281 |
+
)
|
282 |
+
ir_record.objects.append(ir_object)
|
283 |
+
return ir_record
|
284 |
+
|
285 |
+
@staticmethod
|
286 |
+
def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
|
287 |
+
labelme_record = LabelmeRecord(
|
288 |
+
imagePath=ir_record.filename,
|
289 |
+
imageWidth=ir_record.width,
|
290 |
+
imageHeight=ir_record.height
|
291 |
+
)
|
292 |
+
for ir_object in ir_record.objects:
|
293 |
+
labelme_shape = LabelmeShape(
|
294 |
+
label=ir_object.label,
|
295 |
+
shape_type='rectangle',
|
296 |
+
points=[[ir_object.x_min, ir_object.y_min],
|
297 |
+
[ir_object.x_max, ir_object.y_max]]
|
298 |
+
)
|
299 |
+
labelme_record.shapes.append(labelme_shape)
|
300 |
+
return labelme_record
|
301 |
+
|
302 |
+
|
303 |
+
@dataclass
|
304 |
+
class YoloObject:
|
305 |
+
label: str
|
306 |
+
x_center: float
|
307 |
+
y_center: float
|
308 |
+
width: float
|
309 |
+
height: float
|
310 |
+
|
311 |
+
|
312 |
+
@dataclass
|
313 |
+
class YoloRecord:
|
314 |
+
filename: Optional[str] = None
|
315 |
+
width: Optional[int] = None
|
316 |
+
height: Optional[int] = None
|
317 |
+
objects: List[YoloObject] = field(default_factory=list)
|
318 |
+
|
319 |
+
|
320 |
+
class YoloHandler:
|
321 |
+
@staticmethod
|
322 |
+
def load(filename, **kwargs) -> YoloRecord:
|
323 |
+
assert 'image_filename' in kwargs
|
324 |
+
assert 'width' in kwargs and 'height' in kwargs
|
325 |
+
|
326 |
+
records = khandy.load_list(filename)
|
327 |
+
yolo_record = YoloRecord(
|
328 |
+
filename=kwargs.get('image_filename'),
|
329 |
+
width=kwargs.get('width'),
|
330 |
+
height=kwargs.get('height'))
|
331 |
+
for record in records:
|
332 |
+
record_parts = record.split()
|
333 |
+
yolo_record.objects.append(YoloObject(
|
334 |
+
label=record_parts[0],
|
335 |
+
x_center=float(record_parts[1]),
|
336 |
+
y_center=float(record_parts[2]),
|
337 |
+
width=float(record_parts[3]),
|
338 |
+
height=float(record_parts[4]),
|
339 |
+
))
|
340 |
+
return yolo_record
|
341 |
+
|
342 |
+
@staticmethod
|
343 |
+
def save(filename, yolo_record: YoloRecord):
|
344 |
+
records = []
|
345 |
+
for object in yolo_record.objects:
|
346 |
+
records.append(
|
347 |
+
f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
|
348 |
+
if not filename.endswith('.txt'):
|
349 |
+
filename = filename + '.txt'
|
350 |
+
khandy.save_list(filename, records)
|
351 |
+
|
352 |
+
@staticmethod
|
353 |
+
def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
|
354 |
+
ir_record = DetectIrRecord(
|
355 |
+
filename=yolo_record.filename,
|
356 |
+
width=yolo_record.width,
|
357 |
+
height=yolo_record.height
|
358 |
+
)
|
359 |
+
for yolo_object in yolo_record.objects:
|
360 |
+
x_min = (yolo_object.x_center - 0.5 *
|
361 |
+
yolo_object.width) * yolo_record.width
|
362 |
+
y_min = (yolo_object.y_center - 0.5 *
|
363 |
+
yolo_object.height) * yolo_record.height
|
364 |
+
x_max = (yolo_object.x_center + 0.5 *
|
365 |
+
yolo_object.width) * yolo_record.width
|
366 |
+
y_max = (yolo_object.y_center + 0.5 *
|
367 |
+
yolo_object.height) * yolo_record.height
|
368 |
+
ir_object = DetectIrObject(
|
369 |
+
label=yolo_object.label,
|
370 |
+
x_min=x_min,
|
371 |
+
y_min=y_min,
|
372 |
+
x_max=x_max,
|
373 |
+
y_max=y_max
|
374 |
+
)
|
375 |
+
ir_record.objects.append(ir_object)
|
376 |
+
return ir_record
|
377 |
+
|
378 |
+
@staticmethod
|
379 |
+
def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
|
380 |
+
yolo_record = YoloRecord(
|
381 |
+
filename=ir_record.filename,
|
382 |
+
width=ir_record.width,
|
383 |
+
height=ir_record.height
|
384 |
+
)
|
385 |
+
for ir_object in ir_record.objects:
|
386 |
+
x_center = (ir_object.x_max + ir_object.x_min) / \
|
387 |
+
(2 * ir_record.width)
|
388 |
+
y_center = (ir_object.y_max + ir_object.y_min) / \
|
389 |
+
(2 * ir_record.height)
|
390 |
+
width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
|
391 |
+
height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
|
392 |
+
yolo_object = YoloObject(
|
393 |
+
label=ir_object.label,
|
394 |
+
x_center=x_center,
|
395 |
+
y_center=y_center,
|
396 |
+
width=width,
|
397 |
+
height=height,
|
398 |
+
)
|
399 |
+
yolo_record.objects.append(yolo_object)
|
400 |
+
return yolo_record
|
401 |
+
|
402 |
+
|
403 |
+
@dataclass
|
404 |
+
class CocoObject:
|
405 |
+
label: str
|
406 |
+
x_min: float
|
407 |
+
y_min: float
|
408 |
+
width: float
|
409 |
+
height: float
|
410 |
+
|
411 |
+
|
412 |
+
@dataclass
|
413 |
+
class CocoRecord:
|
414 |
+
filename: str
|
415 |
+
width: int
|
416 |
+
height: int
|
417 |
+
objects: List[CocoObject] = field(default_factory=list)
|
418 |
+
|
419 |
+
|
420 |
+
class CocoHandler:
|
421 |
+
@staticmethod
|
422 |
+
def load(filename, **kwargs) -> List[CocoRecord]:
|
423 |
+
json_data = khandy.load_json(filename)
|
424 |
+
|
425 |
+
images = json_data['images']
|
426 |
+
annotations = json_data['annotations']
|
427 |
+
categories = json_data['categories']
|
428 |
+
|
429 |
+
label_map = {}
|
430 |
+
for cat_item in categories:
|
431 |
+
label_map[cat_item['id']] = cat_item['name']
|
432 |
+
|
433 |
+
coco_records = OrderedDict()
|
434 |
+
for image_item in images:
|
435 |
+
coco_records[image_item['id']] = CocoRecord(
|
436 |
+
filename=image_item['file_name'],
|
437 |
+
width=image_item['width'],
|
438 |
+
height=image_item['height'],
|
439 |
+
objects=[])
|
440 |
+
|
441 |
+
for annotation_item in annotations:
|
442 |
+
coco_object = CocoObject(
|
443 |
+
label=label_map[annotation_item['category_id']],
|
444 |
+
x_min=annotation_item['bbox'][0],
|
445 |
+
y_min=annotation_item['bbox'][1],
|
446 |
+
width=annotation_item['bbox'][2],
|
447 |
+
height=annotation_item['bbox'][3])
|
448 |
+
coco_records[annotation_item['image_id']
|
449 |
+
].objects.append(coco_object)
|
450 |
+
return list(coco_records.values())
|
451 |
+
|
452 |
+
@staticmethod
|
453 |
+
def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
|
454 |
+
ir_record = DetectIrRecord(
|
455 |
+
filename=coco_record.filename,
|
456 |
+
width=coco_record.width,
|
457 |
+
height=coco_record.height,
|
458 |
+
)
|
459 |
+
for coco_object in coco_record.objects:
|
460 |
+
ir_object = DetectIrObject(
|
461 |
+
label=coco_object.label,
|
462 |
+
x_min=coco_object.x_min,
|
463 |
+
y_min=coco_object.y_min,
|
464 |
+
x_max=coco_object.x_min + coco_object.width,
|
465 |
+
y_max=coco_object.y_min + coco_object.height
|
466 |
+
)
|
467 |
+
ir_record.objects.append(ir_object)
|
468 |
+
return ir_record
|
469 |
+
|
470 |
+
@staticmethod
|
471 |
+
def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
|
472 |
+
coco_record = CocoRecord(
|
473 |
+
filename=ir_record.filename,
|
474 |
+
width=ir_record.width,
|
475 |
+
height=ir_record.height
|
476 |
+
)
|
477 |
+
for ir_object in ir_record.objects:
|
478 |
+
coco_object = CocoObject(
|
479 |
+
label=ir_object.label,
|
480 |
+
x_min=ir_object.x_min,
|
481 |
+
y_min=ir_object.y_min,
|
482 |
+
width=ir_object.x_max - ir_object.x_min,
|
483 |
+
height=ir_object.y_max - ir_object.y_min
|
484 |
+
)
|
485 |
+
coco_record.objects.append(coco_object)
|
486 |
+
return coco_record
|
487 |
+
|
488 |
+
|
489 |
+
def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
|
490 |
+
if fmt == 'labelme':
|
491 |
+
labelme_record = LabelmeHandler.load(filename, **kwargs)
|
492 |
+
ir_record = LabelmeHandler.to_ir(labelme_record)
|
493 |
+
elif fmt == 'yolo':
|
494 |
+
yolo_record = YoloHandler.load(filename, **kwargs)
|
495 |
+
ir_record = YoloHandler.to_ir(yolo_record)
|
496 |
+
elif fmt in ('voc', 'pascal', 'pascal_voc'):
|
497 |
+
pascal_voc_record = PascalVocHandler.load(filename, **kwargs)
|
498 |
+
ir_record = PascalVocHandler.to_ir(pascal_voc_record)
|
499 |
+
elif fmt == 'coco':
|
500 |
+
coco_records = CocoHandler.load(filename, **kwargs)
|
501 |
+
ir_record = [CocoHandler.to_ir(coco_record)
|
502 |
+
for coco_record in coco_records]
|
503 |
+
else:
|
504 |
+
raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
|
505 |
+
return ir_record
|
506 |
+
|
507 |
+
|
508 |
+
def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
|
509 |
+
os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
|
510 |
+
if out_fmt == 'labelme':
|
511 |
+
labelme_record = LabelmeHandler.from_ir(ir_record)
|
512 |
+
LabelmeHandler.save(filename, labelme_record)
|
513 |
+
elif out_fmt == 'yolo':
|
514 |
+
yolo_record = YoloHandler.from_ir(ir_record)
|
515 |
+
YoloHandler.save(filename, yolo_record)
|
516 |
+
elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
|
517 |
+
pascal_voc_record = PascalVocHandler.from_ir(ir_record)
|
518 |
+
PascalVocHandler.save(filename, pascal_voc_record)
|
519 |
+
elif out_fmt == 'coco':
|
520 |
+
raise ValueError("Unsupported for `coco` now!")
|
521 |
+
else:
|
522 |
+
raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}")
|
523 |
+
|
524 |
+
|
525 |
+
def _get_format(record):
|
526 |
+
if isinstance(record, LabelmeRecord):
|
527 |
+
return ('labelme',)
|
528 |
+
elif isinstance(record, YoloRecord):
|
529 |
+
return ('yolo',)
|
530 |
+
elif isinstance(record, PascalVocRecord):
|
531 |
+
return ('voc', 'pascal', 'pascal_voc')
|
532 |
+
elif isinstance(record, CocoRecord):
|
533 |
+
return ('coco',)
|
534 |
+
elif isinstance(record, DetectIrRecord):
|
535 |
+
return ('ir', 'detect_ir')
|
536 |
+
else:
|
537 |
+
return ()
|
538 |
+
|
539 |
+
|
540 |
+
def convert_detect(record, out_fmt):
|
541 |
+
allowed_fmts = ('labelme', 'yolo', 'voc', 'coco',
|
542 |
+
'pascal', 'pascal_voc', 'ir', 'detect_ir')
|
543 |
+
if out_fmt not in allowed_fmts:
|
544 |
+
raise ValueError(
|
545 |
+
"Unsupported label format conversions for given out_fmt")
|
546 |
+
if out_fmt in _get_format(record):
|
547 |
+
return record
|
548 |
+
|
549 |
+
if isinstance(record, LabelmeRecord):
|
550 |
+
ir_record = LabelmeHandler.to_ir(record)
|
551 |
+
elif isinstance(record, YoloRecord):
|
552 |
+
ir_record = YoloHandler.to_ir(record)
|
553 |
+
elif isinstance(record, PascalVocRecord):
|
554 |
+
ir_record = PascalVocHandler.to_ir(record)
|
555 |
+
elif isinstance(record, CocoRecord):
|
556 |
+
ir_record = CocoHandler.to_ir(record)
|
557 |
+
elif isinstance(record, DetectIrRecord):
|
558 |
+
ir_record = record
|
559 |
+
else:
|
560 |
+
raise TypeError('Unsupported type for record')
|
561 |
+
|
562 |
+
if out_fmt in ('ir', 'detect_ir'):
|
563 |
+
dst_record = ir_record
|
564 |
+
elif out_fmt == 'labelme':
|
565 |
+
dst_record = LabelmeHandler.from_ir(ir_record)
|
566 |
+
elif out_fmt == 'yolo':
|
567 |
+
dst_record = YoloHandler.from_ir(ir_record)
|
568 |
+
elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
|
569 |
+
dst_record = PascalVocHandler.from_ir(ir_record)
|
570 |
+
elif out_fmt == 'coco':
|
571 |
+
dst_record = CocoHandler.from_ir(ir_record)
|
572 |
+
return dst_record
|
573 |
+
|
574 |
+
|
575 |
+
def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
|
576 |
+
dst_record = copy.deepcopy(record)
|
577 |
+
dst_objects = []
|
578 |
+
for ir_object in dst_record.objects:
|
579 |
+
if not ignore:
|
580 |
+
if ir_object.label in label_map:
|
581 |
+
ir_object.label = label_map[ir_object.label]
|
582 |
+
dst_objects.append(ir_object)
|
583 |
+
else:
|
584 |
+
if ir_object.label in label_map:
|
585 |
+
ir_object.label = label_map[ir_object.label]
|
586 |
+
dst_objects.append(ir_object)
|
587 |
+
dst_record.objects = dst_objects
|
588 |
+
return dst_record
|
589 |
+
|
590 |
+
|
591 |
+
def load_coco_class_names(filename):
|
592 |
+
json_data = khandy.load_json(filename)
|
593 |
+
categories = json_data['categories']
|
594 |
+
return [cat_item['name'] for cat_item in categories]
|
khandy/list_utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import itertools
|
3 |
+
|
4 |
+
|
5 |
+
def to_list(obj):
|
6 |
+
if obj is None:
|
7 |
+
return None
|
8 |
+
elif hasattr(obj, '__iter__') and not isinstance(obj, str):
|
9 |
+
try:
|
10 |
+
return list(obj)
|
11 |
+
except:
|
12 |
+
return [obj]
|
13 |
+
else:
|
14 |
+
return [obj]
|
15 |
+
|
16 |
+
|
17 |
+
def convert_lists_to_record(*list_objs, delimiter=None):
|
18 |
+
assert len(list_objs) >= 1, 'list_objs length must >= 1.'
|
19 |
+
delimiter = delimiter or ','
|
20 |
+
|
21 |
+
assert isinstance(list_objs[0], (tuple, list))
|
22 |
+
number = len(list_objs[0])
|
23 |
+
for item in list_objs[1:]:
|
24 |
+
assert isinstance(item, (tuple, list))
|
25 |
+
assert len(item) == number, '{} != {}'.format(len(item), number)
|
26 |
+
|
27 |
+
records = []
|
28 |
+
record_list = zip(*list_objs)
|
29 |
+
for record in record_list:
|
30 |
+
record_str = [str(item) for item in record]
|
31 |
+
records.append(delimiter.join(record_str))
|
32 |
+
return records
|
33 |
+
|
34 |
+
|
35 |
+
def shuffle_table(*table):
|
36 |
+
"""
|
37 |
+
Notes:
|
38 |
+
table can be seen as list of list which have equal items.
|
39 |
+
"""
|
40 |
+
shuffled_list = list(zip(*table))
|
41 |
+
random.shuffle(shuffled_list)
|
42 |
+
tuple_list = zip(*shuffled_list)
|
43 |
+
return [list(item) for item in tuple_list]
|
44 |
+
|
45 |
+
|
46 |
+
def transpose_table(table):
|
47 |
+
"""
|
48 |
+
Notes:
|
49 |
+
table can be seen as list of list which have equal items.
|
50 |
+
"""
|
51 |
+
m, n = len(table), len(table[0])
|
52 |
+
return [[table[i][j] for i in range(m)] for j in range(n)]
|
53 |
+
|
54 |
+
|
55 |
+
def concat_list(in_list):
|
56 |
+
"""Concatenate a list of list into a single list.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
in_list (list): The list of list to be merged.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
list: The concatenated flat list.
|
63 |
+
|
64 |
+
References:
|
65 |
+
mmcv.concat_list
|
66 |
+
"""
|
67 |
+
return list(itertools.chain(*in_list))
|
68 |
+
|
khandy/misc.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import socket
|
3 |
+
import logging
|
4 |
+
import argparse
|
5 |
+
import warnings
|
6 |
+
from enum import Enum
|
7 |
+
|
8 |
+
import requests
|
9 |
+
|
10 |
+
|
11 |
+
def all_of(iterable, pred):
|
12 |
+
"""Returns whether all elements in the iterable satisfy the predicate.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
iterable (Iterable): An iterable to check.
|
16 |
+
pred (callable): A predicate to apply to each element.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
bool: True if all elements satisfy the predicate, False otherwise.
|
20 |
+
|
21 |
+
References:
|
22 |
+
https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
|
23 |
+
"""
|
24 |
+
return all(pred(element) for element in iterable)
|
25 |
+
|
26 |
+
|
27 |
+
def any_of(iterable, pred):
|
28 |
+
"""Returns whether any element in the iterable satisfies the predicate.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
iterable (Iterable): An iterable to check.
|
32 |
+
pred (callable): A predicate to apply to each element.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
bool: True if any element satisfies the predicate, False otherwise.
|
36 |
+
|
37 |
+
References:
|
38 |
+
https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
|
39 |
+
"""
|
40 |
+
return any(pred(element) for element in iterable)
|
41 |
+
|
42 |
+
|
43 |
+
def none_of(iterable, pred):
|
44 |
+
"""Returns whether no elements in the iterable satisfy the predicate.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
iterable (Iterable): An iterable to check.
|
48 |
+
pred (callable): A predicate to apply to each element.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
bool: True if no elements satisfy the predicate, False otherwise.
|
52 |
+
|
53 |
+
References:
|
54 |
+
https://en.cppreference.com/w/cpp/algorithm/all_any_none_of
|
55 |
+
"""
|
56 |
+
return not any(pred(element) for element in iterable)
|
57 |
+
|
58 |
+
|
59 |
+
def print_with_no(obj):
|
60 |
+
if hasattr(obj, '__len__'):
|
61 |
+
for k, item in enumerate(obj):
|
62 |
+
print('[{}/{}] {}'.format(k+1, len(obj), item))
|
63 |
+
elif hasattr(obj, '__iter__'):
|
64 |
+
for k, item in enumerate(obj):
|
65 |
+
print('[{}] {}'.format(k+1, item))
|
66 |
+
else:
|
67 |
+
print('[1] {}'.format(obj))
|
68 |
+
|
69 |
+
|
70 |
+
def get_file_line_count(filename, encoding='utf-8'):
|
71 |
+
line_count = 0
|
72 |
+
buffer_size = 1024 * 1024 * 8
|
73 |
+
with open(filename, 'r', encoding=encoding) as f:
|
74 |
+
while True:
|
75 |
+
data = f.read(buffer_size)
|
76 |
+
if not data:
|
77 |
+
break
|
78 |
+
line_count += data.count('\n')
|
79 |
+
return line_count
|
80 |
+
|
81 |
+
|
82 |
+
def get_host_ip():
|
83 |
+
try:
|
84 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
85 |
+
s.connect(('8.8.8.8', 80))
|
86 |
+
ip = s.getsockname()[0]
|
87 |
+
finally:
|
88 |
+
s.close()
|
89 |
+
return ip
|
90 |
+
|
91 |
+
|
92 |
+
def set_logger(filename, level=logging.INFO, logger_name=None, formatter=None, with_print=True):
|
93 |
+
logger = logging.getLogger(logger_name)
|
94 |
+
logger.setLevel(level)
|
95 |
+
|
96 |
+
if formatter is None:
|
97 |
+
formatter = logging.Formatter('%(message)s')
|
98 |
+
|
99 |
+
# Never mutate (insert/remove elements) the list you're currently iterating on.
|
100 |
+
# If you need, make a copy.
|
101 |
+
for handler in logger.handlers[:]:
|
102 |
+
if isinstance(handler, logging.FileHandler):
|
103 |
+
logger.removeHandler(handler)
|
104 |
+
# FileHandler is subclass of StreamHandler, so isinstance(handler,
|
105 |
+
# logging.StreamHandler) is True even if handler is FileHandler.
|
106 |
+
# if (type(handler) == logging.StreamHandler) and (handler.stream == sys.stderr):
|
107 |
+
elif type(handler) == logging.StreamHandler:
|
108 |
+
logger.removeHandler(handler)
|
109 |
+
|
110 |
+
file_handler = logging.FileHandler(filename, encoding='utf-8')
|
111 |
+
file_handler.setFormatter(formatter)
|
112 |
+
logger.addHandler(file_handler)
|
113 |
+
if with_print:
|
114 |
+
console_handler = logging.StreamHandler()
|
115 |
+
console_handler.setFormatter(formatter)
|
116 |
+
logger.addHandler(console_handler)
|
117 |
+
return logger
|
118 |
+
|
119 |
+
|
120 |
+
def print_arguments(args):
|
121 |
+
assert isinstance(args, argparse.Namespace)
|
122 |
+
arg_list = sorted(vars(args).items())
|
123 |
+
for key, value in arg_list:
|
124 |
+
print('{}: {}'.format(key, value))
|
125 |
+
|
126 |
+
|
127 |
+
def save_arguments(filename, args, sort=True):
|
128 |
+
assert isinstance(args, argparse.Namespace)
|
129 |
+
args = vars(args)
|
130 |
+
with open(filename, 'w') as f:
|
131 |
+
json.dump(args, f, indent=4, sort_keys=sort)
|
132 |
+
|
133 |
+
|
134 |
+
class DownloadStatusCode(Enum):
|
135 |
+
FILE_SIZE_TOO_LARGE = (-100, 'the size of file from url is too large')
|
136 |
+
FILE_SIZE_TOO_SMALL = (-101, 'the size of file from url is too small')
|
137 |
+
FILE_SIZE_IS_ZERO = (-102, 'the size of file from url is zero')
|
138 |
+
URL_IS_NOT_IMAGE = (-103, 'URL is not an image')
|
139 |
+
|
140 |
+
@property
|
141 |
+
def code(self):
|
142 |
+
return self.value[0]
|
143 |
+
|
144 |
+
@property
|
145 |
+
def message(self):
|
146 |
+
return self.value[1]
|
147 |
+
|
148 |
+
|
149 |
+
class DownloadError(Exception):
|
150 |
+
def __init__(self, status_code: DownloadStatusCode, extra_str: str=None):
|
151 |
+
self.name = status_code.name
|
152 |
+
self.code = status_code.code
|
153 |
+
if extra_str is None:
|
154 |
+
self.message = status_code.message
|
155 |
+
else:
|
156 |
+
self.message = f'{status_code.message}: {extra_str}'
|
157 |
+
Exception.__init__(self)
|
158 |
+
|
159 |
+
def __repr__(self):
|
160 |
+
return f'[{self.__class__.__name__} {self.code}] {self.message}'
|
161 |
+
|
162 |
+
__str__ = __repr__
|
163 |
+
|
164 |
+
|
165 |
+
def download_image(image_url, min_filesize=0, max_filesize=100*1024*1024,
|
166 |
+
params=None, **kwargs) -> bytes:
|
167 |
+
"""
|
168 |
+
References:
|
169 |
+
https://httpwg.org/specs/rfc9110.html#field.content-length
|
170 |
+
https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow
|
171 |
+
"""
|
172 |
+
stream = kwargs.pop('stream', True)
|
173 |
+
|
174 |
+
with requests.get(image_url, stream=stream, params=params, **kwargs) as response:
|
175 |
+
response.raise_for_status()
|
176 |
+
|
177 |
+
content_type = response.headers.get('content-type')
|
178 |
+
if content_type is None:
|
179 |
+
warnings.warn('No Content-Type!')
|
180 |
+
else:
|
181 |
+
if not content_type.startswith(('image/', 'application/octet-stream')):
|
182 |
+
raise DownloadError(DownloadStatusCode.URL_IS_NOT_IMAGE)
|
183 |
+
|
184 |
+
# when Transfer-Encoding == chunked, Content-Length does not exist.
|
185 |
+
content_length = response.headers.get('content-length')
|
186 |
+
if content_length is None:
|
187 |
+
warnings.warn('No Content-Length!')
|
188 |
+
else:
|
189 |
+
content_length = int(content_length)
|
190 |
+
if content_length > max_filesize:
|
191 |
+
raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
|
192 |
+
if content_length < min_filesize:
|
193 |
+
raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
|
194 |
+
|
195 |
+
filesize = 0
|
196 |
+
chunks = []
|
197 |
+
for chunk in response.iter_content(chunk_size=10*1024):
|
198 |
+
chunks.append(chunk)
|
199 |
+
filesize += len(chunk)
|
200 |
+
if filesize > max_filesize:
|
201 |
+
raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
|
202 |
+
if filesize < min_filesize:
|
203 |
+
raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
|
204 |
+
image_bytes = b''.join(chunks)
|
205 |
+
|
206 |
+
return image_bytes
|
207 |
+
|
208 |
+
|
209 |
+
def download_file(url, min_filesize=0, max_filesize=100*1024*1024,
|
210 |
+
params=None, **kwargs) -> bytes:
|
211 |
+
"""
|
212 |
+
References:
|
213 |
+
https://httpwg.org/specs/rfc9110.html#field.content-length
|
214 |
+
https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow
|
215 |
+
"""
|
216 |
+
stream = kwargs.pop('stream', True)
|
217 |
+
|
218 |
+
with requests.get(url, stream=stream, params=params, **kwargs) as response:
|
219 |
+
response.raise_for_status()
|
220 |
+
|
221 |
+
# when Transfer-Encoding == chunked, Content-Length does not exist.
|
222 |
+
content_length = response.headers.get('content-length')
|
223 |
+
if content_length is None:
|
224 |
+
warnings.warn('No Content-Length!')
|
225 |
+
else:
|
226 |
+
content_length = int(content_length)
|
227 |
+
if content_length > max_filesize:
|
228 |
+
raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
|
229 |
+
if content_length < min_filesize:
|
230 |
+
raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
|
231 |
+
|
232 |
+
filesize = 0
|
233 |
+
chunks = []
|
234 |
+
for chunk in response.iter_content(chunk_size=10*1024):
|
235 |
+
chunks.append(chunk)
|
236 |
+
filesize += len(chunk)
|
237 |
+
if filesize > max_filesize:
|
238 |
+
raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
|
239 |
+
if filesize < min_filesize:
|
240 |
+
raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
|
241 |
+
file_bytes = b''.join(chunks)
|
242 |
+
|
243 |
+
return file_bytes
|
244 |
+
|
245 |
+
|
khandy/numpy_utils.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def sigmoid(x):
|
5 |
+
return 1. / (1 + np.exp(-x))
|
6 |
+
|
7 |
+
|
8 |
+
def softmax(x, axis=-1, copy=True):
|
9 |
+
"""
|
10 |
+
Args:
|
11 |
+
copy: Copy x or not.
|
12 |
+
|
13 |
+
Referneces:
|
14 |
+
`from sklearn.utils.extmath import softmax`
|
15 |
+
"""
|
16 |
+
if copy:
|
17 |
+
x = np.copy(x)
|
18 |
+
max_val = np.max(x, axis=axis, keepdims=True)
|
19 |
+
x -= max_val
|
20 |
+
np.exp(x, x)
|
21 |
+
sum_exp = np.sum(x, axis=axis, keepdims=True)
|
22 |
+
x /= sum_exp
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
def log_sum_exp(x, axis=-1, keepdims=False):
|
27 |
+
"""
|
28 |
+
References:
|
29 |
+
numpy.logaddexp
|
30 |
+
numpy.logaddexp2
|
31 |
+
scipy.misc.logsumexp
|
32 |
+
"""
|
33 |
+
max_val = np.max(x, axis=axis, keepdims=True)
|
34 |
+
x -= max_val
|
35 |
+
np.exp(x, x)
|
36 |
+
sum_exp = np.sum(x, axis=axis, keepdims=keepdims)
|
37 |
+
lse = np.log(sum_exp, sum_exp)
|
38 |
+
if not keepdims:
|
39 |
+
max_val = np.squeeze(max_val, axis=axis)
|
40 |
+
return max_val + lse
|
41 |
+
|
42 |
+
|
43 |
+
def l2_normalize(x, axis=None, epsilon=1e-12, copy=True):
|
44 |
+
"""L2 normalize an array along an axis.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
x : array_like of floats
|
48 |
+
Input data.
|
49 |
+
axis : None or int or tuple of ints, optional
|
50 |
+
Axis or axes along which to operate.
|
51 |
+
epsilon: float, optional
|
52 |
+
A small value such as to avoid division by zero.
|
53 |
+
copy : bool, optional
|
54 |
+
Copy x or not.
|
55 |
+
"""
|
56 |
+
if copy:
|
57 |
+
x = np.copy(x)
|
58 |
+
x /= np.maximum(np.linalg.norm(x, axis=axis, keepdims=True), epsilon)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
def minmax_normalize(x, axis=None, epsilon=1e-12, copy=True):
|
63 |
+
"""minmax normalize an array along a given axis.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
x : array_like of floats
|
67 |
+
Input data.
|
68 |
+
axis : None or int or tuple of ints, optional
|
69 |
+
Axis or axes along which to operate.
|
70 |
+
epsilon: float, optional
|
71 |
+
A small value such as to avoid division by zero.
|
72 |
+
copy : bool, optional
|
73 |
+
Copy x or not.
|
74 |
+
"""
|
75 |
+
if copy:
|
76 |
+
x = np.copy(x)
|
77 |
+
|
78 |
+
minval = np.min(x, axis=axis, keepdims=True)
|
79 |
+
maxval = np.max(x, axis=axis, keepdims=True)
|
80 |
+
maxval -= minval
|
81 |
+
maxval = np.maximum(maxval, epsilon)
|
82 |
+
|
83 |
+
x -= minval
|
84 |
+
x /= maxval
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
def zscore_normalize(x, mean=None, std=None, axis=None, epsilon=1e-12, copy=True):
|
89 |
+
"""z-score normalize an array along a given axis.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
x : array_like of floats
|
93 |
+
Input data.
|
94 |
+
mean: array_like of floats, optional
|
95 |
+
mean for z-score
|
96 |
+
std: array_like of floats, optional
|
97 |
+
std for z-score
|
98 |
+
axis : None or int or tuple of ints, optional
|
99 |
+
Axis or axes along which to operate.
|
100 |
+
epsilon: float, optional
|
101 |
+
A small value such as to avoid division by zero.
|
102 |
+
copy : bool, optional
|
103 |
+
Copy x or not.
|
104 |
+
"""
|
105 |
+
if copy:
|
106 |
+
x = np.copy(x)
|
107 |
+
if mean is None:
|
108 |
+
mean = np.mean(x, axis=axis, keepdims=True)
|
109 |
+
if std is None:
|
110 |
+
std = np.std(x, axis=axis, keepdims=True)
|
111 |
+
mean = np.asarray(mean, dtype=x.dtype)
|
112 |
+
std = np.asarray(std, dtype=x.dtype)
|
113 |
+
std = np.maximum(std, epsilon)
|
114 |
+
|
115 |
+
x -= mean
|
116 |
+
x /= std
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
def get_order_of_magnitude(number):
|
121 |
+
number = np.where(number == 0, 1, number)
|
122 |
+
oom = np.floor(np.log10(np.abs(number)))
|
123 |
+
return oom.astype(np.int32)
|
124 |
+
|
125 |
+
|
126 |
+
def top_k(x, k, axis=-1, largest=True, sorted=True):
|
127 |
+
"""Finds values and indices of the k largest/smallest
|
128 |
+
elements along a given axis.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
x: numpy ndarray
|
132 |
+
1-D or higher with given axis at least k.
|
133 |
+
k: int
|
134 |
+
Number of top elements to look for along the given axis.
|
135 |
+
axis: int
|
136 |
+
The axis to sort along.
|
137 |
+
largest: bool
|
138 |
+
Controls whether to return largest or smallest elements
|
139 |
+
sorted: bool
|
140 |
+
If true the resulting k elements will be sorted by the values.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
topk_values:
|
144 |
+
The k largest/smallest elements along the given axis.
|
145 |
+
topk_indices:
|
146 |
+
The indices of the k largest/smallest elements along the given axis.
|
147 |
+
"""
|
148 |
+
if axis is None:
|
149 |
+
axis_size = x.size
|
150 |
+
else:
|
151 |
+
axis_size = x.shape[axis]
|
152 |
+
assert 1 <= k <= axis_size
|
153 |
+
|
154 |
+
x = np.asanyarray(x)
|
155 |
+
if largest:
|
156 |
+
index_array = np.argpartition(x, axis_size-k, axis=axis)
|
157 |
+
topk_indices = np.take(index_array, -np.arange(k)-1, axis=axis)
|
158 |
+
else:
|
159 |
+
index_array = np.argpartition(x, k-1, axis=axis)
|
160 |
+
topk_indices = np.take(index_array, np.arange(k), axis=axis)
|
161 |
+
topk_values = np.take_along_axis(x, topk_indices, axis=axis)
|
162 |
+
if sorted:
|
163 |
+
sorted_indices_in_topk = np.argsort(topk_values, axis=axis)
|
164 |
+
if largest:
|
165 |
+
sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis)
|
166 |
+
sorted_topk_values = np.take_along_axis(
|
167 |
+
topk_values, sorted_indices_in_topk, axis=axis)
|
168 |
+
sorted_topk_indices = np.take_along_axis(
|
169 |
+
topk_indices, sorted_indices_in_topk, axis=axis)
|
170 |
+
return sorted_topk_values, sorted_topk_indices
|
171 |
+
return topk_values, topk_indices
|
172 |
+
|
173 |
+
|
khandy/points/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .pts_letterbox import *
|
2 |
+
from .pts_transform_scale import *
|
khandy/points/pts_letterbox.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__all__ = ['letterbox_2d_points', 'unletterbox_2d_points']
|
2 |
+
|
3 |
+
|
4 |
+
def letterbox_2d_points(points, scale=1.0, pad_left=0, pad_top=0, copy=True):
|
5 |
+
if copy:
|
6 |
+
points = points.copy()
|
7 |
+
points[..., 0::2] = points[..., 0::2] * scale + pad_left
|
8 |
+
points[..., 1::2] = points[..., 1::2] * scale + pad_top
|
9 |
+
return points
|
10 |
+
|
11 |
+
|
12 |
+
def unletterbox_2d_points(points, scale=1.0, pad_left=0, pad_top=0, copy=True):
|
13 |
+
if copy:
|
14 |
+
points = points.copy()
|
15 |
+
|
16 |
+
points[..., 0::2] = (points[..., 0::2] - pad_left) / scale
|
17 |
+
points[..., 1::2] = (points[..., 1::2] - pad_top) / scale
|
18 |
+
return points
|
19 |
+
|
khandy/points/pts_transform_scale.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
__all__ = ['scale_2d_points']
|
4 |
+
|
5 |
+
|
6 |
+
def scale_2d_points(points, x_scale=1, y_scale=1, x_center=0, y_center=0, copy=True):
|
7 |
+
"""Scale 2d points.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
points: (..., 2N)
|
11 |
+
x_scale: scale factor in x dimension
|
12 |
+
y_scale: scale factor in y dimension
|
13 |
+
x_center: scale center in x dimension
|
14 |
+
y_center: scale center in y dimension
|
15 |
+
"""
|
16 |
+
points = np.array(points, dtype=np.float32, copy=copy)
|
17 |
+
x_scale = np.asarray(x_scale, np.float32)
|
18 |
+
y_scale = np.asarray(y_scale, np.float32)
|
19 |
+
x_center = np.asarray(x_center, np.float32)
|
20 |
+
y_center = np.asarray(y_center, np.float32)
|
21 |
+
|
22 |
+
x_shift = 1 - x_scale
|
23 |
+
y_shift = 1 - y_scale
|
24 |
+
x_shift *= x_center
|
25 |
+
y_shift *= y_center
|
26 |
+
|
27 |
+
points[..., 0::2] *= x_scale
|
28 |
+
points[..., 1::2] *= y_scale
|
29 |
+
points[..., 0::2] += x_shift
|
30 |
+
points[..., 1::2] += y_shift
|
31 |
+
return points
|
32 |
+
|
33 |
+
|
khandy/split_utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numbers
|
2 |
+
from collections.abc import Sequence
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def split_by_num(x, num_splits, strict=True):
|
8 |
+
"""
|
9 |
+
Args:
|
10 |
+
num_splits: an integer indicating the number of splits
|
11 |
+
|
12 |
+
References:
|
13 |
+
numpy.split and numpy.array_split
|
14 |
+
"""
|
15 |
+
# NB: np.ndarray is not Sequence
|
16 |
+
assert isinstance(x, (Sequence, np.ndarray))
|
17 |
+
assert isinstance(num_splits, numbers.Integral)
|
18 |
+
|
19 |
+
if strict:
|
20 |
+
assert len(x) % num_splits == 0
|
21 |
+
split_size = (len(x) + num_splits - 1) // num_splits
|
22 |
+
out_list = []
|
23 |
+
for i in range(0, len(x), split_size):
|
24 |
+
out_list.append(x[i: i + split_size])
|
25 |
+
return out_list
|
26 |
+
|
27 |
+
|
28 |
+
def split_by_size(x, sizes):
|
29 |
+
"""
|
30 |
+
References:
|
31 |
+
tf.split
|
32 |
+
https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/misc.py
|
33 |
+
"""
|
34 |
+
# NB: np.ndarray is not Sequence
|
35 |
+
assert isinstance(x, (Sequence, np.ndarray))
|
36 |
+
assert isinstance(sizes, (list, tuple))
|
37 |
+
|
38 |
+
assert sum(sizes) == len(x)
|
39 |
+
out_list = []
|
40 |
+
start_index = 0
|
41 |
+
for size in sizes:
|
42 |
+
out_list.append(x[start_index: start_index + size])
|
43 |
+
start_index += size
|
44 |
+
return out_list
|
45 |
+
|
46 |
+
|
47 |
+
def split_by_slice(x, slices):
|
48 |
+
"""
|
49 |
+
References:
|
50 |
+
SliceLayer in Caffe, and numpy.split
|
51 |
+
"""
|
52 |
+
# NB: np.ndarray is not Sequence
|
53 |
+
assert isinstance(x, (Sequence, np.ndarray))
|
54 |
+
assert isinstance(slices, (list, tuple))
|
55 |
+
|
56 |
+
out_list = []
|
57 |
+
indices = [0] + list(slices) + [len(x)]
|
58 |
+
for i in range(len(slices) + 1):
|
59 |
+
out_list.append(x[indices[i]: indices[i + 1]])
|
60 |
+
return out_list
|
61 |
+
|
62 |
+
|
63 |
+
def split_by_ratio(x, ratios):
|
64 |
+
# NB: np.ndarray is not Sequence
|
65 |
+
assert isinstance(x, (Sequence, np.ndarray))
|
66 |
+
assert isinstance(ratios, (list, tuple))
|
67 |
+
|
68 |
+
pdf = [k / sum(ratios) for k in ratios]
|
69 |
+
cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
|
70 |
+
indices = [int(round(len(x) * k)) for k in cdf]
|
71 |
+
return [x[indices[i]: indices[i + 1]] for i in range(len(ratios))]
|
khandy/text_utils.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def strip_content_in_paren(string):
|
5 |
+
"""
|
6 |
+
Notes:
|
7 |
+
strip_content_in_paren cannot process nested paren correctly
|
8 |
+
"""
|
9 |
+
return re.sub(r"\([^)]*\)|([^)]*)", "", string)
|
10 |
+
|
11 |
+
|
12 |
+
def is_chinese_char(uchar: str) -> bool:
|
13 |
+
"""Whether the input char is a Chinese character.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
uchar: input char in unicode
|
17 |
+
|
18 |
+
References:
|
19 |
+
`is_chinese_char` in https://github.com/thunlp/OpenNRE/
|
20 |
+
"""
|
21 |
+
codepoint = ord(uchar)
|
22 |
+
if ((0x4E00 <= codepoint <= 0x9FFF) or # CJK Unified Ideographs
|
23 |
+
(0x3400 <= codepoint <= 0x4DBF) or # CJK Unified Ideographs Extension A
|
24 |
+
(0xF900 <= codepoint <= 0xFAFF) or # CJK Compatibility Ideographs
|
25 |
+
(0x20000 <= codepoint <= 0x2A6DF) or # CJK Unified Ideographs Extension B
|
26 |
+
(0x2A700 <= codepoint <= 0x2B73F) or
|
27 |
+
(0x2B740 <= codepoint <= 0x2B81F) or
|
28 |
+
(0x2B820 <= codepoint <= 0x2CEAF) or
|
29 |
+
(0x2F800 <= codepoint <= 0x2FA1F)): # CJK Compatibility Supplement
|
30 |
+
return True
|
31 |
+
return False
|
32 |
+
|
33 |
+
|
khandy/time_utils.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import logging
|
3 |
+
import numbers
|
4 |
+
import datetime
|
5 |
+
|
6 |
+
|
7 |
+
def _to_timestamp(val, multiplier=1, rounded=False):
|
8 |
+
if val is None:
|
9 |
+
timestamp = time.time()
|
10 |
+
elif isinstance(val, numbers.Real):
|
11 |
+
timestamp = float(val)
|
12 |
+
elif isinstance(val, time.struct_time):
|
13 |
+
timestamp = time.mktime(val)
|
14 |
+
elif isinstance(val, datetime.datetime):
|
15 |
+
timestamp = val.timestamp()
|
16 |
+
elif isinstance(val, datetime.date):
|
17 |
+
dt = datetime.datetime.combine(val, datetime.time())
|
18 |
+
timestamp = dt.timestamp()
|
19 |
+
elif isinstance(val, str):
|
20 |
+
try:
|
21 |
+
# The full format looks like 'YYYY-MM-DD HH:MM:SS.mmmmmm'.
|
22 |
+
dt = datetime.datetime.fromisoformat(val)
|
23 |
+
timestamp = dt.timestamp()
|
24 |
+
except:
|
25 |
+
raise TypeError('when argument is str, it should conform to isoformat')
|
26 |
+
else:
|
27 |
+
raise TypeError('unsupported type!')
|
28 |
+
timestamp = timestamp * multiplier
|
29 |
+
if rounded:
|
30 |
+
# The return value is an integer if ndigits is omitted or None.
|
31 |
+
timestamp = round(timestamp)
|
32 |
+
return timestamp
|
33 |
+
|
34 |
+
|
35 |
+
def get_timestamp(time_val=None, rounded=True):
|
36 |
+
"""timestamp in seconds.
|
37 |
+
"""
|
38 |
+
return _to_timestamp(time_val, multiplier=1, rounded=rounded)
|
39 |
+
|
40 |
+
|
41 |
+
def get_timestamp_ms(time_val=None, rounded=True):
|
42 |
+
"""timestamp in milliseconds.
|
43 |
+
"""
|
44 |
+
return _to_timestamp(time_val, multiplier=1000, rounded=rounded)
|
45 |
+
|
46 |
+
|
47 |
+
def get_timestamp_us(time_val=None, rounded=True):
|
48 |
+
"""timestamp in microseconds.
|
49 |
+
"""
|
50 |
+
return _to_timestamp(time_val, multiplier=1000000, rounded=rounded)
|
51 |
+
|
52 |
+
|
53 |
+
def get_utc8now() -> datetime.datetime:
|
54 |
+
"""get current UTC-8 time or Beijing time
|
55 |
+
"""
|
56 |
+
tz = datetime.timezone(datetime.timedelta(hours=8))
|
57 |
+
utc8now = datetime.datetime.now(tz)
|
58 |
+
return utc8now
|
59 |
+
|
60 |
+
|
61 |
+
class ContextTimer(object):
|
62 |
+
"""
|
63 |
+
References:
|
64 |
+
WithTimer in https://github.com/uber/ludwig/blob/master/ludwig/utils/time_utils.py
|
65 |
+
"""
|
66 |
+
def __init__(self, name=None, use_log=False, quiet=False):
|
67 |
+
self.use_log = use_log
|
68 |
+
self.quiet = quiet
|
69 |
+
if name is None:
|
70 |
+
self.name = ''
|
71 |
+
else:
|
72 |
+
self.name = '{}, '.format(name.rstrip())
|
73 |
+
|
74 |
+
def __enter__(self):
|
75 |
+
self.start_time = time.time()
|
76 |
+
if not self.quiet:
|
77 |
+
self._print_or_log('{}{} starts'.format(self.name, self._now_time_str))
|
78 |
+
return self
|
79 |
+
|
80 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
81 |
+
if not self.quiet:
|
82 |
+
self._print_or_log('{}elapsed_time = {:.5}s'.format(self.name, self.get_eplased_time()))
|
83 |
+
self._print_or_log('{}{} ends'.format(self.name, self._now_time_str))
|
84 |
+
|
85 |
+
@property
|
86 |
+
def _now_time_str(self):
|
87 |
+
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
|
88 |
+
|
89 |
+
def _print_or_log(self, output_str):
|
90 |
+
if self.use_log:
|
91 |
+
logging.info(output_str)
|
92 |
+
else:
|
93 |
+
print(output_str)
|
94 |
+
|
95 |
+
def get_eplased_time(self):
|
96 |
+
return time.time() - self.start_time
|
97 |
+
|
98 |
+
def enter(self):
|
99 |
+
"""Manually trigger enter"""
|
100 |
+
self.__enter__()
|
101 |
+
|
khandy/version.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = '0.1.8'
|
2 |
+
|
3 |
+
__all__ = ['__version__']
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python>=4.5
|
2 |
+
numpy>=1.11.1
|
3 |
+
lxml
|
4 |
+
requests
|
5 |
+
onnxruntime
|
6 |
+
Pillow
|
7 |
+
modelscope==1.15
|