HERIUN commited on
Commit
6a07cb2
β€’
1 Parent(s): 85dbfc9
DocScanner-L.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d907965aa5d8e99ea8d0891fb66d13bc4f23838547bac6f568d01d480ff8c8a
3
+ size 29328510
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class Config:
5
+ def __init__(self):
6
+ self.current_dir = os.path.dirname(os.path.abspath(__file__))
7
+ self.seg_model_path = os.path.join(self.current_dir, "pretrained", "seg.pth")
8
+ self.rec_model_path = os.path.join(
9
+ self.current_dir, "pretrained", "DocScanner-L.pth"
10
+ )
11
+ self.geotr_model_path = os.path.join(self.current_dir, "pretrained", "model.pt")
12
+ self.save_path = os.path.join(self.current_dir, "output")
13
+
14
+ @property
15
+ def get_seg_model_path(self):
16
+ return self.seg_model_path
17
+
18
+ @property
19
+ def get_rec_model_path(self):
20
+ return self.rec_model_path
21
+
22
+ @property
23
+ def get_geotr_model_path(self):
24
+ return self.geotr_model_path
25
+
26
+ @property
27
+ def get_save_path(self):
28
+ return self.save_path
data_utils/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ sRGB Color Space Profile.icm
3
+ USWebCoatedSWOP.icc
data_utils/__init__.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import pandas as pd
4
+ import time
5
+ from datetime import datetime, timedelta
6
+ from pytz import timezone
7
+ import re
8
+ import json
9
+ import config
10
+ from data_utils.image_utils import (
11
+ load_image,
12
+ resize_coordinates_and_image_to_fit_to_maximum_pixel_counts,
13
+ )
14
+
15
+ import torch
16
+ import os
17
+
18
+ from functools import wraps
19
+ import threading
20
+
21
+ lock = threading.Lock()
22
+
23
+
24
+ def check_gpu():
25
+ if torch.cuda.is_available():
26
+ current_device = torch.cuda.current_device()
27
+ device_name = torch.cuda.get_device_name(current_device)
28
+ print(f"Using GPU Device: {current_device} - {device_name}")
29
+ else:
30
+ print("CUDA is not available.")
31
+
32
+
33
+ def record_and_save_gpu_memory_usage(func): # Add func parameter
34
+ @wraps(func)
35
+ def wrapper(*args, **kwargs):
36
+ torch.cuda.memory._record_memory_history(enabled=True)
37
+
38
+ result = func(*args, **kwargs)
39
+
40
+ torch.cuda.memory._record_memory_history(enabled=False)
41
+
42
+ torch.cuda.memory._save_segment_usage(filename="snapshot/segment_usage.svg")
43
+ torch.cuda.memory._save_memory_usage(filename="snapshot/memory_usage.svg")
44
+
45
+ return result # Ensure the result is returned
46
+
47
+ return wrapper
48
+
49
+
50
+ def measure_gpu_time_and_memory(func):
51
+ @wraps(func)
52
+ def wrapper(*args, **kwargs):
53
+ cuda = kwargs.get("cuda", True) # Default to True if 'cuda' is not provided
54
+
55
+ start_memory = (
56
+ torch.cuda.memory_reserved() if cuda else 0
57
+ ) # Record initial memory
58
+ result = func(*args, **kwargs)
59
+ end_memory = torch.cuda.memory_reserved() if cuda else 0 # Record final memory
60
+
61
+ if cuda:
62
+ print(
63
+ f"{func.__name__} Initial CUDA memory reserved: {start_memory / (1024 ** 3):.2f} GB"
64
+ )
65
+ print(
66
+ f"{func.__name__} Final CUDA memory reserved: {end_memory / (1024 ** 3):.2f} GB"
67
+ )
68
+ print(
69
+ f"{func.__name__} CUDA memory change: {(end_memory - start_memory) / (1024 ** 3):.2f} GB"
70
+ )
71
+
72
+ return result
73
+
74
+ return wrapper
75
+
76
+
77
+ def timeit(func):
78
+ @wraps(func)
79
+ def timeit_wrapper(*args, **kwargs):
80
+ start_time = time.perf_counter()
81
+ result = func(*args, **kwargs)
82
+ end_time = time.perf_counter()
83
+ total_time = end_time - start_time
84
+ if kwargs.get("debug", False):
85
+ print(f"{func.__name__} : {total_time:.4f} sec..")
86
+ # print(f'Function {func.__name__} {args} {kwargs} Took {total_time:.4f} seconds')
87
+ return result
88
+
89
+ return timeit_wrapper
90
+
91
+
92
+ def async_timeit(func):
93
+ @wraps(func)
94
+ async def timeit_wrapper(*args, **kwargs):
95
+ start_time = time.perf_counter()
96
+ result = await func(*args, **kwargs)
97
+ end_time = time.perf_counter()
98
+ total_time = end_time - start_time
99
+ if kwargs.get("debug", False):
100
+ print(f"{func.__name__} : {total_time:.4f} sec..")
101
+ # print(f'Function {func.__name__} {args} {kwargs} Took {total_time:.4f} seconds')
102
+ return result
103
+
104
+ return timeit_wrapper
105
+
106
+
107
+ def thread_func(func):
108
+ @wraps(func)
109
+ def thread_func_wrapper(*args, **kwargs):
110
+ lock.acquire()
111
+ result = func(*args, **kwargs)
112
+ lock.release()
113
+ torch.cuda.empty_cache()
114
+ return result
115
+
116
+ return thread_func_wrapper
117
+
118
+
119
+ def get_arguments():
120
+ parser = argparse.ArgumentParser(description="text_remover")
121
+
122
+ parser.add_argument("--image")
123
+ parser.add_argument("--dir")
124
+ parser.add_argument("--json")
125
+ parser.add_argument("--refine", action="store_true", default=False)
126
+ parser.add_argument("--preserve_resolution", action="store_true", default=False)
127
+ parser.add_argument("--pixel_thresh", type=int)
128
+ # Evaluate text stroke mask
129
+ parser.add_argument("--prepare_kaist", action="store_true", default=False)
130
+ parser.add_argument("--kaist_all_zip")
131
+ parser.add_argument("--data_dir")
132
+
133
+ args = parser.parse_args()
134
+ return args
135
+
136
+
137
+ def get_elapsed_time(start_time):
138
+ return timedelta(seconds=round(time.time() - start_time))
139
+
140
+
141
+ def get_current_time():
142
+ return str(datetime.now(timezone("Asia/Seoul"))).replace(" ", "-").rsplit(".", 1)[0]
143
+
144
+
145
+ def parse_csv_file(path_csv, resize=False):
146
+ df = pd.read_csv(path_csv)
147
+
148
+ ls_rows = list()
149
+ for coor, content in df[["coordinates", "content"]].values:
150
+ coor = re.sub(pattern=r"\(|\)", repl="", string=coor)
151
+ coor = coor.split(",")
152
+
153
+ rect = list(map(int, coor))
154
+ ls_rows.append((rect[2], rect[3], rect[0], rect[1], content))
155
+ bboxes = pd.DataFrame(
156
+ ls_rows, columns=["xmin", "ymin", "xmax", "ymax", "transcript"]
157
+ )
158
+
159
+ bboxes["area"] = bboxes.apply(
160
+ lambda x: (x["xmax"] - x["xmin"]) * (x["ymax"] - x["ymin"]), axis=1
161
+ )
162
+ bboxes.sort_values(["area"], inplace=True)
163
+ bboxes.drop(["area"], axis=1, inplace=True)
164
+
165
+ img_url = df["image_url"].values[0]
166
+ img = load_image(img_url)
167
+
168
+ if resize:
169
+ bboxes, img = resize_coordinates_and_image_to_fit_to_maximum_pixel_counts(
170
+ ha_bboxs=bboxes, img=img
171
+ )
172
+ return bboxes, img, img_url
173
+
174
+
175
+ def parse_json_file(json_path):
176
+ with open(json_path, mode="r") as f:
177
+ req = json.load(f)
178
+
179
+ img_url = req["data"]["data"]["req"]["image_url"]
180
+ img = load_image(img_url)
181
+
182
+ coors = req["data"]["data"]["req"]["coordinates"]
183
+ bboxes = pd.DataFrame(coors, columns=["xmin", "ymin", "xmax", "ymax"])
184
+ return bboxes, img, img_url
185
+
186
+
187
+ def parse_transcription_df(csv_path, index=0):
188
+ df = pd.read_csv(csv_path)
189
+ ls_rows = list()
190
+ for idx, (img_url, df_groupby) in enumerate(df.groupby("image_url")):
191
+ if idx != index:
192
+ continue
193
+ img = load_image(img_url)
194
+
195
+ # for img_url, coor, ori_content, tr_content in df_groupby.values:
196
+ for item_org_id, img_url, coor, ori_content, tr_content in df_groupby.values:
197
+ coor = re.sub(pattern=r"\(|\)|\.0", repl="", string=coor)
198
+ coor = coor.split(",")
199
+ rect = list(map(int, coor))
200
+ # ls_rows.append((rect[2], rect[3], rect[0], rect[1], ori_content, tr_content))
201
+ ls_rows.append(
202
+ (
203
+ item_org_id,
204
+ rect[2],
205
+ rect[3],
206
+ rect[0],
207
+ rect[1],
208
+ ori_content,
209
+ tr_content,
210
+ )
211
+ )
212
+ bboxes = pd.DataFrame(
213
+ # ls_rows, columns=["xmin", "ymin", "xmax", "ymax", "ori_content", "tr_content"]
214
+ ls_rows,
215
+ columns=[
216
+ "item_org_id",
217
+ "xmin",
218
+ "ymin",
219
+ "xmax",
220
+ "ymax",
221
+ "ori_content",
222
+ "tr_content",
223
+ ],
224
+ )
225
+ return bboxes, img, img_url
226
+
227
+
228
+ if __name__ == "__main__":
229
+ pass
230
+ # font = ImageFont.truetype(
231
+ # font="/Users/jongbeomkim/Desktop/workspace/image_processing_server/fonts/NotoSansThai-ExtraBold.ttf",
232
+ # size=round(30),
233
+ # )
data_utils/alarm.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from slack_sdk import WebClient
3
+
4
+
5
+ class Alarm:
6
+ def __init__(self, slack):
7
+ self.url = slack.url
8
+ self.username = slack.username
9
+ self.icon_emoji = slack.icon_emoji
10
+ self.channel_id = slack.channel_id
11
+ self.bot_token = slack.bot_token
12
+ self.client = WebClient(self.bot_token)
13
+
14
+ def _get_color(self, level) -> str:
15
+ if level == "ignore":
16
+ color = "#36A64F" # Green
17
+ elif level == "warning":
18
+ color = "#F08080" # Red
19
+ return color
20
+
21
+ def send(self, level, text):
22
+ color = self._get_color(level)
23
+ message = {
24
+ "attachments": [{"text": text, "color": color}],
25
+ "username": self.username,
26
+ "icon_emoji": self.icon_emoji,
27
+ }
28
+
29
+ requests.post(self.url, json=message)
30
+
31
+ def send_sdk(self, level, text):
32
+ color = self._get_color(level)
33
+
34
+ re = self.client.chat_postMessage(
35
+ channel=self.channel_id,
36
+ attachments=[{"fallback": "fallback", "text": text, "color": color}],
37
+ icon_emoji=self.icon_emoji,
38
+ user_name=self.username,
39
+ )
40
+
41
+ return re.data["ts"]
42
+
43
+ def post_reply_to_thread(self, level, thread_ts, text):
44
+ color = self._get_color(level)
45
+
46
+ self.client.chat_postMessage(
47
+ channel=self.channel_id,
48
+ attachments=[{"fallback": "fallback", "text": text, "color": color}],
49
+ icon_emoji=self.icon_emoji,
50
+ thread_ts=thread_ts,
51
+ user_name=self.username,
52
+ )
53
+
54
+ def post_reaction(self, thread_ts, emoji_name):
55
+ # emoji_name ex. "x", "μ™„λ£Œ"
56
+ self.client.reactions_add(
57
+ channel=self.channel_id, name=emoji_name, timestamp=thread_ts
58
+ )
59
+
60
+
61
+ class AlertLevel:
62
+ IGNORE = "ignore"
63
+ WARNING = "warning"
64
+ MAJOR = "major"
65
+
66
+ @classmethod
67
+ def get_levels(self):
68
+ return [self.IGNORE, self.WARNING, self.MAJOR]
data_utils/awss3.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ import cv2
3
+ import os
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID")
9
+ AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")
10
+
11
+
12
+ class AWSS3:
13
+ def load_image(self, bucket, path, local_path):
14
+ file = self.__s3.get_object(Bucket=bucket, Key=path)
15
+ img_content = file["Body"].read()
16
+
17
+ with open(local_path, "wb") as f:
18
+ f.write(img_content)
19
+
20
+ img = cv2.imread(local_path, cv2.IMREAD_COLOR)
21
+ img = cv2.cvtColor(src=img, code=cv2.COLOR_BGR2RGB)
22
+
23
+ return img
24
+
25
+ def save_image(self, bucket, path, local_path) -> bool:
26
+ with open(local_path, "rb") as f:
27
+ image_content = f.read()
28
+
29
+ if image_content:
30
+ content_type = "image/" + local_path.rsplit(".", 1)[-1].lower().replace(
31
+ "jpg", "jpeg"
32
+ )
33
+ self.__s3.put_object(
34
+ Bucket=bucket,
35
+ Key=path,
36
+ Body=image_content,
37
+ ACL="public-read",
38
+ ContentDisposition="inline",
39
+ ContentType=content_type,
40
+ )
41
+ return True
42
+ else:
43
+ return False
44
+
45
+ def __init__(self):
46
+ self.__s3 = boto3.client(
47
+ "s3",
48
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
49
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
50
+ )
data_utils/box_utils.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import pandas as pd
4
+ import pkg_resources as pkg
5
+ import torch
6
+ import math
7
+ from typing import Tuple
8
+ from data_utils.image_utils import _get_width_and_height
9
+
10
+
11
+ def points_to_xyxy(coords: np.ndarray) -> list:
12
+ x_coords = [coord[0] for coord in coords]
13
+ y_coords = [coord[1] for coord in coords]
14
+ x1 = min(x_coords)
15
+ y1 = min(y_coords)
16
+ x2 = max(x_coords)
17
+ y2 = max(y_coords)
18
+ return [x1, y1, x2, y2]
19
+
20
+
21
+ def xyxy2xywh(x):
22
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
23
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
24
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
25
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
26
+ y[..., 2] = x[..., 2] - x[..., 0] # width
27
+ y[..., 3] = x[..., 3] - x[..., 1] # height
28
+ return y
29
+
30
+
31
+ def xywh2xyxy(x):
32
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
33
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
34
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
35
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
36
+ y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
37
+ y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
38
+ return y
39
+
40
+
41
+ def is_abox_in_bbox(abox_coords, bbox_coords):
42
+ # aboxκ°€ bboxμ•ˆμ— μžˆλŠ”μ§€ ν™•μΈν•˜λŠ” ν•¨μˆ˜. μ’Œν‘œν˜•μ‹. (x1,y1,x2,y2)
43
+ if (
44
+ bbox_coords[0] <= abox_coords[0]
45
+ and bbox_coords[1] <= abox_coords[1]
46
+ and abox_coords[2] <= bbox_coords[2]
47
+ and abox_coords[3] <= bbox_coords[3]
48
+ ):
49
+ return True
50
+ else:
51
+ return False
52
+
53
+
54
+ def calculate_aspect_ratio(box):
55
+ width = box[2] - box[0]
56
+ height = box[3] - box[1]
57
+ aspect_ratio = width / (height + 1e-8)
58
+ return aspect_ratio
59
+
60
+
61
+ def get_box_shape(box, threshold=0.1):
62
+ """
63
+ Check if a box is close to a square.
64
+ - threshold (float): The threshold for considering the box as close to a square.
65
+ Default is 0.1.
66
+ Returns:
67
+ - str: "square" or "horizontal" or "vertical"
68
+ """
69
+ aspect_ratio = calculate_aspect_ratio(box)
70
+ if abs(1 - aspect_ratio) < threshold:
71
+ return "square"
72
+ elif aspect_ratio > 1:
73
+ return "horizontal"
74
+ elif aspect_ratio < 1:
75
+ return "vertical"
76
+
77
+
78
+ def calculate_aspect_ratio_loss(predicted_box, gt_box):
79
+ """predicted_box와 gt_boxκ°„μ˜ κ°€λ‘œμ„Έλ‘œ λΉ„μœ¨μ— λŒ€ν•œ 차이도λ₯Ό λ°˜ν™˜ range:0~1. 클수둝 차이가 ν¬λ‹€λŠ” 뜻."""
80
+ gt_aspect_ratio = calculate_aspect_ratio(gt_box)
81
+ pred_aspect_ratio = calculate_aspect_ratio(predicted_box)
82
+
83
+ ratio_difference = abs(gt_aspect_ratio - pred_aspect_ratio)
84
+
85
+ loss = 2 * math.atan(ratio_difference) / math.pi
86
+
87
+ return loss
88
+
89
+
90
+ def clip_boxes(boxes, shape):
91
+ # Clip boxes (xyxy) to image shape (height, width)
92
+ if isinstance(boxes, torch.Tensor): # faster individually
93
+ boxes[..., 0].clamp_(0, shape[1]) # x1
94
+ boxes[..., 1].clamp_(0, shape[0]) # y1
95
+ boxes[..., 2].clamp_(0, shape[1]) # x2
96
+ boxes[..., 3].clamp_(0, shape[0]) # y2
97
+ else: # np.array (faster grouped)
98
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
99
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
100
+
101
+
102
+ def is_box_overlap(box1, box2):
103
+ # Box overlap checking logic
104
+ if box1[0] > box2[2] or box1[2] < box2[0] or box1[1] > box2[3] or box1[3] < box2[1]:
105
+ return False
106
+ else:
107
+ return True
108
+
109
+
110
+ def intersection_area(box1, box2):
111
+ """
112
+ Calculate the intersection area between two bounding boxes.
113
+
114
+ Parameters:
115
+ - box1, box2: Tuple or list representing the bounding box in the format (x1, y1, x2, y2).
116
+
117
+ Returns:
118
+ - area: Intersection area between the two boxes.
119
+ """
120
+ x1_box1, y1_box1, x2_box1, y2_box1 = box1
121
+ x1_box2, y1_box2, x2_box2, y2_box2 = box2
122
+
123
+ # Calculate intersection coordinates
124
+ x_intersection = max(x1_box1, x1_box2)
125
+ y_intersection = max(y1_box1, y1_box2)
126
+ x_intersection_end = min(x2_box1, x2_box2)
127
+ y_intersection_end = min(y2_box1, y2_box2)
128
+
129
+ # Calculate intersection area
130
+ width_intersection = max(0, x_intersection_end - x_intersection)
131
+ height_intersection = max(0, y_intersection_end - y_intersection)
132
+ area = width_intersection * height_intersection
133
+
134
+ return area
135
+
136
+
137
+ def bbox_iou(box1, box2, GIoU=False, DIoU=False, CIoU=False, CIoU2=False, eps=1e-7):
138
+ """
139
+ Caclulate IoUs(GIoU,DIoU,CIoU,CIoU2)
140
+
141
+ Parameters:
142
+ - box1, box2: Tuple or list representing the bounding box in the format (x1, y1, x2, y2).
143
+
144
+ Returns:
145
+ - IoU or GIoU or DIoU or CIoU or CIoU2
146
+ """
147
+ # Returns Intersection over Union (IoU)
148
+
149
+ # Get the coordinates of bounding boxes
150
+ # x1, y1, x2, y2 = box1
151
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1
152
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2
153
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
154
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
155
+
156
+ # Intersection area
157
+ inter = intersection_area(box1, box2)
158
+
159
+ # Union Area
160
+ union = w1 * h1 + w2 * h2 - inter + eps
161
+
162
+ iou = inter / union
163
+
164
+ if CIoU or DIoU or GIoU or CIoU2:
165
+ cw = max(b1_x2, b2_x2) - min(
166
+ b1_x1, b2_x1
167
+ ) # convex (smallest enclosing box) width
168
+ ch = max(b1_y2, b2_y2) - min(b1_y1, b2_y1) # convex height
169
+ c_area = cw * ch + eps # convex area
170
+ giou_penalty = (c_area - union) / c_area
171
+ if GIoU: # GIoU https://arxiv.org/pdf/1902.09630.pdf
172
+ return round(iou - giou_penalty, 4) # GIoU
173
+ elif (
174
+ DIoU or CIoU
175
+ ): # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
176
+ rho2 = (
177
+ (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
178
+ + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
179
+ ) / 4 # center dist ** 2
180
+ c2 = cw**2 + ch**2 + eps # convex diagonal squared
181
+ diou_penalty = rho2 / c2
182
+ if DIoU:
183
+ return round(iou - diou_penalty, 4) # DIoU
184
+ if CIoU or CIoU2:
185
+ v = (4 / math.pi**2) * (
186
+ (np.arctan((w2 / h2)) - np.arctan(w1 / h1)) ** 2
187
+ )
188
+ alpha = v / (v - iou + (1 + eps))
189
+ ciou_penalty = diou_penalty + alpha * v
190
+ if CIoU2:
191
+ ciou2_penalty = giou_penalty + diou_penalty + alpha * v
192
+ return round(iou - ciou2_penalty) # CIoU2
193
+ return round(iou - ciou_penalty, 4) # CIoU
194
+
195
+ return round(iou, 4) # IoU
196
+
197
+
198
+ def rotate_around_point(x, y, pivot_x, pivot_y, degrees) -> Tuple[int, int]:
199
+ """주어진 μ’Œν‘œ (x,y)λ₯Ό μΆ• μ’Œν‘œ(pivot_x,pivot_y_λ₯Ό κΈ°μ€€μœΌλ‘œ λ°˜μ‹œκ³„ λ°©ν–₯으둜 νšŒμ „. return new_x,new_y"""
200
+
201
+ # 각도λ₯Ό λΌλ””μ•ˆμœΌλ‘œ λ³€ν™˜
202
+ angle_radians = np.radians(degrees)
203
+
204
+ # νšŒμ „ λ³€ν™˜ 적용
205
+ x_new = (
206
+ pivot_x
207
+ + np.cos(angle_radians) * (x - pivot_x)
208
+ - np.sin(angle_radians) * (y - pivot_y)
209
+ )
210
+ y_new = (
211
+ pivot_y
212
+ + np.sin(angle_radians) * (x - pivot_x)
213
+ + np.cos(angle_radians) * (y - pivot_y)
214
+ )
215
+
216
+ return int(x_new), int(y_new)
217
+
218
+
219
+ def rotate_box_coordinates_on_pivot(x1, y1, x2, y2, degrees, pivot_x, pivot_y):
220
+ """주어진 box μ’Œν‘œ(x1,y1,x2,y2)λ₯Ό 주어진 μΆ• μ’Œν‘œ(pivot_x,pivot_y)에 λŒ€ν•΄ μ‹œκ³„ λ°©ν–₯으둜 νšŒμ „"""
221
+ radians = np.radians(degrees)
222
+ rotation_matrix = np.array(
223
+ [[np.cos(radians), -np.sin(radians)], [np.sin(radians), np.cos(radians)]]
224
+ )
225
+
226
+ # μƒμž μ’Œν‘œλ₯Ό 쀑심을 κΈ°μ€€μœΌλ‘œ νšŒμ „
227
+ box_coordinates = np.array(
228
+ [
229
+ [x1 - pivot_x, y1 - pivot_y],
230
+ [x2 - pivot_x, y1 - pivot_y],
231
+ [x2 - pivot_x, y2 - pivot_y],
232
+ [x1 - pivot_x, y2 - pivot_y],
233
+ ]
234
+ )
235
+
236
+ rotated_box_coordinates = np.dot(box_coordinates, rotation_matrix.T)
237
+
238
+ # νšŒμ „ ν›„ μ’Œν‘œμ— 쀑심 μ’Œν‘œλ₯Ό 더해 μ›λž˜ μ’Œν‘œλ‘œ λ³€ν™˜
239
+ rotated_box_coordinates += np.array([pivot_y, pivot_x])
240
+
241
+ # λ³€ν™˜λœ μ’Œν‘œλ₯Ό μƒˆλ‘œμš΄ μƒμž μ’Œν‘œλ‘œ λ°˜ν™˜
242
+ new_x1, new_y1 = rotated_box_coordinates.min(axis=0)
243
+ new_x2, new_y2 = rotated_box_coordinates.max(axis=0)
244
+
245
+ return int(new_x1), int(new_y1), int(new_x2), int(new_y2)
246
+
247
+
248
+ def bbox_iou_torch(
249
+ box1, box2, xywh=False, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
250
+ ):
251
+ # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
252
+
253
+ # Get the coordinates of bounding boxes
254
+ if xywh: # transform from xywh to xyxy
255
+ (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
256
+ w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
257
+ b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
258
+ b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
259
+ else: # x1, y1, x2, y2 = box1
260
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
261
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
262
+ w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
263
+ w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
264
+
265
+ # Intersection area
266
+ inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * (
267
+ b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
268
+ ).clamp(0)
269
+
270
+ # Union Area
271
+ union = w1 * h1 + w2 * h2 - inter + eps
272
+
273
+ # IoU
274
+ iou = inter / union
275
+ if CIoU or DIoU or GIoU:
276
+ cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(
277
+ b2_x1
278
+ ) # convex (smallest enclosing box) width
279
+ ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
280
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
281
+ c2 = cw**2 + ch**2 + eps # convex diagonal squared
282
+ rho2 = (
283
+ (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
284
+ + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
285
+ ) / 4 # center dist ** 2
286
+ if (
287
+ CIoU
288
+ ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
289
+ v = (4 / math.pi**2) * (
290
+ torch.atan(w2 / h2) - torch.atan(w1 / h1)
291
+ ).pow(2)
292
+ with torch.no_grad():
293
+ alpha = v / (v - iou + (1 + eps))
294
+ return iou - (rho2 / c2 + v * alpha) # CIoU
295
+ return iou - rho2 / c2 # DIoU
296
+ c_area = cw * ch + eps # convex area
297
+ return (
298
+ iou - (c_area - union) / c_area
299
+ ) # GIoU https://arxiv.org/pdf/1902.09630.pdf
300
+ return iou # IoU
301
+
302
+
303
+ def generate_random_box(width_range, height_range):
304
+ """
305
+ Generate random bounding box coordinates (x1, y1, x2, y2) with random width and height.
306
+
307
+ Parameters:
308
+ - width_range: Tuple representing the range of width values (min_width, max_width).
309
+ - height_range: Tuple representing the range of height values (min_height, max_height).
310
+
311
+ Returns:
312
+ - box: Tuple representing the bounding box in the format (x1, y1, x2, y2).
313
+ """
314
+ min_width, max_width = width_range
315
+ min_height, max_height = height_range
316
+
317
+ width = np.random.randint(min_width, max_width)
318
+ height = np.random.randint(min_height, max_height)
319
+
320
+ x1 = np.random.randint(0, 100 - width)
321
+ y1 = np.random.randint(0, 100 - height)
322
+ x2 = x1 + width
323
+ y2 = y1 + height
324
+
325
+ return x1, y1, x2, y2
326
+
327
+
328
+ def mask_to_bboxes(mask, margin_rate=2, pixel_thresh=300) -> pd.DataFrame:
329
+ nlabels, segmap, stats, centroids = cv2.connectedComponentsWithStats(
330
+ image=mask, connectivity=4
331
+ )
332
+ bboxes = pd.DataFrame(
333
+ stats[1:, :], columns=["bbox_x1", "bbox_y1", "width", "height", "pixel_count"]
334
+ )
335
+ img_width, img_height = _get_width_and_height(mask)
336
+
337
+ bboxes = bboxes[bboxes["pixel_count"].ge(pixel_thresh)]
338
+
339
+ bboxes["bbox_x2"] = bboxes["bbox_x1"] + bboxes["width"]
340
+ bboxes["bbox_y2"] = bboxes["bbox_y1"] + bboxes["height"]
341
+
342
+ bboxes["margin"] = bboxes.apply(
343
+ lambda x: int(
344
+ math.sqrt(
345
+ x["pixel_count"]
346
+ * min(x["width"], x["height"])
347
+ / (x["width"] * x["height"])
348
+ )
349
+ * margin_rate
350
+ ),
351
+ axis=1,
352
+ )
353
+ bboxes["bbox_x1"] = bboxes.apply(
354
+ lambda x: max(0, x["bbox_x1"] - x["margin"]), axis=1
355
+ )
356
+ bboxes["bbox_y1"] = bboxes.apply(
357
+ lambda x: max(0, x["bbox_y1"] - x["margin"]), axis=1
358
+ )
359
+ bboxes["bbox_x2"] = bboxes.apply(
360
+ lambda x: min(img_width, x["bbox_x2"] + x["margin"]), axis=1
361
+ )
362
+ bboxes["bbox_y2"] = bboxes.apply(
363
+ lambda x: min(img_height, x["bbox_y2"] + x["margin"]), axis=1
364
+ )
365
+ bboxes = bboxes[["bbox_x1", "bbox_y1", "bbox_x2", "bbox_y2"]]
366
+ img_width, img_height = _get_width_and_height(mask)
367
+ if img_width >= img_height:
368
+ bboxes.sort_values(by=["bbox_x1", "bbox_y1"], inplace=True)
369
+ else:
370
+ bboxes.sort_values(by=["bbox_y1", "bbox_x1"], inplace=True)
371
+
372
+ return bboxes
373
+
374
+
375
+ def bbox_to_mask(bboxes: list, mask_size):
376
+ """
377
+ Creates a mask image based on bounding box coordinates.
378
+
379
+ Args:
380
+ - bboxes: list (x_min, y_min, x_max, y_max) representing the bounding box coordinates.
381
+ - mask_size: Tuple (height, width) representing the size of the mask image to be created.
382
+
383
+ Returns:
384
+ - Mask image with the specified bounding box area filled with white.
385
+ """
386
+ # Initialize a black mask image with the specified size
387
+ mask = np.zeros(mask_size, dtype=np.uint8)
388
+ # mask = np.zeros_like(img).astype("uint8")
389
+
390
+ for bbox in bboxes:
391
+ # Extract bounding box coordinates
392
+ x_min, y_min, x_max, y_max = bbox
393
+
394
+ # Ensure bbox coordinates are within mask bounds
395
+ x_min = max(0, x_min)
396
+ y_min = max(0, y_min)
397
+ x_max = min(mask_size[1], x_max)
398
+ y_max = min(mask_size[0], y_max)
399
+
400
+ # Fill the bounding box area with white color in the mask image
401
+ mask[y_min:y_max, x_min:x_max] = 255
402
+
403
+ return mask
404
+
405
+
406
+ def move_box_a_to_center_of_box_b(A, B):
407
+ # A와 B의 μ’Œν‘œ (l, t, r, b)
408
+ lA, tA, rA, bA = A
409
+ lB, tB, rB, bB = B
410
+
411
+ # λ°•μŠ€ A의 λ„ˆλΉ„μ™€ 높이
412
+ width_A = rA - lA
413
+ height_A = bA - tA
414
+
415
+ # λ°•μŠ€ B의 쀑심 μ’Œν‘œ
416
+ center_x_B = (lB + rB) / 2
417
+ center_y_B = (tB + bB) / 2
418
+
419
+ # λ°•μŠ€ A의 μƒˆλ‘œμš΄ μ’Œν‘œ (쀑심을 B의 μ€‘μ‹¬μœΌλ‘œ 이동)
420
+ new_lA = center_x_B - width_A / 2
421
+ new_tA = center_y_B - height_A / 2
422
+ new_rA = center_x_B + width_A / 2
423
+ new_bA = center_y_B + height_A / 2
424
+
425
+ # μƒˆλ‘œμš΄ A λ°•μŠ€μ˜ μ’Œν‘œ λ°˜ν™˜
426
+ return (new_lA, new_tA, new_rA, new_bA)
427
+
428
+
429
+ def scale_bboxes(bboxes, max_x, max_y, x_scale_factor=1.2, y_scale_factor=1.05):
430
+ # κΈ°μ‘΄ μ’Œν‘œμ—μ„œ 각 λ°•μŠ€μ˜ 쀑심 μ’Œν‘œ, λ„ˆλΉ„, 높이 계산
431
+ bboxes["cx"] = (bboxes["bbox_x1"] + bboxes["bbox_x2"]) / 2
432
+ bboxes["cy"] = (bboxes["bbox_y1"] + bboxes["bbox_y2"]) / 2
433
+ bboxes["width"] = bboxes["bbox_x2"] - bboxes["bbox_x1"]
434
+ bboxes["height"] = bboxes["bbox_y2"] - bboxes["bbox_y1"]
435
+
436
+ # 각 λ°•μŠ€μ˜ 크기λ₯Ό 1.2배둜 늘림
437
+ bboxes["new_width"] = bboxes["width"] * x_scale_factor
438
+ bboxes["new_height"] = bboxes["height"] * y_scale_factor
439
+
440
+ # μƒˆλ‘œμš΄ μ’Œν‘œ 계산
441
+ bboxes["new_x1"] = bboxes["cx"] - bboxes["new_width"] / 2
442
+ bboxes["new_y1"] = bboxes["cy"] - bboxes["new_height"] / 2
443
+ bboxes["new_x2"] = bboxes["cx"] + bboxes["new_width"] / 2
444
+ bboxes["new_y2"] = bboxes["cy"] + bboxes["new_height"] / 2
445
+
446
+ # box λ²”μœ„ μ œν•œ
447
+ bboxes["new_x1"] = bboxes["new_x1"].clip(lower=0).astype(int)
448
+ bboxes["new_y1"] = bboxes["new_y1"].clip(lower=0).astype(int)
449
+ bboxes["new_x2"] = bboxes["new_x2"].clip(upper=max_x).astype(int)
450
+ bboxes["new_y2"] = bboxes["new_y2"].clip(upper=max_y).astype(int)
451
+
452
+ # κ²°κ³Ό λ°μ΄ν„°ν”„λ ˆμž„ 생성
453
+ new_bboxes = bboxes[
454
+ ["ori_content", "new_x1", "new_y1", "new_x2", "new_y2", "predicted_lang"]
455
+ ].copy()
456
+ new_bboxes.columns = [
457
+ "ori_content",
458
+ "bbox_x1",
459
+ "bbox_y1",
460
+ "bbox_x2",
461
+ "bbox_y2",
462
+ "predicted_lang",
463
+ ]
464
+
465
+ return new_bboxes
466
+
467
+
468
+ if __name__ == "__main__":
469
+ w_range = (100, 200)
470
+ h_range = (100, 200)
471
+
472
+ box1 = generate_random_box(w_range, h_range)
473
+ box2 = generate_random_box(w_range, h_range)
474
+
475
+ print(f"box1 coors : {box1}")
476
+ print(f"box2 coors : {box2}")
477
+
478
+ print(f"intersection area : {intersection_area(box1,box2)}")
479
+ iou = bbox_iou(box1, box2)
480
+ giou = bbox_iou(box1, box2, GIoU=True)
481
+ diou = bbox_iou(box1, box2, DIoU=True)
482
+ ciou = bbox_iou(box1, box2, CIoU=True)
483
+ print(iou, giou, diou, ciou)
data_utils/color_utils.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import cv2
4
+ import convcolors
5
+ import matplotlib.pyplot as plt
6
+ from colormap import rgb2hex
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
9
+ import extcolors
10
+ from skimage.color import deltaE_cie76
11
+ import math
12
+
13
+ from data_utils import timeit
14
+
15
+ from data_utils.image_utils import (
16
+ _to_pil,
17
+ _get_pseudo_image,
18
+ _mask_image,
19
+ _get_width_and_height,
20
+ _resize_image,
21
+ _figure_to_array,
22
+ load_image,
23
+ _to_2d,
24
+ )
25
+
26
+ np.set_printoptions(precision=3, edgeitems=20, linewidth=sys.maxsize, suppress=False)
27
+
28
+
29
+ def _to_tuple(color):
30
+ if isinstance(color, tuple):
31
+ return color
32
+ elif isinstance(color, str):
33
+ if color[:3] == "rgb":
34
+ return eval(color.replace("rgb", ""))
35
+ elif color[:3] == "lab":
36
+ return eval(color.replace("lab", ""))
37
+ elif isinstance(color, np.ndarray):
38
+ return tuple(color)
39
+
40
+
41
+ def _to_str(color, color_space):
42
+ if isinstance(color, str):
43
+ return color
44
+ elif isinstance(color, tuple):
45
+ if color_space == "rgb":
46
+ return f"""rgb{color}"""
47
+ elif color_space == "lab":
48
+ return f"""lab{color}"""
49
+
50
+
51
+ def _to_rgb(color):
52
+ if isinstance(color, str):
53
+ if color[:3] == "rgb":
54
+ color = eval(color.replace("rgb", ""))
55
+ return color
56
+ elif color[:3] == "lab":
57
+ color = eval(color.replace("lab", ""))
58
+ color = convcolors.lab_to_rgb(color)
59
+ color = tuple([round(i) for i in color])
60
+ return _to_str(color, color_space="rgb")
61
+
62
+
63
+ def _to_lab(color):
64
+ if isinstance(color, str):
65
+ if color[:3] == "rgb":
66
+ color = eval(color.replace("rgb", ""))
67
+ color = convcolors.rgb_to_lab(color)
68
+ color = tuple([round(i) for i in color])
69
+ return _to_str(color, color_space="lab")
70
+ elif color[:3] == "lab":
71
+ return color
72
+
73
+
74
+ def _extract_colors(img, mask=None, invert=False, tolerance=10, limit=4):
75
+ # img(H,W,3), mask(H,W)
76
+ if mask is None or np.any(mask) == False:
77
+ pseudo_outer = img
78
+ else:
79
+ pseudo_outer = _get_pseudo_image(img=img, mask=mask, invert=invert)
80
+
81
+ colors = extcolors.extract_from_image(
82
+ img=_to_pil(pseudo_outer), tolerance=tolerance, limit=limit
83
+ )[0]
84
+ sum_freqs = sum([i[1] for i in colors])
85
+
86
+ return [
87
+ {
88
+ "rgb": rgb,
89
+ "hex_code": rgb2hex(*rgb),
90
+ "percentage": round(freq / sum_freqs, 3),
91
+ }
92
+ for rgb, freq in colors
93
+ ]
94
+
95
+
96
+ def get_palette(colors, img=None, mask=None, invert=False, index=None, zoom=4):
97
+ rgbs = [i["rgb"] for i in colors]
98
+ pers = [i["percentage"] for i in colors]
99
+ hex_codes = [i["hex_code"] for i in colors]
100
+
101
+ labels = [
102
+ f"""{str(rgb)}\n{str(round(per * 100, 1))}%""" for rgb, per in zip(rgbs, pers)
103
+ ]
104
+ explode = [0] * len(rgbs)
105
+ if index is not None:
106
+ explode[index] = 0.05
107
+
108
+ fig, ax = plt.subplots(figsize=(30, 20), dpi=15)
109
+ wedges, _ = ax.pie(
110
+ x=pers,
111
+ labels=labels,
112
+ labeldistance=1.07,
113
+ colors=hex_codes,
114
+ textprops={"fontsize": 50, "color": "black"},
115
+ wedgeprops={"edgecolor": "black", "linewidth": 7},
116
+ startangle=90,
117
+ radius=1,
118
+ counterclock=False,
119
+ explode=explode,
120
+ )
121
+ plt.setp(wedges, width=0.3)
122
+ plt.setp(wedges, width=0.26)
123
+
124
+ ax.set_aspect("equal")
125
+
126
+ if img is not None:
127
+ if mask is not None:
128
+ img = _mask_image(img=img, mask=mask, invert=invert)
129
+ w, h = _get_width_and_height(img)
130
+ if w >= h:
131
+ resized_img = _resize_image(img=img, w=400, h=int(400 * h / w))
132
+ else:
133
+ resized_img = _resize_image(img=img, w=int(400 * w / h), h=400)
134
+ offset_img = OffsetImage(resized_img.astype("float32") / 255, zoom=zoom)
135
+ annot_box = AnnotationBbox(offsetbox=offset_img, xy=(0, 0))
136
+ ax.add_artist(annot_box)
137
+ fig.tight_layout()
138
+ palette = _figure_to_array(fig)
139
+
140
+ plt.close()
141
+ return palette
142
+
143
+
144
+ def _get_complementary_color(color):
145
+ if isinstance(color, str):
146
+ color = _to_tuple(color)
147
+ return f"""rgb{tuple([255 - rgb for rgb in color])}"""
148
+ if isinstance(color, tuple):
149
+ return tuple([255 - rgb for rgb in color])
150
+
151
+
152
+ def _linearize(x):
153
+ if isinstance(x, np.ndarray):
154
+ x = x.astype("float64")
155
+ x /= 255
156
+ return np.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
157
+ elif isinstance(x, (int, np.uint8)):
158
+ x /= 255
159
+ if x <= 0.04045:
160
+ return x / 12.92
161
+ else:
162
+ return ((x + 0.055) / 1.055) ** 2.4
163
+
164
+
165
+ def _get_relative_luminance(x):
166
+ if isinstance(x, str):
167
+ return _get_relative_luminance(_to_tuple(_to_rgb(x)))
168
+ elif isinstance(x, np.ndarray):
169
+ x = _linearize(x)
170
+ return np.round(0.2126 * x[..., 0] + 0.7152 * x[..., 1] + 0.0722 * x[..., 2], 3)
171
+ elif isinstance(x, tuple):
172
+ assert len(x) == 3, "If the argument `x` is tuple, it should have 3 elements."
173
+
174
+ return round(
175
+ 0.2126 * _linearize(x[0])
176
+ + 0.7152 * _linearize(x[1])
177
+ + 0.0722 * _linearize(x[2]),
178
+ 3,
179
+ )
180
+
181
+
182
+ def rgb_to_lab(rgb: tuple):
183
+ rgb = np.uint8([[list(rgb)]])
184
+ lab = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB)
185
+ return tuple(lab[0][0])
186
+
187
+
188
+ def lab_to_rgb(lab):
189
+ lab = np.uint8([[list(lab)]])
190
+ rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
191
+ return tuple(rgb[0][0])
192
+
193
+
194
+ def get_contrast(x, y):
195
+ l1 = _get_relative_luminance(x)
196
+ l2 = _get_relative_luminance(y)
197
+
198
+ if isinstance(l1, float) and isinstance(l2, float):
199
+ return (
200
+ round((l1 + 0.05) / (l2 + 0.05), 1)
201
+ if l1 > l2
202
+ else round((l2 + 0.05) / (l1 + 0.05), 1)
203
+ )
204
+ elif isinstance(l1, np.ndarray) or isinstance(l2, np.ndarray):
205
+ return np.where(
206
+ l1 > l2,
207
+ np.round((l1 + 0.05) / (l2 + 0.05), 1),
208
+ np.round((l2 + 0.05) / (l1 + 0.05), 1),
209
+ )
210
+
211
+
212
+ def adjust_luminance_for_contrast(color1, color2, th=4.5):
213
+ """color2(text_color)λ₯Ό color1(back ground color)κ³Ό λŒ€λΉ„λ„κ°€ th 이상이 λ˜λ„λ‘ color2(의 λͺ…도)λ₯Ό μˆ˜μ •ν•©λ‹ˆλ‹€.
214
+ lab ν˜•μ‹μ„ μ‚¬μš©ν•˜κΈ° λ•Œλ¬Έμ— color2의 색 λ³€ν™”λ₯Ό μ΅œμ†Œν™” ν•©λ‹ˆλ‹€.
215
+ color1이 μ–΄λ‘‘λ‹€λ©΄, color2λŠ” 밝아지고, color1이 밝닀면, color2λŠ” μ–΄λ‘μ›Œ μ§‘λ‹ˆλ‹€.
216
+
217
+ Args:
218
+ color1 (_tuple_): κΈ°μ€€ 색. ν•΄λ‹Ή 색은 λ°”λ€Œμ§€ μ•ŠλŠ” 색이며, λŒ€λΉ„λ„ μΈ‘μ •μ˜ 기쀀이 λ˜λŠ” μƒ‰μž…λ‹ˆλ‹€. ex.back ground color
219
+ color2 (_tuple_): λ³€ν™”λ₯Ό 쀄 색. ex. text color
220
+ th (float, optional): λŒ€λΉ„λ„ μž„κ³„κ°’. color1κ³Ό color2의 λŒ€λΉ„λ„κ°€ ν•΄λ‹Ή 수치 이상이 λ˜λŠ”κ²ƒμ„ λͺ©ν‘œλ‘œ ν•©λ‹ˆλ‹€. Defaults to 4.5.
221
+ type (str, optional): 색상 νƒ€μž… "rgb" or "lab". Defaults to "rgb".
222
+
223
+ Returns:
224
+ _tuple_: new color2
225
+ """
226
+ initial_cont = get_contrast(color1, color2)
227
+ if initial_cont >= th:
228
+ return color2
229
+ lab1 = rgb_to_lab(color1)
230
+ lab2 = rgb_to_lab(color2)
231
+
232
+ plus_cont, minus_cont = initial_cont, initial_cont
233
+ plus_l, minus_l = lab2, lab2
234
+ max_iterations = 100
235
+ plus_iteration = 0
236
+ minus_iteration = 0
237
+ step = 3
238
+
239
+ if lab1[0] >= 127:
240
+ while minus_iteration < max_iterations: # minus iteration
241
+ minus_l = (min(minus_l[0] - step, 255), minus_l[1], minus_l[2])
242
+ minus_cont = get_contrast(lab_to_rgb(lab1), lab_to_rgb(minus_l))
243
+
244
+ if minus_cont >= th:
245
+ return lab_to_rgb(minus_l)
246
+ minus_iteration += 1
247
+ else:
248
+ while plus_iteration < max_iterations: # plus iteration
249
+ plus_l = (min(plus_l[0] + step, 255), plus_l[1], plus_l[2])
250
+ plus_cont = get_contrast(lab_to_rgb(lab1), lab_to_rgb(plus_l))
251
+
252
+ if plus_cont >= th:
253
+ return lab_to_rgb(plus_l)
254
+ plus_iteration += 1
255
+ return color2
256
+
257
+
258
+ def get_readability(color, bg, contrast_thresh=2.5):
259
+ contrast = get_contrast(_to_tuple(color), bg)
260
+ below_thresh = contrast[contrast < contrast_thresh]
261
+ if below_thresh.size == 0:
262
+ return 21
263
+ else:
264
+ return below_thresh.mean()
265
+
266
+
267
+ def _blend_two_colors(color1, color2, ratio=0.5):
268
+ blended = np.array(_to_tuple(_to_lab(color1))) * ratio + np.array(
269
+ _to_tuple(_to_lab(color2))
270
+ ) * (1 - ratio)
271
+ blended = _to_rgb(_to_str(_to_tuple(blended), color_space="lab"))
272
+ return blended
273
+
274
+
275
+ def get_colorfulness(img):
276
+ try:
277
+ r, g, b = cv2.split(img.astype("float"))
278
+ rg = np.absolute(r - g)
279
+ yb = np.absolute((r + g) / 2 - b)
280
+ rg_mean, rg_std = np.mean(rg), np.std(rg)
281
+ yb_mean, yb_std = np.mean(yb), np.std(yb)
282
+ std_root = np.sqrt((rg_std**2) + (yb_std**2))
283
+ mean_root = np.sqrt((rg_mean**2) + (yb_mean**2))
284
+ colorfulness = std_root + (0.3 * mean_root)
285
+ except ValueError:
286
+ colorfulness = 0
287
+ return colorfulness
288
+
289
+
290
+ def get_colorfulness_by_extracting_colors(img, limit=20):
291
+ colors = _extract_colors(img=img, tolerance=10, limit=limit)
292
+ pers = [i["percentage"] for i in colors]
293
+ colorfulness = (np.array(pers).cumsum() < 0.98).sum()
294
+ return colorfulness
295
+
296
+
297
+ def _colors_to_pseudo_image(colors):
298
+ pseudo_img = np.array(colors, dtype="uint8")[None, ...]
299
+ return pseudo_img
300
+
301
+
302
+ def _pick_most_colors(colors, tolerance):
303
+ pseudo_img = _colors_to_pseudo_image(colors)
304
+ most_colors = extcolors.extract_from_image(
305
+ img=_to_pil(pseudo_img), tolerance=tolerance, limit=len(colors)
306
+ )[0]
307
+ most_colors = [i[0] for i in most_colors]
308
+ return most_colors
309
+
310
+
311
+ def _get_euclidean_distance(color1, color2):
312
+ return deltaE_cie76(
313
+ np.array(convcolors.rgb_to_lab(color1))[None, None, ...],
314
+ np.array(convcolors.rgb_to_lab(color2))[None, None, ...],
315
+ )[0][0]
316
+
317
+
318
+ def is_similar_black_or_gray(color) -> str:
319
+ # color = (R,G,B)
320
+ black_distance = _get_euclidean_distance(color, (0, 0, 0))
321
+ gray_distance = _get_euclidean_distance(color, (128, 128, 128))
322
+ if black_distance < gray_distance:
323
+ return "black"
324
+ else:
325
+ return "gray"
326
+
327
+
328
+ def is_similar_white_or_gray(color) -> str:
329
+ # color = (R,G,B)
330
+ gray_distance = _get_euclidean_distance(color, (128, 128, 128))
331
+ white_distance = _get_euclidean_distance(color, (255, 255, 255))
332
+ if white_distance < gray_distance:
333
+ return "white"
334
+ else:
335
+ return "gray"
336
+
337
+
338
+ def is_similar_white_or_black(color) -> str:
339
+ # color = (R,G,B)
340
+ black_distance = _get_euclidean_distance(color, (0, 0, 0))
341
+ white_distance = _get_euclidean_distance(color, (255, 255, 255))
342
+ if white_distance < black_distance:
343
+ return "white"
344
+ else:
345
+ return "black"
346
+
347
+
348
+ def view_hist(img):
349
+ color = ("r", "g", "b")
350
+ for i, col in enumerate(color):
351
+ hist = cv2.calcHist([img], [i], None, [256], [0, 256])
352
+ plt.plot(hist, color=col)
353
+ plt.savefig("calc_hist.png")
354
+
355
+
356
+ def normalize_image(img):
357
+ img_norm = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
358
+ return img_norm
359
+
360
+
361
+ def equalize_hist(img):
362
+ gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
363
+ # hist = cv2.calcHist([gray_img],[0],None,[256],[0,256])
364
+ # ycrcb_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
365
+ # ycrcb_img[:, :, 0] = cv2.equalizeHist(ycrcb_img[:, :, 0])
366
+
367
+ # equalized_img = cv2.cvtColor(ycrcb_img, cv2.COLOR_YCrCb2RGB)
368
+
369
+ # make contras limiting adaptive histogram equalization
370
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
371
+ equalized_img = clahe.apply(gray_img)
372
+
373
+ return equalized_img
374
+
375
+
376
+ def is_gray(color, threshold=30):
377
+ r, g, b = map(int, color) # color type is np.uint8. cast int to prevent overflow
378
+ if abs(r - g) < threshold and abs(r - b) < threshold and abs(g - b) < threshold:
379
+ return True
380
+ return False
381
+
382
+
383
+ def merge_similar_colors(colors, tolerance=10):
384
+ most_colors = _pick_most_colors(colors, tolerance=tolerance)
385
+
386
+ new_colors = list()
387
+ for color in colors:
388
+ minim = math.inf
389
+ for most_color in most_colors:
390
+ dist = _get_euclidean_distance(color, most_color)
391
+ if dist < minim:
392
+ picked = most_color
393
+ minim = dist
394
+ new_colors.append(picked)
395
+ return new_colors
396
+
397
+
398
+ def merge_colors(colors, tolerance=10):
399
+ temp = [eval(i[3:]) for i in colors]
400
+ # print(len(set(temp)))
401
+ pseudo_img = np.array([temp], dtype="uint8")
402
+ # _to_pil(pseudo_img).show()
403
+
404
+ extracted_colors = _extract_colors(
405
+ pseudo_img,
406
+ mask=None,
407
+ invert=False,
408
+ tolerance=tolerance,
409
+ limit=len(colors) // 2,
410
+ )
411
+ # print(len(extracted_colors))
412
+
413
+ new_colors = list()
414
+ for i in temp:
415
+ min_dist = math.inf
416
+ for c in extracted_colors:
417
+ dist = _get_euclidean_distance(c["rgb"], i)
418
+ if dist < min_dist:
419
+ min_dist = dist
420
+ best = c["rgb"]
421
+ new_colors.append(best)
422
+ return [_to_str(i, color_space="rgb") for i in new_colors]
423
+
424
+
425
+ def get_most_color(img, mask=None, min_count=10, get_full=False):
426
+ # img=(H,W,3) (0~255), mask=(H,W,3) (0 or 255)
427
+ if mask is None:
428
+ img_pixels = img.reshape(-1, 3)
429
+ else:
430
+ img_pixels = img[_to_2d(mask) == 255]
431
+
432
+ colors, colors_counts = np.unique(img_pixels, axis=0, return_counts=True)
433
+
434
+ if colors_counts.max() <= min_count:
435
+ most_color = tuple(
436
+ (
437
+ (colors_counts[:, np.newaxis] * colors).sum(axis=0)
438
+ / colors_counts.sum()
439
+ ).astype(np.uint8)
440
+ )
441
+ else:
442
+ most_color = tuple(colors[np.argmax(colors_counts)])
443
+
444
+ if get_full:
445
+ return colors, colors_counts
446
+
447
+ return most_color, colors_counts.max() # (R,G,B), color count
448
+
449
+
450
+ if __name__ == "__main__":
451
+ img = load_image(
452
+ "/Users/jongbeomkim/Desktop/Screen Shot 2023-11-07 at 10.41.23 AM.png"
453
+ )
454
+ contrast = get_contrast("rgb(10, 100, 100)", img)
455
+ below_thresh = contrast[contrast < 0]
456
+ if below_thresh.size == 0:
457
+ 21
458
+ else:
459
+ below_thresh.mean()
460
+
461
+ colors = [
462
+ "rgb(0, 0, 0)",
463
+ "rgb(95, 95, 95)",
464
+ "rgb(184, 137, 91)",
465
+ "rgb(0, 0, 0)",
466
+ "rgb(0, 0, 0)",
467
+ "rgb(93, 93, 93)",
468
+ "rgb(182, 142, 93)",
469
+ "rgb(0, 0, 0)",
470
+ "rgb(0, 0, 0)",
471
+ "rgb(99, 99, 99)",
472
+ "rgb(0, 0, 0)",
473
+ "rgb(0, 0, 0)",
474
+ "rgb(90, 90, 90)",
475
+ "rgb(184, 141, 90)",
476
+ "rgb(0, 0, 0)",
477
+ "rgb(0, 0, 0)",
478
+ "rgb(93, 93, 93)",
479
+ "rgb(14, 14, 14)",
480
+ "rgb(17, 17, 17)",
481
+ "rgb(101, 101, 101)",
482
+ "rgb(97, 97, 97)",
483
+ "rgb(193, 193, 193)",
484
+ "rgb(122, 122, 122)",
485
+ "rgb(122, 122, 122)",
486
+ "rgb(0, 0, 0)",
487
+ "rgb(118, 118, 118)",
488
+ ]
489
+ new_colors = merge_colors(colors, tolerance=30)
490
+ new_colors
data_utils/conf.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ from dataclasses_json import dataclass_json
4
+ import json
5
+ from pathlib import Path
6
+
7
+
8
+ @dataclass
9
+ class Slack:
10
+ url: str
11
+ channel: str
12
+ username: str
13
+ icon_emoji: str
14
+ channel_id: Optional[str] = None
15
+ bot_token: Optional[str] = None
16
+
17
+
18
+ class Config(object):
19
+ def __init__(self, env="dev"):
20
+ if env == "dev":
21
+ config_path = Path(__file__).parent.parent / "config/config.dev.json"
22
+ else:
23
+ config_path = Path(__file__).parent.parent / "config/config.prod.json"
24
+ config = self._read_config(config_path)
25
+ self.awss3_bucket, self.awss3_path = self.__parse_awss3_storage(
26
+ config.get("storage", {}).get("awss3", {})
27
+ )
28
+
29
+ (
30
+ self.requested_img_path,
31
+ self.text_removed_img_path,
32
+ ) = self.__parse_local_storage(config.get("storage", {}).get("local", {}))
33
+
34
+ self.mq_url = self._parse_amqp_server(config["mq_server"])
35
+
36
+ (
37
+ self.req_queue,
38
+ self.req_pattern,
39
+ self.resp_queue,
40
+ self.success_resp_pattern,
41
+ self.failure_resp_pattern,
42
+ ) = self._parse_queue(config["queue"])
43
+
44
+ self.slack = self._parse_slack(config["alarm"]["slack"])
45
+
46
+ def _read_config(self, config_path) -> dict:
47
+ with open(config_path, mode="r") as f:
48
+ config = json.load(f)
49
+ return config
50
+
51
+ def _parse_amqp_server(self, amqp_server) -> str:
52
+ username = amqp_server["username"]
53
+ password = amqp_server["password"]
54
+ url = amqp_server["url"]
55
+ port = amqp_server["port"]
56
+ amqp_url = f"amqps://{username}:{password}@{url}:{port}"
57
+ return amqp_url
58
+
59
+ def _parse_queue(self, queue) -> tuple:
60
+ req_queue = queue["request_name"]
61
+ req_pattern = queue["request_pattern"]
62
+ resp_queue = queue.get("response_name")
63
+ success_resp_pattern = queue["success_response_pattern"]
64
+ failure_resp_pattern = queue["failure_response_pattern"]
65
+ return (
66
+ req_queue,
67
+ req_pattern,
68
+ resp_queue,
69
+ success_resp_pattern,
70
+ failure_resp_pattern,
71
+ )
72
+
73
+ def __parse_awss3_storage(self, awss3_storage) -> tuple:
74
+ awss3_bucket = awss3_storage.get("default_bucket")
75
+ awss3_path = awss3_storage.get("default_path")
76
+
77
+ return awss3_bucket, awss3_path
78
+
79
+ def __parse_local_storage(self, local_storage) -> tuple:
80
+ requested_img_path = local_storage.get("requested")
81
+ text_removed_img_path = local_storage.get("text_removed")
82
+
83
+ return requested_img_path, text_removed_img_path
84
+
85
+ def _parse_slack(self, slack) -> Slack:
86
+ url = slack["url"]
87
+ channel = slack["channel"]
88
+ username = slack["username"]
89
+ icon_emoji = slack["icon_emoji"]
90
+ channel_id = slack.get("channel_id", None)
91
+ bot_token = slack.get("bot_token", None)
92
+
93
+ return Slack(url, channel, username, icon_emoji, channel_id, bot_token)
94
+
95
+
96
+ class ImageTrConfig(object):
97
+ def __init__(self, env="dev"):
98
+ if env == "dev":
99
+ config_path = Path(__file__).parent.parent / "config/config.dev.json"
100
+ else:
101
+ config_path = Path(__file__).parent.parent / "config/config.prod.json"
102
+ config = self._read_config(config_path)
103
+
104
+ (
105
+ self.awss3_bucket,
106
+ self.awss3_inpainting_path,
107
+ self.awss3_translation_path,
108
+ ) = self.__parse_awss3_storage(config.get("storage", {}).get("awss3", {}))
109
+
110
+ self.mq_url = self._parse_amqp_server(config["mq_server"])
111
+
112
+ (
113
+ self.req_queue,
114
+ self.req_pattern,
115
+ self.resp_queue,
116
+ self.success_resp_pattern,
117
+ self.failure_resp_pattern,
118
+ ) = self._parse_queue(config["queue"])
119
+
120
+ self.slack = self._parse_slack(config["alarm"]["slack"])
121
+
122
+ def _read_config(self, config_path) -> dict:
123
+ with open(config_path, mode="r") as f:
124
+ config = json.load(f)
125
+ return config
126
+
127
+ def _parse_amqp_server(self, amqp_server) -> str:
128
+ username = amqp_server["username"]
129
+ password = amqp_server["password"]
130
+ url = amqp_server["url"]
131
+ port = amqp_server["port"]
132
+ amqp_url = f"amqps://{username}:{password}@{url}:{port}"
133
+ return amqp_url
134
+
135
+ def _parse_queue(self, queue) -> tuple:
136
+ req_queue = queue["request_name"]
137
+ req_pattern = queue["request_pattern"]
138
+ resp_queue = queue.get("response_name")
139
+ success_resp_pattern = queue["success_response_pattern"]
140
+ failure_resp_pattern = queue["failure_response_pattern"]
141
+ return (
142
+ req_queue,
143
+ req_pattern,
144
+ resp_queue,
145
+ success_resp_pattern,
146
+ failure_resp_pattern,
147
+ )
148
+
149
+ def __parse_awss3_storage(self, awss3_storage) -> tuple:
150
+ awss3_bucket = awss3_storage.get("default_bucket")
151
+ awss3_inpainting_path = awss3_storage.get("inpainting_path")
152
+ awss3_translation_path = awss3_storage.get("translation_path")
153
+
154
+ return awss3_bucket, awss3_inpainting_path, awss3_translation_path
155
+
156
+ def __parse_local_storage(self, local_storage) -> tuple:
157
+ requested_img_path = local_storage.get("requested")
158
+ text_removed_img_path = local_storage.get("text_removed")
159
+
160
+ return requested_img_path, text_removed_img_path
161
+
162
+ def _parse_slack(self, slack) -> Slack:
163
+ url = slack["url"]
164
+ channel = slack["channel"]
165
+ username = slack["username"]
166
+ icon_emoji = slack["icon_emoji"]
167
+ channel_id = slack.get("channel_id", None)
168
+ bot_token = slack.get("bot_token", None)
169
+
170
+ return Slack(url, channel, username, icon_emoji, channel_id, bot_token)
data_utils/image_utils.py ADDED
@@ -0,0 +1,1364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References
2
+ # https://sashamaps.net/docs/resources/20-colors/
3
+
4
+ import numpy as np
5
+ import cv2
6
+ from scipy import ndimage as ndi
7
+ from PIL import Image, ImageDraw, ImageCms, ExifTags, ImageEnhance
8
+ import requests
9
+ from pathlib import Path
10
+ import pandas as pd
11
+ from scipy.sparse import coo_matrix
12
+ from skimage.feature import peak_local_max
13
+ from skimage.morphology import local_maxima
14
+ from skimage.segmentation import watershed
15
+ from moviepy.video.io.bindings import mplfig_to_npimage
16
+ import io
17
+ import os
18
+ from enum import Enum
19
+
20
+
21
+ COLORS = (
22
+ (230, 25, 75),
23
+ (60, 180, 75),
24
+ (255, 255, 25),
25
+ (0, 130, 200),
26
+ (245, 130, 48),
27
+ (145, 30, 180),
28
+ (70, 240, 250),
29
+ (240, 50, 230),
30
+ (210, 255, 60),
31
+ (250, 190, 212),
32
+ (0, 128, 128),
33
+ (220, 190, 255),
34
+ (170, 110, 40),
35
+ (255, 250, 200),
36
+ (128, 0, 0),
37
+ (170, 255, 195),
38
+ (128, 128, 0),
39
+ (255, 215, 180),
40
+ (0, 0, 128),
41
+ (128, 128, 128),
42
+ )
43
+
44
+
45
+ class PC_TYPE(Enum):
46
+ HARRIS = 1
47
+ EDGES_CONTOURS = 2
48
+ GFTT = 3
49
+ FAST = 4
50
+ KAZE = 5
51
+
52
+
53
+ def _to_2d(img):
54
+ # it use just first channel. if you want rgb2gray, use _to_grayscale
55
+ if img.ndim == 3:
56
+ return img[:, :, 0]
57
+ else:
58
+ return img
59
+
60
+
61
+ def _to_3d(img):
62
+ if img.ndim == 2:
63
+ return np.dstack([img, img, img])
64
+ else:
65
+ return img
66
+
67
+
68
+ def _to_byte(img: Image, format) -> bytes:
69
+ # BytesIO is a file-like buffer stored in memory
70
+ imgByteArr = io.BytesIO()
71
+ # image.save expects a file-like as a argument
72
+ img.save(imgByteArr, format=format)
73
+ # Turn the BytesIO object back into a bytes object
74
+ imgByteArr = imgByteArr.getvalue()
75
+ return imgByteArr
76
+
77
+
78
+ def _get_width_and_height(img):
79
+ if img.ndim == 2:
80
+ h, w = img.shape
81
+ else:
82
+ h, w, _ = img.shape
83
+ return w, h
84
+
85
+
86
+ def _get_resolution(img):
87
+ w, h = _get_width_and_height(img)
88
+ res = w * h
89
+ return res
90
+
91
+
92
+ def _to_pil(img):
93
+ if not isinstance(img, Image.Image):
94
+ img = Image.fromarray(img, mode="RGB")
95
+ return img
96
+
97
+
98
+ def _to_array(img):
99
+ img = np.array(img)
100
+ return img
101
+
102
+
103
+ def _bool_to_uint8(img):
104
+ uint8 = img.astype("uint8")
105
+ if (
106
+ np.array_equal(np.unique(uint8), np.array([0, 1]))
107
+ or np.array_equal(np.unique(uint8), np.array([0]))
108
+ or np.array_equal(np.unique(uint8), np.array([1]))
109
+ ):
110
+ return uint8 * 255
111
+ else:
112
+ return uint8
113
+
114
+
115
+ def _figure_to_array(fig):
116
+ arr = mplfig_to_npimage(fig)
117
+ return arr
118
+
119
+
120
+ def _preprocess_image(img):
121
+ if img.dtype == "int32":
122
+ img = _repaint_segmentation_map(img)
123
+
124
+ if img.dtype == "bool":
125
+ img = img.astype("uint8") * 255
126
+
127
+ if img.ndim == 2:
128
+ if (
129
+ np.array_equal(np.unique(img), np.array([0, 255]))
130
+ or np.array_equal(np.unique(img), np.array([0]))
131
+ or np.array_equal(np.unique(img), np.array([255]))
132
+ ):
133
+ img = _to_3d(img)
134
+ else:
135
+ img = _apply_jet_colormap(img)
136
+ return img
137
+
138
+
139
+ def _blend_two_images(img1, img2, alpha=0.5):
140
+ img1 = _to_pil(img1)
141
+ img2 = _to_pil(img2)
142
+ img_blended = Image.blend(im1=img1, im2=img2, alpha=alpha)
143
+ return _to_array(img_blended)
144
+
145
+
146
+ def _repaint_segmentation_map(seg_map):
147
+ canvas_r = _get_canvas_same_size_as_image(seg_map, black=True)
148
+ canvas_g = _get_canvas_same_size_as_image(seg_map, black=True)
149
+ canvas_b = _get_canvas_same_size_as_image(seg_map, black=True)
150
+
151
+ remainder_map = seg_map % len(COLORS) + 1
152
+ for remainder, (r, g, b) in enumerate(COLORS, start=1):
153
+ canvas_r[remainder_map == remainder] = r
154
+ canvas_g[remainder_map == remainder] = g
155
+ canvas_b[remainder_map == remainder] = b
156
+ canvas_r[seg_map == 0] = 0
157
+ canvas_g[seg_map == 0] = 0
158
+ canvas_b[seg_map == 0] = 0
159
+
160
+ dstacked = np.dstack([canvas_r, canvas_g, canvas_b])
161
+ return dstacked
162
+
163
+
164
+ def _get_canvas_same_size_as_image(img, black=False):
165
+ if black:
166
+ return np.zeros_like(img).astype("uint8")
167
+ else:
168
+ return (np.ones_like(img) * 255).astype("uint8")
169
+
170
+
171
+ def _get_canvas(w, h, black=False):
172
+ if black:
173
+ return np.zeros((h, w, 3)).astype("uint8")
174
+ else:
175
+ return (np.ones((h, w, 3)) * 255).astype("uint8")
176
+
177
+
178
+ def _invert_image(mask):
179
+ return cv2.bitwise_not(mask.astype("uint8"))
180
+
181
+
182
+ def _to_grayscale(img):
183
+ gray_img = cv2.cvtColor(src=img, code=cv2.COLOR_RGB2GRAY)
184
+ return gray_img
185
+
186
+
187
+ def _erode_mask(mask, kernel_size=3):
188
+ kernel = cv2.getStructuringElement(
189
+ shape=cv2.MORPH_RECT, ksize=(kernel_size, kernel_size)
190
+ )
191
+ if mask.dtype == "bool":
192
+ mask = mask.astype("uint8") * 255
193
+ mask = cv2.erode(src=mask, kernel=kernel)
194
+ return mask
195
+
196
+
197
+ def _dilate_mask(mask, kernel_size=3):
198
+ if kernel_size == 0:
199
+ return mask
200
+ kernel = cv2.getStructuringElement(
201
+ shape=cv2.MORPH_RECT, ksize=(kernel_size, kernel_size)
202
+ )
203
+ if mask.dtype == "bool":
204
+ mask = mask.astype("uint8") * 255
205
+ mask = cv2.dilate(src=mask, kernel=kernel)
206
+ return mask
207
+
208
+
209
+ def _gaussian_blur_mask(mask, kernel_size=5):
210
+ blurred_mask = cv2.GaussianBlur(
211
+ src=mask, ksize=(kernel_size, kernel_size), sigmaX=0
212
+ )
213
+ # mask = (blurred_mask >= 32).astype("uint8") * 255
214
+ mask = (blurred_mask != 0).astype("uint8") * 255
215
+ return mask
216
+
217
+
218
+ def _blur(img, v=0.04):
219
+ w, h = _get_width_and_height(img)
220
+ kernel_size = round(min(w, h) * v)
221
+ bl = cv2.GaussianBlur(
222
+ src=img.copy(order="C"),
223
+ ksize=(kernel_size // 2 * 2 + 1, kernel_size // 2 * 2 + 1),
224
+ sigmaX=0,
225
+ )
226
+ return bl
227
+
228
+
229
+ def _get_adaptive_thresholded_image(img, invert=False, block_size=3):
230
+ gray_img = cv2.cvtColor(src=img, code=cv2.COLOR_RGB2GRAY)
231
+
232
+ thrsh_type = cv2.THRESH_BINARY if not invert else cv2.THRESH_BINARY_INV
233
+ img_thr = cv2.adaptiveThreshold(
234
+ src=gray_img,
235
+ maxValue=255,
236
+ adaptiveMethod=cv2.ADAPTIVE_THRESH_MEAN_C,
237
+ thresholdType=thrsh_type,
238
+ blockSize=block_size,
239
+ C=0,
240
+ )
241
+ return img_thr
242
+
243
+
244
+ def _make_segmentation_map_rectangle(seg_map):
245
+ seg_map_copied = seg_map.copy(order="C")
246
+ for idx in range(1, np.max(seg_map_copied) + 1):
247
+ seg_map_sub = seg_map_copied == idx
248
+ nonzero_x = np.where((seg_map_sub != 0).any(axis=0))[0]
249
+ nonzero_y = np.where((seg_map_sub != 0).any(axis=1))[0]
250
+ if nonzero_x.size != 0 and nonzero_y.size != 0:
251
+ seg_map_copied[
252
+ nonzero_y[0] : nonzero_y[-1], nonzero_x[0] : nonzero_x[-1]
253
+ ] = idx
254
+ return seg_map_copied
255
+
256
+
257
+ def _apply_jet_colormap(img):
258
+ img_jet = cv2.applyColorMap(src=(255 - img), colormap=cv2.COLORMAP_JET)
259
+ return img_jet
260
+
261
+
262
+ def _reverse_jet_colormap(img):
263
+ gray_values = np.arange(256, dtype=np.uint8)
264
+ color_values = list(map(tuple, _apply_jet_colormap(gray_values).reshape(256, 3)))
265
+ color_to_gray_map = dict(zip(color_values, gray_values))
266
+
267
+ out = np.apply_along_axis(
268
+ lambda bgr: color_to_gray_map[tuple(bgr)], axis=2, arr=img
269
+ )
270
+ return out
271
+
272
+
273
+ def _get_pixel_counts(arr, sort=False, include_zero=False):
274
+ unique, cnts = np.unique(arr, return_counts=True)
275
+ idx2cnt = dict(zip(unique, cnts))
276
+
277
+ if not include_zero:
278
+ if 0 in idx2cnt:
279
+ idx2cnt.pop(0)
280
+
281
+ if not sort:
282
+ return idx2cnt
283
+ else:
284
+ return dict(sorted(idx2cnt.items(), key=lambda x: x[1], reverse=True))
285
+
286
+
287
+ def _combine_masks(masks):
288
+ canvas = _get_canvas_same_size_as_image(img=masks[0], black=True)
289
+ for mask in masks:
290
+ canvas = np.maximum(_to_3d(canvas), _to_3d(mask))
291
+ return canvas
292
+
293
+
294
+ def _get_local_maxima_coordinates(region_score_map, region_seg_map=None, th=150):
295
+ # `src_lang="ja"`일 λ•Œ `150`이 더 잘 μž‘λ™ν•¨.
296
+ if region_seg_map is None:
297
+ _, region_mask = cv2.threshold(
298
+ src=region_score_map, thresh=th, maxval=255, type=cv2.THRESH_BINARY
299
+ )
300
+ _, region_seg_map = cv2.connectedComponents(image=region_mask, connectivity=4)
301
+ local_max = peak_local_max(
302
+ image=region_score_map,
303
+ min_distance=5,
304
+ labels=region_seg_map,
305
+ num_peaks_per_label=24,
306
+ )
307
+ local_max = local_max[:, ::-1] # yx to xy
308
+ return local_max
309
+
310
+
311
+ def _get_local_maxima_array(region_score_map, region_seg_map=None, th=150):
312
+ local_max_coor = _get_local_maxima_coordinates(
313
+ region_score_map, region_seg_map=None, th=th
314
+ )
315
+
316
+ _, h = _get_width_and_height(local_max_coor)
317
+ vals = np.array([1] * h)
318
+ rows = local_max_coor[:, 1]
319
+ cols = local_max_coor[:, 0]
320
+ local_max = (
321
+ coo_matrix((vals, (rows, cols)), shape=region_score_map.shape)
322
+ .toarray()
323
+ .astype("bool")
324
+ )
325
+ return local_max
326
+
327
+
328
+ def _mask_image(img, mask, invert=False):
329
+ """imgμ—μ„œ mask μ˜μ—­μ— ν•΄λ‹Ήν•˜λŠ” λΆ€λΆ„λ§Œ μΆ”μΆœ
330
+
331
+ Args:
332
+ img (_PIL or np.ndarray_): 이미지
333
+ mask (_PIL or np.ndarray_): 마슀크 (H,W,C)일경우 ν‘λ°±μœΌλ‘œ λ³€ν™˜ ν›„ or (H,W)
334
+ invert (bool, optional): invert_mask둜 μΆ”μΆœν• μ§€.
335
+
336
+ Returns:
337
+ _np.ndarray_: κ²°κ³Ό 이미지
338
+ """
339
+ img = _to_array(img)
340
+ mask = _to_2d(_to_array(mask))
341
+ if invert:
342
+ mask = _invert_image(mask)
343
+ return cv2.bitwise_and(src1=img, src2=img, mask=mask.astype("uint8"))
344
+
345
+
346
+ def _ignore_small_regions_in_mask(mask, area_thresh=10):
347
+ mask = _to_2d(mask)
348
+
349
+ _, seg_map, stats, _ = cv2.connectedComponentsWithStats(
350
+ mask.astype("uint8"), connectivity=4
351
+ )
352
+ bool = np.isin(seg_map, np.where(stats[:, cv2.CC_STAT_AREA] >= area_thresh)[0][1:])
353
+ new_mask = bool.astype("uint8") * 255
354
+ new_mask = _to_3d(new_mask)
355
+ return new_mask
356
+
357
+
358
+ def _crop_image(img, l, t, r, b):
359
+ w, h = _get_width_and_height(img)
360
+ return img[
361
+ int(max(0, t)) : int(min(h, b)),
362
+ int(max(0, l)) : int(min(w, r)),
363
+ ...,
364
+ ]
365
+
366
+
367
+ def _bboxes_to_mask(img, bboxes):
368
+ canvas = _get_canvas_same_size_as_image(img=img, black=True)
369
+ for row in bboxes.itertuples():
370
+ canvas[row.bbox_y1 : row.bbox_y2, row.bbox_x1 : row.bbox_x2] = 255
371
+ return _to_3d(canvas)
372
+
373
+
374
+ def _apply_watershed(mask, region_score_map, th=150):
375
+ local_max_arr = _get_local_maxima_array(region_score_map, th=th)
376
+ _, markers = cv2.connectedComponents(
377
+ image=local_max_arr.astype("uint8"), connectivity=4
378
+ )
379
+ seg_map = watershed(image=-region_score_map, markers=markers, mask=_to_2d(mask))
380
+ return seg_map
381
+
382
+
383
+ def _perform_watershed(score_map, score_thresh=80):
384
+ trimmed_score_map = score_map.copy()
385
+ trimmed_score_map[trimmed_score_map < 190] = 0
386
+
387
+ markers = local_maxima(image=trimmed_score_map, allow_borders=False)
388
+ _, markers = cv2.connectedComponents(image=markers.astype("int8"), connectivity=8)
389
+
390
+ _, region_mask = cv2.threshold(
391
+ src=score_map, thresh=score_thresh, maxval=255, type=cv2.THRESH_BINARY
392
+ )
393
+ watersheded = watershed(image=-score_map, markers=markers, mask=_to_2d(region_mask))
394
+ return watersheded
395
+
396
+
397
+ def _get_region_segmentation_map(region_score_map, region_thresh=30):
398
+ _, region_mask = cv2.threshold(
399
+ src=region_score_map, thresh=region_thresh, maxval=255, type=cv2.THRESH_BINARY
400
+ )
401
+ region_seg_map = _apply_watershed(
402
+ region_score_map=region_score_map, mask=region_mask
403
+ )
404
+ return region_seg_map
405
+
406
+
407
+ def _combine_two_segmentation_maps(seg_map1, seg_map2):
408
+ seg_map = seg_map1 + _mask_image(
409
+ img=seg_map2 + len(np.unique(seg_map1)) - 1, mask=(seg_map2 != 0)
410
+ )
411
+ px_cnts = _get_pixel_counts(seg_map, sort=True, include_zero=True)
412
+ seg_map = _mask_image(img=seg_map, mask=(seg_map != list(px_cnts)[0]))
413
+ return seg_map
414
+
415
+
416
+ def _get_image_segmentation_map(img, region_score_map=None, block_size=3):
417
+ if region_score_map is not None:
418
+ _, region_mask = cv2.threshold(
419
+ src=region_score_map, thresh=20, maxval=255, type=cv2.THRESH_BINARY
420
+ )
421
+ region_mask = _dilate_mask(img=region_mask, kernel_size=16)
422
+ img_masked = _mask_image(img=img, mask=region_mask)
423
+ else:
424
+ img_masked = img
425
+
426
+ img_thr1 = _get_adaptive_thresholded_image(
427
+ img=img_masked, invert=False, block_size=block_size
428
+ )
429
+ img_thr2 = _get_adaptive_thresholded_image(
430
+ img=img_masked, invert=True, block_size=block_size
431
+ )
432
+
433
+ _, seg_map1 = cv2.connectedComponents(image=img_thr1, connectivity=4)
434
+ _, seg_map2 = cv2.connectedComponents(image=img_thr2, connectivity=4)
435
+ seg_map = _combine_two_segmentation_maps(seg_map1=seg_map1, seg_map2=seg_map2)
436
+ return seg_map
437
+
438
+
439
+ def _get_segmentation_map_overlapping_mask(seg_map, mask, overlap_thresh=0.6):
440
+ img_pixel_counts = _get_pixel_counts(seg_map, sort=True, include_zero=False)
441
+
442
+ overlapping_seg_map = _mask_image(img=seg_map, mask=(mask != 0))
443
+ overlapping_counts = _get_pixel_counts(
444
+ overlapping_seg_map, sort=False, include_zero=False
445
+ )
446
+
447
+ df_counts = pd.DataFrame.from_dict(
448
+ img_pixel_counts, orient="index", columns=["total_pixel_count"]
449
+ )
450
+ df_counts["overlap_pixel_count"] = df_counts.apply(
451
+ lambda x: overlapping_counts.get(x.name, 0), axis=1
452
+ )
453
+ df_counts["ratio"] = (
454
+ df_counts["overlap_pixel_count"] / df_counts["total_pixel_count"]
455
+ )
456
+
457
+ region_is_inside = df_counts[df_counts["ratio"] > overlap_thresh].index.tolist()
458
+ mask = np.isin(seg_map, region_is_inside).astype("uint8")
459
+ mask = _to_3d(mask * 255)
460
+ return mask
461
+
462
+
463
+ def _split_segmentation_map(seg_map, pccs):
464
+ ls_idx = (
465
+ pccs[pccs["inside"]]
466
+ .apply(lambda x: seg_map[x["y"], x["x"]], axis=1)
467
+ .values.tolist()
468
+ )
469
+
470
+ seg_map1 = _mask_image(img=seg_map, mask=np.isin(seg_map, ls_idx))
471
+ seg_map2 = _mask_image(img=seg_map, mask=~np.isin(seg_map, ls_idx))
472
+ return seg_map1, seg_map2
473
+
474
+
475
+ def _segmentation_map_to_mask(seg_map):
476
+ return _to_3d((seg_map != 0).astype("uint8") * 255)
477
+
478
+
479
+ def _get_pseudo_character_centers_from_mask(mask, bboxes: pd.DataFrame = None):
480
+ """Mask μ΄λ―Έμ§€λ‘œλΆ€ν„° label(κΈ€μž)의 쀑심 μ’Œν‘œλ₯Ό κ΅¬ν•˜λŠ” ν•¨μˆ˜"""
481
+ center_coords = []
482
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
483
+ image=_to_2d(mask), connectivity=8
484
+ )
485
+ for i in range(1, num_labels):
486
+ center_coords.append((int(centroids[i][0]), int(centroids[i][1])))
487
+
488
+ pccs = pd.DataFrame(
489
+ center_coords,
490
+ columns=[
491
+ "x",
492
+ "y",
493
+ ],
494
+ )
495
+
496
+ if not bboxes.empty:
497
+ # 벑터화 μ—°μ‚°μœΌλ‘œ bbox μ•ˆμ— μžˆλŠ”μ§€ 검사
498
+ pccs["inside"] = (
499
+ (pccs["x"].values[:, None] > bboxes["bbox_x1"].values) &
500
+ (pccs["x"].values[:, None] < bboxes["bbox_x2"].values) &
501
+ (pccs["y"].values[:, None] > bboxes["bbox_y1"].values) &
502
+ (pccs["y"].values[:, None] < bboxes["bbox_y2"].values)
503
+ ).any(axis=1)
504
+ else:
505
+ pccs["inside"] = True
506
+
507
+ return pccs
508
+
509
+
510
+ def _get_pseudo_character_centers(
511
+ region_score_map, region_seg_map=None, bboxes=pd.DataFrame()
512
+ ):
513
+ local_max_coor = _get_local_maxima_coordinates(
514
+ region_score_map, region_seg_map=region_seg_map
515
+ )
516
+ pccs = pd.DataFrame(local_max_coor, columns=["x", "y"])
517
+
518
+ if not bboxes.empty:
519
+ # 벑터화 μ—°μ‚°μœΌλ‘œ bbox μ•ˆμ— μžˆλŠ”μ§€ 검사
520
+ pccs["inside"] = (
521
+ (pccs["x"].values[:, None] > bboxes["bbox_x1"].values) &
522
+ (pccs["x"].values[:, None] < bboxes["bbox_x2"].values) &
523
+ (pccs["y"].values[:, None] > bboxes["bbox_y1"].values) &
524
+ (pccs["y"].values[:, None] < bboxes["bbox_y2"].values)
525
+ ).any(axis=1)
526
+ else:
527
+ pccs["inside"] = True
528
+
529
+ return pccs
530
+
531
+
532
+ def _convert_region_score_map_to_region_mask(region_score_map, region_score_thresh=170):
533
+ _, region_mask = cv2.threshold(
534
+ src=region_score_map, thresh=30, maxval=255, type=cv2.THRESH_BINARY
535
+ )
536
+
537
+ new_mask = _get_canvas_same_size_as_image(img=region_mask, black=True)
538
+
539
+ n_labels, seg_map, _, _ = cv2.connectedComponentsWithStats(
540
+ image=_to_2d(region_mask), connectivity=4
541
+ )
542
+ for k in range(1, n_labels):
543
+ if np.max(region_score_map[seg_map == k]) < region_score_thresh:
544
+ continue
545
+
546
+ new_mask[seg_map == k] = 255
547
+ new_mask = _to_3d(new_mask)
548
+ return new_mask
549
+
550
+
551
+ def _split_mask(mask, region_score_map=None, bboxes=pd.DataFrame(), th=30):
552
+ """maskλ₯Ό 두 μ’…λ₯˜λ‘œ λ‚˜λˆ•λ‹ˆλ‹€. 각각 inpaintingκ³Όμ •μ—μ„œ μ§€μ›Œμ•Όν•  mask와 볡ꡬ해야할 mask μ˜μ—­μ„ μ˜λ―Έν•©λ‹ˆλ‹€.
553
+ mask1κ³Ό mask2λŠ” μ„œλ‘œ κ²ΉμΉ μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€.
554
+ λ™μž‘μ›λ¦¬ : region_score_map(이 μ•ˆμ£Όμ–΄μ§ˆ 경우 dst_mask_map)을 th둜 이진화 및 segmap으둜 λ³€ν˜•(Connected components)ν›„
555
+ labelμ˜μ—­ 별 Local maximum 포인트λ₯Ό watershed의 marker둜 여겨 watershedλ₯Ό μ§„ν–‰ν•œ κ²°κ³Όλ₯Ό segmap으둜 μ—¬κΈ°κ³ ,
556
+ pccsλ₯Ό peak_loacl_max(skimage)ν•¨μˆ˜λ‘œ region_scoremapκ³Ό segmap을 μ΄μš©ν•΄ κ΅¬ν•œλ‹€. μ΄λ•Œ bbox정보도 ν¬ν•¨μ‹œμΌœ, 각 pccsκ°€ boxμ•ˆμ— λ“€μ–΄ μ˜€λŠ”μ§€ ν™•μΈν•œ ν›„
557
+ bboxμ•ˆμ— μžˆλŠ” pccs에 λŒ€ν•΄ 각 pccsκ°€ μ†ν•œ segmap의 labelμ˜μ—­(seg_map1)κ³Ό μ†ν•˜μ§€ λͺ»ν•œ label μ˜μ—­(seg_map2)둜 λ‚˜λˆˆλ‹€.
558
+
559
+ Args:
560
+ mask (_np.ndarray_): (H,W,3)의 mask. values : (0 or 255)
561
+ region_score_map (_np.ndarray_): region_score_map, craft의 κ²°κ³Ό. κΈ€μ˜ 쀑심을 κ°•μ‘°ν•˜λŠ” Heat map
562
+ bboxes (_pd.DataFrame_): λ°•μŠ€ μ’Œν‘œμ •λ³΄(bbox_x1,bbox_y1,bbox_x2,bbox_y2)κ°€ ν¬ν•¨λœ dataFrame.
563
+ Returns:
564
+ _np.ndarray_: μ§€μ›Œμ•Ό ν•˜λŠ” 뢀뢄인 mask1. 볡ꡬ해야 ν•˜λŠ” 뢀뢄인 mask2.
565
+ """
566
+
567
+ if region_score_map is None:
568
+ dst_mask_map = _to_2d(get_dst_mask(mask))
569
+ seg_map = _apply_watershed(mask=mask, region_score_map=dst_mask_map, th=th)
570
+ pccs = _get_pseudo_character_centers(
571
+ region_score_map=dst_mask_map, region_seg_map=seg_map, bboxes=bboxes
572
+ )
573
+ else:
574
+ seg_map = _apply_watershed(mask, region_score_map, th=th)
575
+ pccs = _get_pseudo_character_centers(
576
+ region_score_map=region_score_map, region_seg_map=seg_map, bboxes=bboxes
577
+ )
578
+
579
+ box_mask = _bboxes_to_mask(seg_map, bboxes)
580
+
581
+ seg_map1, seg_map2 = _split_segmentation_map(seg_map=seg_map, pccs=pccs)
582
+ mask1 = _segmentation_map_to_mask(seg_map1)
583
+ mask2 = _segmentation_map_to_mask(seg_map2)
584
+ mask3 = _to_3d(_mask_image(mask1, box_mask, invert=True))
585
+ mask2 = _combine_masks([mask2, mask3])
586
+ return mask1, mask2
587
+
588
+
589
+ def get_word_segmentation_map(region_score_map, affinity_score_map):
590
+ _, region_mask = cv2.threshold(
591
+ src=region_score_map, thresh=70, maxval=255, type=cv2.THRESH_BINARY
592
+ )
593
+ _, affinity_mask = cv2.threshold(
594
+ src=affinity_score_map, thresh=70, maxval=255, type=cv2.THRESH_BINARY
595
+ )
596
+ word_mask = region_mask + affinity_mask
597
+
598
+ _, segmentation_map_word = cv2.connectedComponents(image=word_mask, connectivity=4)
599
+ return segmentation_map_word
600
+
601
+
602
+ def get_line_segmentation_map(line_score_map):
603
+ _, line_mask = cv2.threshold(
604
+ src=line_score_map, thresh=130, maxval=255, type=cv2.THRESH_BINARY
605
+ )
606
+ _, line_segmentation_map = cv2.connectedComponents(image=line_mask, connectivity=4)
607
+ return line_segmentation_map
608
+
609
+
610
+ def _get_3d_block_segmentation_map(img, bboxes):
611
+ segmentation_map_block = np.zeros(
612
+ shape=(img.shape[0], img.shape[1], len(bboxes) + 1)
613
+ )
614
+ for idx, (xmin, ymin, xmax, ymax) in enumerate(
615
+ bboxes[["xmin", "ymin", "xmax", "ymax"]].values, start=1
616
+ ):
617
+ segmentation_map_block[ymin:ymax, xmin:xmax, idx] = 255
618
+ return segmentation_map_block
619
+
620
+
621
+ def compare_images(img1, img2, flag=cv2.CMP_EQ):
622
+ # 두 이미지가 같은 μ˜μ—­μ„ 255 μ•„λ‹Œ μ˜μ—­μ„ 0. flagλŠ” cv2.CMP_XXμ°Έκ³ (EQ==κ°™μœΌλ©΄1,NE==λ‹€λ₯΄λ©΄1)
623
+ return cv2.compare(img1, img2, flag)
624
+
625
+
626
+ def convert_webp_png_get_data(img: np.ndarray):
627
+ pil_img = _to_pil(img)
628
+ convert_pil_img = pil_img.convert("RGB")
629
+ convert_pil_img.save("temp.png")
630
+ _, byte, format = load_image("temp.png", with_byte=True, with_format=True)
631
+ os.remove("temp.png")
632
+
633
+ return byte
634
+
635
+
636
+ def add_water_mark(original_img, water_mark_img_path):
637
+ if isinstance(original_img, np.ndarray):
638
+ original_img = _to_pil(original_img)
639
+ return_np = True
640
+ else:
641
+ return_np = False
642
+ watermark = Image.open(water_mark_img_path).convert("RGBA")
643
+
644
+ width_o, height_o = original_img.size
645
+ width_wm, height_wm = watermark.size
646
+
647
+ position = ((width_o - width_wm) // 2, (height_o - height_wm) // 2)
648
+
649
+ # 원본 이미지보닀 크기가 μž‘μ€ κ²½μš°μ—λ§Œ μ›Œν„°λ§ˆν¬ 이미지λ₯Ό λΉ„μœ¨μ— 맞게 μ‘°μ •
650
+ if width_wm > width_o or height_wm > height_o:
651
+ # μ›Œν„°λ§ˆν¬ μ΄λ―Έμ§€μ˜ κ°€λ‘œ μ„Έλ‘œ λΉ„μœ¨ 계산
652
+ ratio_w = width_o / width_wm
653
+ ratio_h = height_o / height_wm
654
+ # 더 μž‘μ€ λΉ„μœ¨μ„ μ„ νƒν•˜μ—¬ μ›Œν„°λ§ˆν¬ 이미지λ₯Ό μ‘°μ •
655
+ ratio = min(ratio_w, ratio_h)
656
+ new_width = int(width_wm * ratio)
657
+ new_height = int(height_wm * ratio)
658
+ watermark = watermark.resize((new_width, new_height), Image.Resampling.LANCZOS)
659
+ width_wm, height_wm = watermark.size
660
+
661
+ # μƒˆλ‘œ κ³„μ‚°λœ μœ„μΉ˜
662
+ position = ((width_o - width_wm) // 2, (height_o - height_wm) // 2)
663
+
664
+ original_img.paste(watermark, position, watermark)
665
+ rgb_image = original_img.convert("RGB")
666
+
667
+ if return_np:
668
+ return _to_array(rgb_image)
669
+ return rgb_image
670
+
671
+
672
+ def load_image(url_or_path, with_byte=False, with_format=False):
673
+ if "http" in url_or_path:
674
+ url_or_path = str(url_or_path)
675
+ response = requests.get(url_or_path)
676
+ PIL_image = Image.open(io.BytesIO(response.content))
677
+ format = PIL_image.format
678
+ image_bytes = response.content
679
+ if format == "GIF":
680
+ img_exif = None
681
+ else:
682
+ img_exif = PIL_image._getexif()
683
+ if PIL_image.mode in ["L", "P", "PA", "RGBA"]:
684
+ PIL_image = Image.open(io.BytesIO(response.content)).convert("RGB")
685
+ if img_exif:
686
+ for k in img_exif.keys():
687
+ attr = ExifTags.TAGS.get(k, "no_key")
688
+ if attr != "no_key":
689
+ if ExifTags.TAGS[k] == "Orientation":
690
+ if img_exif[k] == 3:
691
+ PIL_image = PIL_image.rotate(180, expand=True)
692
+ elif img_exif[k] == 6:
693
+ PIL_image = PIL_image.rotate(270, expand=True)
694
+ elif img_exif[k] == 8:
695
+ PIL_image = PIL_image.rotate(90, expand=True)
696
+ break
697
+ if PIL_image.mode == "CMYK":
698
+ cmyk_profile = ImageCms.ImageCmsProfile("resources/USWebCoatedSWOP.icc")
699
+ srgb_profile = ImageCms.ImageCmsProfile(
700
+ "resources/sRGB Color Space Profile.icm"
701
+ )
702
+ PIL_image = ImageCms.profileToProfile(
703
+ PIL_image, cmyk_profile, srgb_profile, outputMode="RGB"
704
+ )
705
+ img = np.array(PIL_image)
706
+ else:
707
+ img = np.array(PIL_image)
708
+ else:
709
+ # img = cv2.imread(url_or_path, flags=cv2.IMREAD_COLOR)
710
+ # img = cv2.cvtColor(src=img, code=cv2.COLOR_BGR2RGB)
711
+ PIL_image = Image.open(url_or_path)
712
+ format = PIL_image.format
713
+ byte_arr = io.BytesIO()
714
+ if PIL_image.mode == "RGBA":
715
+ PIL_image = PIL_image.convert("RGB")
716
+ PIL_image.save(byte_arr, format="JPEG")
717
+ image_bytes = byte_arr.getvalue()
718
+ img = np.array(PIL_image)
719
+
720
+ # if "http" in url_or_path:
721
+ # img = cv2.imdecode(
722
+ # np.asarray(bytearray(requests.get(url_or_path).content), dtype="uint8"), flags=cv2.IMREAD_COLOR
723
+ # )
724
+ # else:
725
+ # img = cv2.imread(url_or_path, flags=cv2.IMREAD_COLOR)
726
+ # img = cv2.cvtColor(src=img, code=cv2.COLOR_BGR2RGB)
727
+ if with_byte:
728
+ if with_format:
729
+ return img, image_bytes, format
730
+ else:
731
+ return img, image_bytes
732
+
733
+ return img
734
+
735
+
736
+ def save_image(img1, img2=None, alpha=0.5, path="") -> None:
737
+ copied_img1 = _preprocess_image(_to_array(img1.copy(order="C")))
738
+ if img2 is None:
739
+ img_arr = copied_img1
740
+ else:
741
+ copied_img2 = _to_array(_preprocess_image(_to_array(img2.copy(order="C"))))
742
+ img_arr = _to_array(
743
+ _blend_two_images(img1=copied_img1, img2=copied_img2, alpha=alpha)
744
+ )
745
+
746
+ path = Path(path)
747
+ path.parent.mkdir(parents=True, exist_ok=True)
748
+
749
+ if os.path.splitext(str(path))[1] == ".gif":
750
+ pil = _to_pil(img1)
751
+ pil.save(str(path))
752
+ return True
753
+
754
+ if img_arr.ndim == 3:
755
+ cv2.imwrite(
756
+ filename=str(path),
757
+ img=img_arr[:, :, ::-1],
758
+ params=[cv2.IMWRITE_JPEG_QUALITY, 100],
759
+ )
760
+ elif img_arr.ndim == 2:
761
+ cv2.imwrite(
762
+ filename=str(path), img=img_arr, params=[cv2.IMWRITE_JPEG_QUALITY, 100]
763
+ )
764
+
765
+
766
+ def show_image(img1, img2=None, alpha=0.5):
767
+ img1 = _to_pil(_preprocess_image(_to_array(img1)))
768
+ if img2 is None:
769
+ img1.show()
770
+ else:
771
+ img2 = _to_pil(_preprocess_image(_to_array(img2)))
772
+ img_blended = Image.blend(im1=img1, im2=img2, alpha=alpha)
773
+ img_blended.show()
774
+
775
+
776
+ def draw_bboxes(img, bboxes: pd.DataFrame, index=False):
777
+ """μ†μ„±μΆ”μΆœμ „ 원본 이미지와 bboxes정보λ₯Ό 가지고 μ΄λ―Έμ§€μœ„μ— bboxesλ₯Ό μ‹œκ°ν™” ν•΄μ£ΌλŠ” ν•¨μˆ˜."""
778
+ canvas = _to_pil(_get_canvas_same_size_as_image(img=img, black=True))
779
+ draw = ImageDraw.Draw(canvas)
780
+ dic = dict()
781
+ for row in bboxes.itertuples():
782
+ h = row.bbox_y2 - row.bbox_y1
783
+ w = row.bbox_x2 - row.bbox_x1
784
+ smaller = min(w, h)
785
+ thickness = max(1, smaller // 22)
786
+
787
+ dic[row.Index] = ((0, 255, 0), (0, 100, 0), thickness)
788
+
789
+ for row in bboxes.itertuples():
790
+ _, fill, thickness = dic[row.Index]
791
+ draw.rectangle(
792
+ xy=(row.bbox_x1, row.bbox_y1, row.bbox_x2, row.bbox_y2),
793
+ outline=None,
794
+ fill=fill,
795
+ width=thickness,
796
+ )
797
+ for row in bboxes.itertuples():
798
+ outline, _, thickness = dic[row.Index]
799
+ draw.rectangle(
800
+ xy=(row.bbox_x1, row.bbox_y1, row.bbox_x2, row.bbox_y2),
801
+ outline=outline,
802
+ fill=None,
803
+ width=thickness,
804
+ )
805
+
806
+ if index:
807
+ from data_utils.rendering_utils import _get_font
808
+
809
+ max_len = max(map(len, map(str, bboxes.index)))
810
+ for row in bboxes.itertuples():
811
+ h = row.bbox_y2 - row.bbox_y1
812
+ w = row.bbox_x2 - row.bbox_x1
813
+ smaller = min(w, h)
814
+ font_size = max(10, min(40, smaller // 4))
815
+
816
+ draw.text(
817
+ xy=(row.bbox_x1, row.bbox_y1 - 4),
818
+ text=str(row.Index).zfill(max_len),
819
+ fill="white",
820
+ stroke_fill="black",
821
+ stroke_width=2,
822
+ font=_get_font(lang="en", font_size=font_size),
823
+ anchor="ls",
824
+ )
825
+ return _blend_two_images(img1=canvas, img2=img, alpha=0.4)
826
+
827
+
828
+ def visualize_clusters(img, bboxes, index=False):
829
+ from data_utils.rendering_utils import _get_font
830
+
831
+ canvas = _to_pil(_get_canvas_same_size_as_image(img=img, black=True))
832
+ draw = ImageDraw.Draw(canvas)
833
+ dic = dict()
834
+ for row in bboxes.itertuples():
835
+ h = row.bbox_y2 - row.bbox_y1
836
+ w = row.bbox_x2 - row.bbox_x1
837
+ smaller = min(w, h)
838
+ thickness = max(1, smaller // 22)
839
+
840
+ dic[row.Index] = ((255, 255, 255), COLORS[row.cluster], thickness)
841
+
842
+ for row in bboxes.itertuples():
843
+ _, fill, thickness = dic[row.Index]
844
+ draw.rectangle(
845
+ xy=(row.bbox_x1, row.bbox_y1, row.bbox_x2, row.bbox_y2),
846
+ outline=None,
847
+ fill=fill,
848
+ width=1,
849
+ )
850
+ for row in bboxes.itertuples():
851
+ outline, _, thickness = dic[row.Index]
852
+ draw.rectangle(
853
+ xy=(row.bbox_x1, row.bbox_y1, row.bbox_x2, row.bbox_y2),
854
+ outline=outline,
855
+ fill=None,
856
+ width=1,
857
+ )
858
+
859
+ if index:
860
+ for row in bboxes.itertuples():
861
+ h = row.bbox_y2 - row.bbox_y1
862
+ w = row.bbox_x2 - row.bbox_x1
863
+ smaller = min(w, h)
864
+ font_size = max(14, min(40, smaller * 0.35))
865
+
866
+ draw.text(
867
+ xy=(row.bbox_x1, row.bbox_y1 - 4),
868
+ text=str(row.cluster),
869
+ fill="white",
870
+ stroke_fill="black",
871
+ stroke_width=2,
872
+ font=_get_font(lang="en", font_size=font_size),
873
+ anchor="ls",
874
+ )
875
+ return _blend_two_images(img1=canvas, img2=img, alpha=0.25)
876
+
877
+
878
+ def draw_bboxes_and_textboxes(bboxes, img):
879
+ canvas = img.copy(order="C")
880
+ for row in bboxes.itertuples():
881
+ cv2.rectangle(
882
+ img=canvas,
883
+ pt1=(row.bbox_x1, row.bbox_y1),
884
+ pt2=(row.bbox_x2, row.bbox_y2),
885
+ color=(0, 255, 0),
886
+ thickness=4,
887
+ )
888
+ cv2.rectangle(
889
+ img=canvas,
890
+ pt1=(row.tbox_x1, row.tbox_y1),
891
+ pt2=(row.tbox_x2, row.tbox_y2),
892
+ color=(255, 0, 0),
893
+ thickness=2,
894
+ )
895
+ return canvas
896
+
897
+
898
+ def draw_pseudo_character_centers(img, pccs, margin=4):
899
+ canvas = _to_pil(_get_canvas_same_size_as_image(img=img, black=True))
900
+ draw = ImageDraw.Draw(canvas)
901
+ for row in pccs.itertuples():
902
+ draw.ellipse(
903
+ xy=(row.x - margin, row.y - margin, row.x + margin, row.y + margin),
904
+ outline=(255, 0, 0),
905
+ fill=(100, 0, 0),
906
+ )
907
+ return _blend_two_images(img1=canvas, img2=img, alpha=0.3)
908
+
909
+
910
+ def _resize_image(img, w, h):
911
+ ori_w, ori_h = _get_width_and_height(img)
912
+ if w < ori_w or h < ori_h:
913
+ interpolation = cv2.INTER_AREA
914
+ else:
915
+ interpolation = cv2.INTER_LANCZOS4
916
+ resized_img = cv2.resize(src=img, dsize=(w, h), interpolation=interpolation)
917
+ return resized_img
918
+
919
+
920
+ def _resize_image_using_shorter_side(img, img_size=1530):
921
+ ori_w, ori_h = _get_width_and_height(img)
922
+ shorter = min(ori_w, ori_h)
923
+ if shorter <= img_size:
924
+ return img
925
+ if ori_w < ori_h:
926
+ resized_img = cv2.resize(
927
+ src=img,
928
+ dsize=(img_size, round(ori_h * (img_size / ori_w))),
929
+ interpolation=cv2.INTER_AREA,
930
+ )
931
+ else:
932
+ resized_img = cv2.resize(
933
+ src=img,
934
+ dsize=(round(ori_w * (img_size / ori_h)), img_size),
935
+ interpolation=cv2.INTER_AREA,
936
+ )
937
+ return resized_img
938
+
939
+
940
+ def _resize_image_using_longer_side(img, img_size=2560):
941
+ ori_w, ori_h = _get_width_and_height(img)
942
+ longer = max(ori_w, ori_h)
943
+ if longer <= img_size:
944
+ return img
945
+ if ori_w < ori_h:
946
+ resized_img = cv2.resize(
947
+ src=img,
948
+ dsize=(round(ori_w * (img_size / ori_h)), img_size),
949
+ interpolation=cv2.INTER_AREA,
950
+ )
951
+ else:
952
+ resized_img = cv2.resize(
953
+ src=img,
954
+ dsize=(img_size, round(ori_h * (img_size / ori_w))),
955
+ interpolation=cv2.INTER_AREA,
956
+ )
957
+ return resized_img
958
+
959
+
960
+ def _split_image_3(img, print=False):
961
+ if img.ndim == 2:
962
+ is_2d = True
963
+ else:
964
+ is_2d = False
965
+
966
+ img = _to_3d(img)
967
+ w, h = _get_width_and_height(img)
968
+ if h >= w:
969
+ if print:
970
+ print(f"Resolution: {w}, {h} -> {w}, {h // 2}")
971
+ img1 = img[: h // 2, :, :]
972
+ img2 = img[h // 4 : h // 4 + h // 2, :, :]
973
+ img3 = img[-h // 2 :, :, :]
974
+ else:
975
+ if print:
976
+ print(f"Resolution: {w}, {h} -> {w // 2}, {h}")
977
+ img1 = img[:, : w // 2, :]
978
+ img2 = img[:, w // 2 // 2 : w // 2 // 2 + w // 2, :]
979
+ img3 = img[:, -w // 2 :, :]
980
+ if is_2d:
981
+ img1 = _to_2d(img1)
982
+ img2 = _to_2d(img2)
983
+ img3 = _to_2d(img3)
984
+ return img1, img2, img3
985
+
986
+
987
+ def _split_image_2(img, print=False):
988
+ if img.ndim == 2:
989
+ is_2d = True
990
+ else:
991
+ is_2d = False
992
+
993
+ img = _to_3d(img)
994
+ w, h = _get_width_and_height(img)
995
+ if h >= w:
996
+ if print:
997
+ print(f"Resolution: {w}, {h} -> {w}, {h // 2}")
998
+ img1 = img[: h // 2, :, :]
999
+ img3 = img[-h // 2 :, :, :]
1000
+ else:
1001
+ if print:
1002
+ print(f"Resolution: {w}, {h} -> {w // 2}, {h}")
1003
+ img1 = img[:, : w // 2, :]
1004
+ img3 = img[:, -w // 2 :, :]
1005
+ if is_2d:
1006
+ img1 = _to_2d(img1)
1007
+ img3 = _to_2d(img3)
1008
+ return img1, img3
1009
+
1010
+
1011
+ def _combine_images_3(img, img1, img2, img3):
1012
+ if (img1 is None) and (img2 is None) and (img3 is None):
1013
+ canvas = None
1014
+ else:
1015
+ img1 = _to_2d(img1)
1016
+ img2 = _to_2d(img2)
1017
+ img3 = _to_2d(img3)
1018
+
1019
+ canvas = _get_canvas_same_size_as_image(_to_2d(img), black=True)
1020
+
1021
+ w, h = _get_width_and_height(img)
1022
+ if h >= w:
1023
+ canvas[: h // 2, :] = img1
1024
+ canvas[h // 2 // 2 : h // 2 // 2 + h // 2, :] = np.maximum(
1025
+ canvas[h // 2 // 2 : h // 2 // 2 + h // 2, :], img2
1026
+ )
1027
+ canvas[-h // 2 :, :] = np.maximum(canvas[-h // 2 :, :], img3)
1028
+ else:
1029
+ canvas[:, : w // 2] = img1
1030
+ canvas[:, w // 2 // 2 : w // 2 // 2 + w // 2] = np.maximum(
1031
+ canvas[:, w // 2 // 2 : w // 2 // 2 + w // 2], img2
1032
+ )
1033
+ canvas[:, -w // 2 :] = np.maximum(canvas[:, -w // 2 :], img3)
1034
+ return canvas
1035
+
1036
+
1037
+ def _combine_images_2(img, img1, img2):
1038
+ if (img1 is None) and (img2 is None):
1039
+ canvas = None
1040
+ else:
1041
+ canvas = _get_canvas_same_size_as_image(img, black=True)
1042
+
1043
+ w, h = _get_width_and_height(img)
1044
+ if h >= w:
1045
+ canvas[: h // 2, :] = img1
1046
+ canvas[-h // 2 :, :] = np.maximum(canvas[-h // 2 :, :], img2)
1047
+ else:
1048
+ canvas[:, : w // 2] = img1
1049
+ canvas[:, -w // 2 :] = np.maximum(canvas[:, -w // 2 :], img2)
1050
+ return canvas
1051
+
1052
+
1053
+ def _rotate_90_degrees(img, counterclockwise=False):
1054
+ return cv2.rotate(
1055
+ src=img,
1056
+ rotateCode=cv2.ROTATE_90_COUNTERCLOCKWISE
1057
+ if counterclockwise
1058
+ else cv2.ROTATE_90_CLOCKWISE,
1059
+ )
1060
+
1061
+
1062
+ def save_image_patches(img, bboxes, dir):
1063
+ for row in bboxes.itertuples():
1064
+ patch = _crop_image(
1065
+ img=img,
1066
+ l=row.bbox_x1,
1067
+ t=row.bbox_y1,
1068
+ r=row.bbox_x2,
1069
+ b=row.bbox_y2,
1070
+ )
1071
+ patch_w = row.bbox_x2 - row.bbox_x1
1072
+ patch_h = row.bbox_y2 - row.bbox_y1
1073
+ if patch_h > patch_w:
1074
+ patch = _rotate_90_degrees(patch, counterclockwise=False)
1075
+
1076
+ save_image(img1=patch, path=Path(dir) / f"{str(row.Index).zfill(4)}.jpg")
1077
+
1078
+
1079
+ def get_minimum_area_bounding_rectangle(mask):
1080
+ bool = _to_2d(mask.astype("uint8")) != 0
1081
+ nonzero_x = np.where(bool.any(axis=0))[0]
1082
+ nonzero_y = np.where(bool.any(axis=1))[0]
1083
+ if len(nonzero_x) != 0 and len(nonzero_y) != 0:
1084
+ bbox_x1 = nonzero_x[0]
1085
+ bbox_x2 = nonzero_x[-1]
1086
+ bbox_y1 = nonzero_y[0]
1087
+ bbox_y2 = nonzero_y[-1]
1088
+ return int(bbox_x1), int(bbox_y1), int(bbox_x2), int(bbox_y2)
1089
+ else:
1090
+ return 0, 0, 0, 0
1091
+
1092
+
1093
+ def get_minimum_area_bounding_rectangle2(mask, l, t, r, b):
1094
+ bool = _to_2d(mask.astype("uint8")) != 0
1095
+ nonzero_x = np.where(bool.any(axis=0))[0]
1096
+ nonzero_y = np.where(bool.any(axis=1))[0]
1097
+ try:
1098
+ new_l = nonzero_x[np.where(l < nonzero_x)][0]
1099
+ except Exception:
1100
+ new_l = l
1101
+ try:
1102
+ new_t = nonzero_y[np.where(t < nonzero_y)][0]
1103
+ except Exception:
1104
+ new_t = t
1105
+ try:
1106
+ new_r = nonzero_x[np.where(nonzero_x < r)][-1]
1107
+ except Exception:
1108
+ new_r = r
1109
+ try:
1110
+ new_b = nonzero_y[np.where(nonzero_y < b)][-1]
1111
+ except Exception:
1112
+ new_b = b
1113
+ return new_l, new_t, new_r, new_b
1114
+
1115
+
1116
+ def _downsample_image(img):
1117
+ ori_w, ori_h = _get_width_and_height(img)
1118
+ resized = _resize_image(img, w=ori_w // 2, h=ori_h // 2)
1119
+ return resized
1120
+
1121
+
1122
+ def _upsample_image(img):
1123
+ ori_w, ori_h = _get_width_and_height(img)
1124
+ resized = _resize_image(img, w=ori_w * 2, h=ori_h * 2)
1125
+ return resized
1126
+
1127
+
1128
+ def _get_pseudo_image(img, mask, invert=False):
1129
+ if invert:
1130
+ mask = _invert_image(mask)
1131
+ rows, cols = np.nonzero(_to_2d(mask))
1132
+ pseudo_outer = img[rows, cols, :].reshape((1, -1, 3))
1133
+ return pseudo_outer
1134
+
1135
+
1136
+ def resize_coordinates_and_image_to_fit_to_maximum_pixel_counts(
1137
+ bboxes, img, max_pixel_counts=1530
1138
+ ):
1139
+ w, h = _get_width_and_height(img)
1140
+ ratio = min(max_pixel_counts / h, max_pixel_counts / w)
1141
+ if ratio < 1:
1142
+ for col in ["xmin", "ymin", "xmax", "ymax"]:
1143
+ bboxes[col] = bboxes[col].apply(lambda x: int(x * ratio))
1144
+
1145
+ img = cv2.resize(
1146
+ src=img,
1147
+ dsize=(int(w * ratio), int(h * ratio)),
1148
+ interpolation=cv2.INTER_LANCZOS4,
1149
+ )
1150
+ return bboxes, img
1151
+
1152
+
1153
+ def get_image_patches_3(img, text_stroke_mask, mask1, mask2):
1154
+ splitting_mask = get_splitting_mask(text_stroke_mask)
1155
+
1156
+ _, _, stats, _ = cv2.connectedComponentsWithStats(
1157
+ image=_to_2d(splitting_mask), connectivity=4
1158
+ )
1159
+ ls_patches = list()
1160
+ for xmin, ymin, width, height, px_cnt in stats[1:, :]:
1161
+ xmax = xmin + width
1162
+ ymax = ymin + height
1163
+
1164
+ cropped_img = _crop_image(img=img, l=xmin, t=ymin, r=xmax, b=ymax)
1165
+ cropped_mask1 = _crop_image(img=mask1, l=xmin, t=ymin, r=xmax, b=ymax)
1166
+ cropped_mask2 = _crop_image(img=mask2, l=xmin, t=ymin, r=xmax, b=ymax)
1167
+ ls_patches.append(
1168
+ {
1169
+ "xmin": xmin,
1170
+ "ymin": ymin,
1171
+ "xmax": xmax,
1172
+ "ymax": ymax,
1173
+ "img": cropped_img,
1174
+ "mask1": cropped_mask1,
1175
+ "mask2": cropped_mask2,
1176
+ }
1177
+ )
1178
+ return ls_patches
1179
+
1180
+
1181
+ def get_image_patches_2(img, mask1, mask2):
1182
+ splitting_mask = get_splitting_mask(mask1)
1183
+
1184
+ _, _, stats, _ = cv2.connectedComponentsWithStats(
1185
+ image=_to_2d(splitting_mask), connectivity=4
1186
+ )
1187
+ ls_patches = list()
1188
+ for x1, y1, w, h, _ in stats[1:, :]:
1189
+ x2 = x1 + w
1190
+ y2 = y1 + h
1191
+
1192
+ cropped_img = _crop_image(img=img, l=x1, t=y1, r=x2, b=y2)
1193
+ cropped_mask1 = _crop_image(img=mask1, l=x1, t=y1, r=x2, b=y2)
1194
+ cropped_mask2 = _crop_image(img=mask2, l=x1, t=y1, r=x2, b=y2)
1195
+
1196
+ ls_patches.append(
1197
+ {
1198
+ "x1": x1,
1199
+ "y1": y1,
1200
+ "x2": x2,
1201
+ "y2": y2,
1202
+ "img": cropped_img,
1203
+ "mask1": cropped_mask1,
1204
+ "mask2": cropped_mask2,
1205
+ }
1206
+ )
1207
+ return ls_patches
1208
+
1209
+
1210
+ def get_splitting_mask(text_stroke_mask):
1211
+ splitting_mask = _dilate_mask(text_stroke_mask, kernel_size=200)
1212
+ return splitting_mask
1213
+
1214
+
1215
+ def enhance_sharpness(img):
1216
+ """img의 μ„ λͺ…도λ₯Ό λ†’μž„. 3가지 방법이 있음(sharpening filter, unsharpening mask, pil sharpening)
1217
+ 3 방법 쀑 PIL 이 κ°€μž₯ μ›λ³Έμ˜ 색변화가 적음
1218
+ Args:
1219
+ img (_np.ndarray_): 이미지
1220
+
1221
+ Returns:
1222
+ _np.ndarray_: κ²°κ³Ό 이미지
1223
+ """
1224
+ # sharpening_k = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
1225
+ # hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
1226
+ # sharpened_v = cv2.filter2D(hsv[..., 2], -1, sharpening_k)
1227
+ # hsv[..., 2] = sharpened_v
1228
+ # img_patch2 = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
1229
+
1230
+ # src_ycrcb = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
1231
+ # src_f = src_ycrcb[:, :, 0].astype(np.float32)
1232
+ # blr = cv2.GaussianBlur(src_f, (0, 0), 2.0)
1233
+ # src_ycrcb[:, :, 0] = np.clip(2. * src_f - blr, 0, 255).astype(np.uint8)
1234
+ # img_patch3 = cv2.cvtColor(src_ycrcb, cv2.COLOR_YCrCb2RGB)
1235
+
1236
+ pil_img = _to_pil(img)
1237
+ sharpness_img = ImageEnhance.Sharpness(pil_img).enhance(2)
1238
+ result_img = _to_array(sharpness_img)
1239
+
1240
+ return result_img
1241
+
1242
+
1243
+ def mask2point(mask):
1244
+ # mask (H,W,3) 0 or 255 -> (N,2)
1245
+ mask = _to_2d(mask)
1246
+ indices = np.argwhere(mask == 255)
1247
+ return indices
1248
+
1249
+
1250
+ def get_corner(corner_coords):
1251
+ # corner_coords (N,2) each point means (y,x)
1252
+ cy, cx = np.mean(corner_coords, axis=0)
1253
+ quadrant_1 = corner_coords[(corner_coords[:, 0] < cy) & (corner_coords[:, 1] >= cx)]
1254
+ rt = quadrant_1[:, 1].max(), quadrant_1[:, 0].min()
1255
+
1256
+ quadrant_2 = corner_coords[(corner_coords[:, 0] < cy) & (corner_coords[:, 1] < cx)]
1257
+ lt = quadrant_2[:, 1].min(), quadrant_2[:, 0].min()
1258
+
1259
+ quadrant_3 = corner_coords[(corner_coords[:, 0] >= cy) & (corner_coords[:, 1] < cx)]
1260
+ lb = quadrant_3[:, 1].min(), quadrant_3[:, 0].max()
1261
+
1262
+ quadrant_4 = corner_coords[
1263
+ (corner_coords[:, 0] >= cy) & (corner_coords[:, 1] >= cx)
1264
+ ]
1265
+ rb = quadrant_4[:, 1].max(), quadrant_4[:, 0].max()
1266
+
1267
+ return lt, rt, rb, lb
1268
+
1269
+
1270
+ def get_dst_mask(mask):
1271
+ mask = _to_2d(mask)
1272
+ dst = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
1273
+ # 거리 값을 0 ~ 255 λ²”μœ„λ‘œ μ •κ·œν™” ---β‘‘
1274
+ dist_transform_normalized = cv2.normalize(
1275
+ dst, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U
1276
+ )
1277
+ return _to_3d(dist_transform_normalized)
1278
+
1279
+
1280
+ def unwarp(img, src, dst):
1281
+ h, w = img.shape[:2]
1282
+ # use cv2.getPerspectiveTransform() to get M, the transform matrix, and Minv, the inverse
1283
+ M = cv2.getPerspectiveTransform(src, dst)
1284
+ # use cv2.warpPerspective() to warp your image to a top-down view
1285
+ warped = cv2.warpPerspective(img, M, (w, h), flags=cv2.INTER_LINEAR)
1286
+
1287
+ return warped, M
1288
+
1289
+
1290
+ def perspective_correction(img, src=None, vis=False, method: PC_TYPE = PC_TYPE.HARRIS):
1291
+ # img (H,W,C) 0~255, src=[[ltx,lty],[rtx,rty],[rbx,rby],[lbx,lby]]
1292
+ if src is None:
1293
+ gray = _to_grayscale(img)
1294
+
1295
+ if not isinstance(method, PC_TYPE):
1296
+ raise ValueError(
1297
+ f"Invalid method: {method}. Expected one of {list(PC_TYPE)}."
1298
+ )
1299
+
1300
+ if method == PC_TYPE.HARRIS:
1301
+ corner = cv2.cornerHarris(gray, 5, 3, 0.04) # (H,W) value: corner score
1302
+ threshold = 0.005 * corner.max()
1303
+ corner_coords = np.argwhere(corner > threshold)
1304
+
1305
+ elif method == PC_TYPE.EDGES_CONTOURS:
1306
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
1307
+ edges = cv2.Canny(blurred, 50, 150)
1308
+ contours, _ = cv2.findContours(
1309
+ edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
1310
+ )
1311
+ contour_points = []
1312
+ for cs in contours:
1313
+ c = [css for css in cs]
1314
+ contour_points.extend(c)
1315
+ corner_coords = np.array(contour_points).reshape(-1, 2)[..., ::-1]
1316
+
1317
+ elif method == PC_TYPE.GFTT:
1318
+ corners = cv2.goodFeaturesToTrack(
1319
+ gray, 0, 0.01, 5, blockSize=3, useHarrisDetector=True, k=0.03
1320
+ )
1321
+ corner_coords = corners.reshape(corners.shape[0], 2)[..., ::-1]
1322
+
1323
+ elif method == PC_TYPE.FAST:
1324
+ th = 50
1325
+ fast = cv2.FastFeatureDetector_create(th)
1326
+ keypoints = fast.detect(gray)
1327
+ corner_coords = np.array([[kp.pt[1], kp.pt[0]] for kp in keypoints])
1328
+
1329
+ elif method == PC_TYPE.KAZE:
1330
+ # feature = cv2.SIFT_create()
1331
+ feature = cv2.KAZE_create()
1332
+
1333
+ keypoints = feature.detect(gray)
1334
+ corner_coords = np.array([[kp.pt[1], kp.pt[0]] for kp in keypoints])
1335
+
1336
+ if vis:
1337
+ view_img = img.copy()
1338
+ for corner in corner_coords:
1339
+ y, x = corner
1340
+ cv2.circle(view_img, (int(x), int(y)), 3, (255, 0, 0), 2)
1341
+ save_image(view_img, path="vis_corner.png")
1342
+
1343
+ lt, rt, rb, lb = get_corner(corner_coords)
1344
+
1345
+ src = np.float32([lt, rt, rb, lb])
1346
+
1347
+ dst = np.float32(
1348
+ [
1349
+ (0, 0),
1350
+ (img.shape[1] - 1, 0),
1351
+ (img.shape[1] - 1, img.shape[0] - 1),
1352
+ (0, img.shape[0] - 1),
1353
+ ]
1354
+ )
1355
+
1356
+ result, M = unwarp(img, src, dst)
1357
+ save_image(result, path="cv_result.png")
1358
+ return result
1359
+
1360
+
1361
+ if __name__ == "__main__":
1362
+ image_url = "https://d2reotjpatzlok.cloudfront.net/qr-place/item/QR_20240726_2441_2_LZ1ZFCT38HN7PPCEZR8H.jpg"
1363
+ img, imgdata, format = load_image(image_url, with_byte=True, with_format=True)
1364
+ perspective_correction(img, vis=True)
rect_main.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import warnings
3
+ from collections import defaultdict
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from config import Config
11
+ from data_utils.image_utils import _to_2d
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+ DocTr_Plus = importlib.import_module("models.DocTr-Plus.inference")
16
+ DocScanner = importlib.import_module("models.DocScanner.inference")
17
+
18
+ cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ mask_dict = defaultdict(int)
21
+
22
+
23
+ def load_geotrp_model(cuda, path=""):
24
+
25
+ _GeoTrP = DocTr_Plus.GeoTrP()
26
+ _GeoTrP = _GeoTrP.to(cuda)
27
+ DocTr_Plus.reload_model(_GeoTrP.GeoTr, path)
28
+ _GeoTrP.eval()
29
+
30
+ return _GeoTrP
31
+
32
+
33
+ def load_docscanner_model(cuda, path_l="", path_m=""):
34
+
35
+ net = DocScanner.Net().to(cuda)
36
+ DocScanner.reload_seg_model(net.msk, path_m)
37
+ DocScanner.reload_rec_model(net.bm, path_l)
38
+ net.eval()
39
+
40
+ return net
41
+
42
+
43
+ def preprocess_image(img, target_size=[288, 288]):
44
+ im_ori = img[:, :, :3] / 255.0
45
+ h_, w_, _ = im_ori.shape
46
+ im_ori_resized = cv2.resize(im_ori, (288, 288))
47
+
48
+ im = cv2.resize(im_ori_resized, target_size)
49
+ im = im.transpose(2, 0, 1)
50
+ im = torch.from_numpy(im).float().unsqueeze(0)
51
+
52
+ return im_ori, im, h_, w_
53
+
54
+
55
+ def geotrp_rec(img, model):
56
+ im_ori, im, h_, w_ = preprocess_image(img)
57
+
58
+ with torch.no_grad():
59
+ bm = model(im.cuda())
60
+ bm = bm.cpu().numpy()[0]
61
+ bm0 = bm[0, :, :]
62
+ bm1 = bm[1, :, :]
63
+ bm0 = cv2.blur(bm0, (3, 3))
64
+ bm1 = cv2.blur(bm1, (3, 3))
65
+
66
+ img_geo = cv2.remap(im_ori, bm0, bm1, cv2.INTER_LINEAR) * 255
67
+ img_geo = cv2.resize(img_geo, (w_, h_))
68
+
69
+ return img_geo
70
+
71
+
72
+ def docscanner_get_mask(img, model):
73
+ _, im, h, w = preprocess_image(img)
74
+
75
+ with torch.no_grad():
76
+ _, msk = model(im.cuda())
77
+ msk = msk.cpu()
78
+
79
+ mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8)
80
+ mask_resized = cv2.resize(mask_np, (w, h))
81
+
82
+ return mask_resized
83
+
84
+
85
+ def docscanner_rec_img(img, model):
86
+ im_ori, im, h, w = preprocess_image(img)
87
+
88
+ with torch.no_grad():
89
+ bm = model(im.cuda())
90
+ bm = bm.cpu()
91
+
92
+ # save rectified image
93
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
94
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
95
+ bm0 = cv2.blur(bm0, (3, 3))
96
+ bm1 = cv2.blur(bm1, (3, 3))
97
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
98
+ out = F.grid_sample(
99
+ torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(),
100
+ lbl,
101
+ align_corners=True,
102
+ )
103
+ img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8)
104
+
105
+ return img
106
+
107
+
108
+
109
+ def docscanner_rec(img, model):
110
+ im_ori = img[:, :, :3] / 255.0
111
+ h, w, _ = im_ori.shape
112
+ im = cv2.resize(im_ori, (288, 288))
113
+ im = im.transpose(2, 0, 1)
114
+ im = torch.from_numpy(im).float().unsqueeze(0)
115
+
116
+ with torch.no_grad():
117
+ bm, msk = model(im.cuda())
118
+ bm = bm.cpu()
119
+ msk = msk.cpu()
120
+
121
+ mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8)
122
+ mask_resized = cv2.resize(mask_np, (w, h))
123
+ mask_img = mask_resized
124
+
125
+ # save rectified image
126
+ bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
127
+ bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
128
+ bm0 = cv2.blur(bm0, (3, 3))
129
+ bm1 = cv2.blur(bm1, (3, 3))
130
+ lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
131
+ out = F.grid_sample(
132
+ torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(),
133
+ lbl,
134
+ align_corners=True,
135
+ )
136
+ img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8)
137
+
138
+ return img, mask_img
139
+
140
+
141
+ # μΆ”ν›„ data_utils에 넣을 μ˜ˆμ •
142
+ def get_mask_white_area(mask):
143
+ """
144
+ Get the white area (non-zero pixels) of a mask.
145
+
146
+ Args:
147
+ mask (np.ndarray): Input mask image (2D or 3D array)
148
+
149
+ Returns:
150
+ np.ndarray: Array of (y, x) coordinates of white pixels
151
+ """
152
+ mask = _to_2d(mask)
153
+ white_pixels = np.argwhere(mask > 0)
154
+ return white_pixels
155
+
156
+
157
+ def main():
158
+
159
+ config = Config()
160
+
161
+ img = cv2.imread("input/test.jpg") # μ½”λ“œ μ‹€ν–‰μ‹œ μˆ˜μ • ν•„μš”
162
+
163
+ docscanner = load_docscanner_model(
164
+ cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path
165
+ )
166
+ doctr = load_geotrp_model(cuda, path=config.get_geotr_model_path)
167
+
168
+ mask = docscanner_get_mask(img, docscanner)
169
+ mask_dict.add(get_mask_white_area(mask))
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()
seg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb79fdec55a5ed435dc74d8112aa9285d8213bae475022f711c709744fb19dd4
3
+ size 4715923