ydshieh HF staff commited on
Commit
69254da
1 Parent(s): 0c3ab1a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +133 -0
README.md CHANGED
@@ -51,3 +51,136 @@ print(processed_text)
51
  print(entities)
52
  # `[('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]`
53
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  print(entities)
52
  # `[('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]`
53
  ```
54
+
55
+ ## Draw the bounding bboxes of the entities on the image
56
+
57
+ Once you have the `entities`, you can use the following helper function to draw their bounding bboxes on the image:
58
+
59
+ ```python
60
+ import os
61
+ import numpy as np
62
+ import torch
63
+ from PIL import Image
64
+ import torchvision.transforms as T
65
+ import cv2
66
+ import requests
67
+
68
+
69
+ def is_overlapping(rect1, rect2):
70
+ x1, y1, x2, y2 = rect1
71
+ x3, y3, x4, y4 = rect2
72
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
73
+
74
+
75
+ def draw_entity_boxes_on_image(image, entities, show=False, save_path=None):
76
+ """_summary_
77
+ Args:
78
+ image (_type_): image or image path
79
+ collect_entity_location (_type_): _description_
80
+ """
81
+ if isinstance(image, Image.Image):
82
+ image_h = image.height
83
+ image_w = image.width
84
+ image = np.array(image)[:, :, [2, 1, 0]]
85
+ elif isinstance(image, str):
86
+ if os.path.exists(image):
87
+ pil_img = Image.open(image).convert("RGB")
88
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
89
+ image_h = pil_img.height
90
+ image_w = pil_img.width
91
+ else:
92
+ raise ValueError(f"invaild image path, {image}")
93
+ elif isinstance(image, torch.Tensor):
94
+ # pdb.set_trace()
95
+ image_tensor = image.cpu()
96
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
97
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
98
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
99
+ pil_img = T.ToPILImage()(image_tensor)
100
+ image_h = pil_img.height
101
+ image_w = pil_img.width
102
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
103
+ else:
104
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
105
+
106
+ if len(entities) == 0:
107
+ return image
108
+
109
+ new_image = image.copy()
110
+ previous_bboxes = []
111
+ # size of text
112
+ text_size = 2
113
+ # thickness of text
114
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
115
+ box_line = 3
116
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
117
+ base_height = int(text_height * 0.675)
118
+ text_offset_original = text_height - base_height
119
+ text_spaces = 3
120
+
121
+ for entity_name, (start, end), bboxes in entities:
122
+ for (x1_norm, y1_norm, x2_norm, y2_norm) in bboxes:
123
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
124
+ # draw bbox
125
+ # random color
126
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
127
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
128
+
129
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
130
+
131
+ x1 = orig_x1 - l_o
132
+ y1 = orig_y1 - l_o
133
+
134
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
135
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
136
+ x1 = orig_x1 + r_o
137
+
138
+ # add text background
139
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
140
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
141
+
142
+ for prev_bbox in previous_bboxes:
143
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
144
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
145
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
146
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
147
+
148
+ if text_bg_y2 >= image_h:
149
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
150
+ text_bg_y2 = image_h
151
+ y1 = image_h
152
+ break
153
+
154
+ alpha = 0.5
155
+ for i in range(text_bg_y1, text_bg_y2):
156
+ for j in range(text_bg_x1, text_bg_x2):
157
+ if i < image_h and j < image_w:
158
+ if j < text_bg_x1 + 1.35 * c_width:
159
+ # original color
160
+ bg_color = color
161
+ else:
162
+ # white
163
+ bg_color = [255, 255, 255]
164
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)
165
+
166
+ cv2.putText(
167
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
168
+ )
169
+ # previous_locations.append((x1, y1))
170
+ previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
171
+
172
+ pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
173
+ if save_path:
174
+ pil_image.save(save_path)
175
+ if show:
176
+ pil_image.show()
177
+
178
+ return new_image
179
+
180
+
181
+ # From the previous code example
182
+ entities = [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]
183
+
184
+ # Draw the bounding bboxes
185
+ draw_entity_boxes_on_image(image, entities, show=True)
186
+ ```