ragavsachdeva
commited on
Commit
•
b36f7c2
1
Parent(s):
7bb1e4c
Upload model
Browse files- config.json +475 -0
- configuration_magi.py +38 -0
- modelling_magi.py +486 -0
- processing_magi.py +274 -0
- pytorch_model.bin +3 -0
- utils.py +391 -0
config.json
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "to_push",
|
3 |
+
"architectures": [
|
4 |
+
"MagiModel"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_magi.MagiConfig",
|
8 |
+
"AutoModel": "modelling_magi.MagiModel"
|
9 |
+
},
|
10 |
+
"crop_embedding_image_preprocessing_config": {
|
11 |
+
"_processor_class": null,
|
12 |
+
"do_normalize": true,
|
13 |
+
"do_rescale": true,
|
14 |
+
"do_resize": true,
|
15 |
+
"image_mean": [
|
16 |
+
0.485,
|
17 |
+
0.456,
|
18 |
+
0.406
|
19 |
+
],
|
20 |
+
"image_processor_type": "ViTImageProcessor",
|
21 |
+
"image_std": [
|
22 |
+
0.229,
|
23 |
+
0.224,
|
24 |
+
0.225
|
25 |
+
],
|
26 |
+
"resample": 2,
|
27 |
+
"rescale_factor": 0.00392156862745098,
|
28 |
+
"size": {
|
29 |
+
"height": 224,
|
30 |
+
"width": 224
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"crop_embedding_model_config": {
|
34 |
+
"_name_or_path": "facebook/vit-mae-base",
|
35 |
+
"add_cross_attention": false,
|
36 |
+
"architectures": [
|
37 |
+
"ViTMAEForPreTraining"
|
38 |
+
],
|
39 |
+
"attention_probs_dropout_prob": 0.0,
|
40 |
+
"bad_words_ids": null,
|
41 |
+
"begin_suppress_tokens": null,
|
42 |
+
"bos_token_id": null,
|
43 |
+
"chunk_size_feed_forward": 0,
|
44 |
+
"cross_attention_hidden_size": null,
|
45 |
+
"decoder_hidden_size": 512,
|
46 |
+
"decoder_intermediate_size": 2048,
|
47 |
+
"decoder_num_attention_heads": 16,
|
48 |
+
"decoder_num_hidden_layers": 8,
|
49 |
+
"decoder_start_token_id": null,
|
50 |
+
"diversity_penalty": 0.0,
|
51 |
+
"do_sample": false,
|
52 |
+
"early_stopping": false,
|
53 |
+
"encoder_no_repeat_ngram_size": 0,
|
54 |
+
"eos_token_id": null,
|
55 |
+
"exponential_decay_length_penalty": null,
|
56 |
+
"finetuning_task": null,
|
57 |
+
"forced_bos_token_id": null,
|
58 |
+
"forced_eos_token_id": null,
|
59 |
+
"hidden_act": "gelu",
|
60 |
+
"hidden_dropout_prob": 0.0,
|
61 |
+
"hidden_size": 768,
|
62 |
+
"id2label": {
|
63 |
+
"0": "LABEL_0",
|
64 |
+
"1": "LABEL_1"
|
65 |
+
},
|
66 |
+
"image_size": 224,
|
67 |
+
"initializer_range": 0.02,
|
68 |
+
"intermediate_size": 3072,
|
69 |
+
"is_decoder": false,
|
70 |
+
"is_encoder_decoder": false,
|
71 |
+
"label2id": {
|
72 |
+
"LABEL_0": 0,
|
73 |
+
"LABEL_1": 1
|
74 |
+
},
|
75 |
+
"layer_norm_eps": 1e-12,
|
76 |
+
"length_penalty": 1.0,
|
77 |
+
"mask_ratio": 0.75,
|
78 |
+
"max_length": 20,
|
79 |
+
"min_length": 0,
|
80 |
+
"model_type": "",
|
81 |
+
"no_repeat_ngram_size": 0,
|
82 |
+
"norm_pix_loss": false,
|
83 |
+
"num_attention_heads": 12,
|
84 |
+
"num_beam_groups": 1,
|
85 |
+
"num_beams": 1,
|
86 |
+
"num_channels": 3,
|
87 |
+
"num_hidden_layers": 12,
|
88 |
+
"num_return_sequences": 1,
|
89 |
+
"output_attentions": false,
|
90 |
+
"output_hidden_states": false,
|
91 |
+
"output_scores": false,
|
92 |
+
"pad_token_id": null,
|
93 |
+
"patch_size": 16,
|
94 |
+
"prefix": null,
|
95 |
+
"problem_type": null,
|
96 |
+
"pruned_heads": {},
|
97 |
+
"qkv_bias": true,
|
98 |
+
"remove_invalid_values": false,
|
99 |
+
"repetition_penalty": 1.0,
|
100 |
+
"return_dict": true,
|
101 |
+
"return_dict_in_generate": false,
|
102 |
+
"sep_token_id": null,
|
103 |
+
"suppress_tokens": null,
|
104 |
+
"task_specific_params": null,
|
105 |
+
"temperature": 1.0,
|
106 |
+
"tf_legacy_loss": false,
|
107 |
+
"tie_encoder_decoder": false,
|
108 |
+
"tie_word_embeddings": true,
|
109 |
+
"tokenizer_class": null,
|
110 |
+
"top_k": 50,
|
111 |
+
"top_p": 1.0,
|
112 |
+
"torch_dtype": "float32",
|
113 |
+
"torchscript": false,
|
114 |
+
"typical_p": 1.0,
|
115 |
+
"use_bfloat16": false
|
116 |
+
},
|
117 |
+
"detection_image_preprocessing_config": {
|
118 |
+
"_processor_class": null,
|
119 |
+
"do_normalize": true,
|
120 |
+
"do_pad": true,
|
121 |
+
"do_rescale": true,
|
122 |
+
"do_resize": true,
|
123 |
+
"format": "coco_detection",
|
124 |
+
"image_mean": [
|
125 |
+
0.485,
|
126 |
+
0.456,
|
127 |
+
0.406
|
128 |
+
],
|
129 |
+
"image_processor_type": "ConditionalDetrImageProcessor",
|
130 |
+
"image_std": [
|
131 |
+
0.229,
|
132 |
+
0.224,
|
133 |
+
0.225
|
134 |
+
],
|
135 |
+
"resample": 2,
|
136 |
+
"rescale_factor": 0.00392156862745098,
|
137 |
+
"size": {
|
138 |
+
"longest_edge": 1333,
|
139 |
+
"shortest_edge": 800
|
140 |
+
}
|
141 |
+
},
|
142 |
+
"detection_model_config": {
|
143 |
+
"_name_or_path": "microsoft/conditional-detr-resnet-50",
|
144 |
+
"activation_dropout": 0.0,
|
145 |
+
"activation_function": "relu",
|
146 |
+
"add_cross_attention": false,
|
147 |
+
"architectures": [
|
148 |
+
"ConditionalDETRForObjectDetection"
|
149 |
+
],
|
150 |
+
"attention_dropout": 0.0,
|
151 |
+
"auxiliary_loss": false,
|
152 |
+
"backbone": "resnet50",
|
153 |
+
"backbone_config": null,
|
154 |
+
"bad_words_ids": null,
|
155 |
+
"bbox_cost": 5,
|
156 |
+
"bbox_loss_coefficient": 5,
|
157 |
+
"begin_suppress_tokens": null,
|
158 |
+
"bos_token_id": null,
|
159 |
+
"chunk_size_feed_forward": 0,
|
160 |
+
"class_cost": 2,
|
161 |
+
"cls_loss_coefficient": 2,
|
162 |
+
"cross_attention_hidden_size": null,
|
163 |
+
"d_model": 256,
|
164 |
+
"decoder_attention_heads": 8,
|
165 |
+
"decoder_ffn_dim": 2048,
|
166 |
+
"decoder_layerdrop": 0.0,
|
167 |
+
"decoder_layers": 6,
|
168 |
+
"decoder_start_token_id": null,
|
169 |
+
"dice_loss_coefficient": 1,
|
170 |
+
"dilation": false,
|
171 |
+
"diversity_penalty": 0.0,
|
172 |
+
"do_sample": false,
|
173 |
+
"dropout": 0.1,
|
174 |
+
"early_stopping": false,
|
175 |
+
"encoder_attention_heads": 8,
|
176 |
+
"encoder_ffn_dim": 2048,
|
177 |
+
"encoder_layerdrop": 0.0,
|
178 |
+
"encoder_layers": 6,
|
179 |
+
"encoder_no_repeat_ngram_size": 0,
|
180 |
+
"eos_token_id": null,
|
181 |
+
"exponential_decay_length_penalty": null,
|
182 |
+
"finetuning_task": null,
|
183 |
+
"focal_alpha": 0.25,
|
184 |
+
"forced_bos_token_id": null,
|
185 |
+
"forced_eos_token_id": null,
|
186 |
+
"giou_cost": 2,
|
187 |
+
"giou_loss_coefficient": 2,
|
188 |
+
"id2label": {
|
189 |
+
"0": "LABEL_0",
|
190 |
+
"1": "LABEL_1",
|
191 |
+
"2": "LABEL_2"
|
192 |
+
},
|
193 |
+
"init_std": 0.02,
|
194 |
+
"init_xavier_std": 1.0,
|
195 |
+
"is_decoder": false,
|
196 |
+
"is_encoder_decoder": true,
|
197 |
+
"label2id": {
|
198 |
+
"LABEL_0": 0,
|
199 |
+
"LABEL_1": 1,
|
200 |
+
"LABEL_2": 2
|
201 |
+
},
|
202 |
+
"length_penalty": 1.0,
|
203 |
+
"mask_loss_coefficient": 1,
|
204 |
+
"max_length": 20,
|
205 |
+
"max_position_embeddings": 1024,
|
206 |
+
"min_length": 0,
|
207 |
+
"model_type": "",
|
208 |
+
"no_repeat_ngram_size": 0,
|
209 |
+
"num_beam_groups": 1,
|
210 |
+
"num_beams": 1,
|
211 |
+
"num_channels": 3,
|
212 |
+
"num_hidden_layers": 6,
|
213 |
+
"num_queries": 305,
|
214 |
+
"num_return_sequences": 1,
|
215 |
+
"output_attentions": false,
|
216 |
+
"output_hidden_states": false,
|
217 |
+
"output_scores": false,
|
218 |
+
"pad_token_id": null,
|
219 |
+
"position_embedding_type": "sine",
|
220 |
+
"prefix": null,
|
221 |
+
"problem_type": null,
|
222 |
+
"pruned_heads": {},
|
223 |
+
"remove_invalid_values": false,
|
224 |
+
"repetition_penalty": 1.0,
|
225 |
+
"return_dict": true,
|
226 |
+
"return_dict_in_generate": false,
|
227 |
+
"scale_embedding": false,
|
228 |
+
"sep_token_id": null,
|
229 |
+
"suppress_tokens": null,
|
230 |
+
"task_specific_params": null,
|
231 |
+
"temperature": 1.0,
|
232 |
+
"tf_legacy_loss": false,
|
233 |
+
"tie_encoder_decoder": false,
|
234 |
+
"tie_word_embeddings": true,
|
235 |
+
"tokenizer_class": null,
|
236 |
+
"top_k": 50,
|
237 |
+
"top_p": 1.0,
|
238 |
+
"torch_dtype": "float32",
|
239 |
+
"torchscript": false,
|
240 |
+
"typical_p": 1.0,
|
241 |
+
"use_bfloat16": false,
|
242 |
+
"use_pretrained_backbone": true,
|
243 |
+
"use_timm_backbone": true
|
244 |
+
},
|
245 |
+
"disable_crop_embeddings": false,
|
246 |
+
"disable_detections": false,
|
247 |
+
"disable_ocr": false,
|
248 |
+
"model_type": "magi",
|
249 |
+
"ocr_model_config": {
|
250 |
+
"_name_or_path": "/work/rs/logs/manga_ocr/nt8rn2ul/",
|
251 |
+
"add_cross_attention": false,
|
252 |
+
"architectures": [
|
253 |
+
"VisionEncoderDecoderModel"
|
254 |
+
],
|
255 |
+
"bad_words_ids": null,
|
256 |
+
"begin_suppress_tokens": null,
|
257 |
+
"bos_token_id": null,
|
258 |
+
"chunk_size_feed_forward": 0,
|
259 |
+
"cross_attention_hidden_size": null,
|
260 |
+
"decoder": {
|
261 |
+
"_name_or_path": "",
|
262 |
+
"activation_dropout": 0.0,
|
263 |
+
"activation_function": "gelu",
|
264 |
+
"add_cross_attention": true,
|
265 |
+
"architectures": null,
|
266 |
+
"attention_dropout": 0.0,
|
267 |
+
"bad_words_ids": null,
|
268 |
+
"begin_suppress_tokens": null,
|
269 |
+
"bos_token_id": 0,
|
270 |
+
"chunk_size_feed_forward": 0,
|
271 |
+
"classifier_dropout": 0.0,
|
272 |
+
"cross_attention_hidden_size": 768,
|
273 |
+
"d_model": 1024,
|
274 |
+
"decoder_attention_heads": 16,
|
275 |
+
"decoder_ffn_dim": 4096,
|
276 |
+
"decoder_layerdrop": 0.0,
|
277 |
+
"decoder_layers": 12,
|
278 |
+
"decoder_start_token_id": 2,
|
279 |
+
"diversity_penalty": 0.0,
|
280 |
+
"do_sample": false,
|
281 |
+
"dropout": 0.1,
|
282 |
+
"early_stopping": false,
|
283 |
+
"encoder_no_repeat_ngram_size": 0,
|
284 |
+
"eos_token_id": 2,
|
285 |
+
"exponential_decay_length_penalty": null,
|
286 |
+
"finetuning_task": null,
|
287 |
+
"forced_bos_token_id": null,
|
288 |
+
"forced_eos_token_id": null,
|
289 |
+
"id2label": {
|
290 |
+
"0": "LABEL_0",
|
291 |
+
"1": "LABEL_1"
|
292 |
+
},
|
293 |
+
"init_std": 0.02,
|
294 |
+
"is_decoder": true,
|
295 |
+
"is_encoder_decoder": false,
|
296 |
+
"label2id": {
|
297 |
+
"LABEL_0": 0,
|
298 |
+
"LABEL_1": 1
|
299 |
+
},
|
300 |
+
"layernorm_embedding": true,
|
301 |
+
"length_penalty": 1.0,
|
302 |
+
"max_length": 20,
|
303 |
+
"max_position_embeddings": 512,
|
304 |
+
"min_length": 0,
|
305 |
+
"model_type": "trocr",
|
306 |
+
"no_repeat_ngram_size": 0,
|
307 |
+
"num_beam_groups": 1,
|
308 |
+
"num_beams": 1,
|
309 |
+
"num_return_sequences": 1,
|
310 |
+
"output_attentions": false,
|
311 |
+
"output_hidden_states": false,
|
312 |
+
"output_scores": false,
|
313 |
+
"pad_token_id": 1,
|
314 |
+
"prefix": null,
|
315 |
+
"problem_type": null,
|
316 |
+
"pruned_heads": {},
|
317 |
+
"remove_invalid_values": false,
|
318 |
+
"repetition_penalty": 1.0,
|
319 |
+
"return_dict": true,
|
320 |
+
"return_dict_in_generate": false,
|
321 |
+
"scale_embedding": false,
|
322 |
+
"sep_token_id": null,
|
323 |
+
"suppress_tokens": null,
|
324 |
+
"task_specific_params": null,
|
325 |
+
"temperature": 1.0,
|
326 |
+
"tf_legacy_loss": false,
|
327 |
+
"tie_encoder_decoder": false,
|
328 |
+
"tie_word_embeddings": true,
|
329 |
+
"tokenizer_class": null,
|
330 |
+
"top_k": 50,
|
331 |
+
"top_p": 1.0,
|
332 |
+
"torch_dtype": null,
|
333 |
+
"torchscript": false,
|
334 |
+
"typical_p": 1.0,
|
335 |
+
"use_bfloat16": false,
|
336 |
+
"use_cache": false,
|
337 |
+
"use_learned_position_embeddings": true,
|
338 |
+
"vocab_size": 50265
|
339 |
+
},
|
340 |
+
"decoder_start_token_id": 0,
|
341 |
+
"diversity_penalty": 0.0,
|
342 |
+
"do_sample": false,
|
343 |
+
"early_stopping": true,
|
344 |
+
"encoder": {
|
345 |
+
"_name_or_path": "",
|
346 |
+
"add_cross_attention": false,
|
347 |
+
"architectures": null,
|
348 |
+
"attention_probs_dropout_prob": 0.0,
|
349 |
+
"bad_words_ids": null,
|
350 |
+
"begin_suppress_tokens": null,
|
351 |
+
"bos_token_id": null,
|
352 |
+
"chunk_size_feed_forward": 0,
|
353 |
+
"cross_attention_hidden_size": null,
|
354 |
+
"decoder_start_token_id": null,
|
355 |
+
"diversity_penalty": 0.0,
|
356 |
+
"do_sample": false,
|
357 |
+
"early_stopping": false,
|
358 |
+
"encoder_no_repeat_ngram_size": 0,
|
359 |
+
"encoder_stride": 16,
|
360 |
+
"eos_token_id": null,
|
361 |
+
"exponential_decay_length_penalty": null,
|
362 |
+
"finetuning_task": null,
|
363 |
+
"forced_bos_token_id": null,
|
364 |
+
"forced_eos_token_id": null,
|
365 |
+
"hidden_act": "gelu",
|
366 |
+
"hidden_dropout_prob": 0.0,
|
367 |
+
"hidden_size": 768,
|
368 |
+
"id2label": {
|
369 |
+
"0": "LABEL_0",
|
370 |
+
"1": "LABEL_1"
|
371 |
+
},
|
372 |
+
"image_size": 384,
|
373 |
+
"initializer_range": 0.02,
|
374 |
+
"intermediate_size": 3072,
|
375 |
+
"is_decoder": false,
|
376 |
+
"is_encoder_decoder": false,
|
377 |
+
"label2id": {
|
378 |
+
"LABEL_0": 0,
|
379 |
+
"LABEL_1": 1
|
380 |
+
},
|
381 |
+
"layer_norm_eps": 1e-12,
|
382 |
+
"length_penalty": 1.0,
|
383 |
+
"max_length": 20,
|
384 |
+
"min_length": 0,
|
385 |
+
"model_type": "vit",
|
386 |
+
"no_repeat_ngram_size": 0,
|
387 |
+
"num_attention_heads": 12,
|
388 |
+
"num_beam_groups": 1,
|
389 |
+
"num_beams": 1,
|
390 |
+
"num_channels": 3,
|
391 |
+
"num_hidden_layers": 12,
|
392 |
+
"num_return_sequences": 1,
|
393 |
+
"output_attentions": false,
|
394 |
+
"output_hidden_states": false,
|
395 |
+
"output_scores": false,
|
396 |
+
"pad_token_id": null,
|
397 |
+
"patch_size": 16,
|
398 |
+
"prefix": null,
|
399 |
+
"problem_type": null,
|
400 |
+
"pruned_heads": {},
|
401 |
+
"qkv_bias": false,
|
402 |
+
"remove_invalid_values": false,
|
403 |
+
"repetition_penalty": 1.0,
|
404 |
+
"return_dict": true,
|
405 |
+
"return_dict_in_generate": false,
|
406 |
+
"sep_token_id": null,
|
407 |
+
"suppress_tokens": null,
|
408 |
+
"task_specific_params": null,
|
409 |
+
"temperature": 1.0,
|
410 |
+
"tf_legacy_loss": false,
|
411 |
+
"tie_encoder_decoder": false,
|
412 |
+
"tie_word_embeddings": true,
|
413 |
+
"tokenizer_class": null,
|
414 |
+
"top_k": 50,
|
415 |
+
"top_p": 1.0,
|
416 |
+
"torch_dtype": null,
|
417 |
+
"torchscript": false,
|
418 |
+
"typical_p": 1.0,
|
419 |
+
"use_bfloat16": false
|
420 |
+
},
|
421 |
+
"encoder_no_repeat_ngram_size": 0,
|
422 |
+
"eos_token_id": 2,
|
423 |
+
"exponential_decay_length_penalty": null,
|
424 |
+
"finetuning_task": null,
|
425 |
+
"forced_bos_token_id": null,
|
426 |
+
"forced_eos_token_id": null,
|
427 |
+
"id2label": {
|
428 |
+
"0": "LABEL_0",
|
429 |
+
"1": "LABEL_1"
|
430 |
+
},
|
431 |
+
"is_decoder": false,
|
432 |
+
"is_encoder_decoder": true,
|
433 |
+
"label2id": {
|
434 |
+
"LABEL_0": 0,
|
435 |
+
"LABEL_1": 1
|
436 |
+
},
|
437 |
+
"length_penalty": 2.0,
|
438 |
+
"max_length": 300,
|
439 |
+
"min_length": 0,
|
440 |
+
"model_type": "vision-encoder-decoder",
|
441 |
+
"no_repeat_ngram_size": 3,
|
442 |
+
"num_beam_groups": 1,
|
443 |
+
"num_beams": 4,
|
444 |
+
"num_return_sequences": 1,
|
445 |
+
"output_attentions": false,
|
446 |
+
"output_hidden_states": false,
|
447 |
+
"output_scores": false,
|
448 |
+
"pad_token_id": 1,
|
449 |
+
"prefix": null,
|
450 |
+
"problem_type": null,
|
451 |
+
"pruned_heads": {},
|
452 |
+
"remove_invalid_values": false,
|
453 |
+
"repetition_penalty": 1.0,
|
454 |
+
"return_dict": true,
|
455 |
+
"return_dict_in_generate": false,
|
456 |
+
"sep_token_id": null,
|
457 |
+
"suppress_tokens": null,
|
458 |
+
"task_specific_params": null,
|
459 |
+
"temperature": 1.0,
|
460 |
+
"tf_legacy_loss": false,
|
461 |
+
"tie_encoder_decoder": false,
|
462 |
+
"tie_word_embeddings": false,
|
463 |
+
"tokenizer_class": null,
|
464 |
+
"top_k": 50,
|
465 |
+
"top_p": 1.0,
|
466 |
+
"torch_dtype": "float32",
|
467 |
+
"torchscript": false,
|
468 |
+
"typical_p": 1.0,
|
469 |
+
"use_bfloat16": false,
|
470 |
+
"vocab_size": 50265
|
471 |
+
},
|
472 |
+
"ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
|
473 |
+
"torch_dtype": "float32",
|
474 |
+
"transformers_version": "4.34.0.dev0"
|
475 |
+
}
|
configuration_magi.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, VisionEncoderDecoderConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class MagiConfig(PretrainedConfig):
|
6 |
+
model_type = "magi"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
disable_ocr: bool = False,
|
11 |
+
disable_crop_embeddings: bool = False,
|
12 |
+
disable_detections: bool = False,
|
13 |
+
detection_model_config: dict = None,
|
14 |
+
ocr_model_config: dict = None,
|
15 |
+
crop_embedding_model_config: dict = None,
|
16 |
+
detection_image_preprocessing_config: dict = None,
|
17 |
+
ocr_pretrained_processor_path: str = None,
|
18 |
+
crop_embedding_image_preprocessing_config: dict = None,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
self.disable_ocr = disable_ocr
|
22 |
+
self.disable_crop_embeddings = disable_crop_embeddings
|
23 |
+
self.disable_detections = disable_detections
|
24 |
+
|
25 |
+
self.detection_model_config = None
|
26 |
+
self.ocr_model_config = None
|
27 |
+
self.crop_embedding_model_config = None
|
28 |
+
if detection_model_config is not None:
|
29 |
+
self.detection_model_config = PretrainedConfig.from_dict(detection_model_config)
|
30 |
+
if ocr_model_config is not None:
|
31 |
+
self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config)
|
32 |
+
if crop_embedding_model_config is not None:
|
33 |
+
self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config)
|
34 |
+
|
35 |
+
self.detection_image_preprocessing_config = detection_image_preprocessing_config
|
36 |
+
self.ocr_pretrained_processor_path = ocr_pretrained_processor_path
|
37 |
+
self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config
|
38 |
+
super().__init__(**kwargs)
|
modelling_magi.py
ADDED
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
|
2 |
+
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
3 |
+
ConditionalDetrMLPPredictionHead,
|
4 |
+
ConditionalDetrModelOutput,
|
5 |
+
ConditionalDetrHungarianMatcher,
|
6 |
+
inverse_sigmoid,
|
7 |
+
)
|
8 |
+
from .configuration_magi import MagiConfig
|
9 |
+
from .processing_magi import MagiProcessor
|
10 |
+
from torch import nn
|
11 |
+
from typing import Optional, List
|
12 |
+
import torch
|
13 |
+
from einops import rearrange, repeat, einsum
|
14 |
+
from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
|
15 |
+
|
16 |
+
class MagiModel(PreTrainedModel):
|
17 |
+
config_class = MagiConfig
|
18 |
+
|
19 |
+
def __init__(self, config):
|
20 |
+
super().__init__(config)
|
21 |
+
self.config = config
|
22 |
+
self.processor = MagiProcessor(config)
|
23 |
+
if not config.disable_ocr:
|
24 |
+
self.ocr_model = VisionEncoderDecoderModel(config.ocr_model_config)
|
25 |
+
if not config.disable_crop_embeddings:
|
26 |
+
self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config)
|
27 |
+
if not config.disable_detections:
|
28 |
+
self.num_non_obj_tokens = 5
|
29 |
+
self.detection_transformer = ConditionalDetrModel(config.detection_model_config)
|
30 |
+
self.bbox_predictor = ConditionalDetrMLPPredictionHead(
|
31 |
+
input_dim=config.detection_model_config.d_model,
|
32 |
+
hidden_dim=config.detection_model_config.d_model,
|
33 |
+
output_dim=4, num_layers=3
|
34 |
+
)
|
35 |
+
self.is_this_text_a_dialogue = ConditionalDetrMLPPredictionHead(
|
36 |
+
input_dim=config.detection_model_config.d_model,
|
37 |
+
hidden_dim=config.detection_model_config.d_model,
|
38 |
+
output_dim=1,
|
39 |
+
num_layers=3
|
40 |
+
)
|
41 |
+
self.character_character_matching_head = ConditionalDetrMLPPredictionHead(
|
42 |
+
input_dim = 3 * config.detection_model_config.d_model + (2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0),
|
43 |
+
hidden_dim=config.detection_model_config.d_model,
|
44 |
+
output_dim=1, num_layers=3
|
45 |
+
)
|
46 |
+
self.text_character_matching_head = ConditionalDetrMLPPredictionHead(
|
47 |
+
input_dim = 3 * config.detection_model_config.d_model,
|
48 |
+
hidden_dim=config.detection_model_config.d_model,
|
49 |
+
output_dim=1, num_layers=3
|
50 |
+
)
|
51 |
+
self.class_labels_classifier = nn.Linear(
|
52 |
+
config.detection_model_config.d_model, config.detection_model_config.num_labels
|
53 |
+
)
|
54 |
+
self.matcher = ConditionalDetrHungarianMatcher(
|
55 |
+
class_cost=config.detection_model_config.class_cost,
|
56 |
+
bbox_cost=config.detection_model_config.bbox_cost,
|
57 |
+
giou_cost=config.detection_model_config.giou_cost
|
58 |
+
)
|
59 |
+
|
60 |
+
def move_to_device(self, input):
|
61 |
+
return move_to_device(input, self.device)
|
62 |
+
|
63 |
+
def predict_detections_and_associations(
|
64 |
+
self,
|
65 |
+
images,
|
66 |
+
move_to_device_fn=None,
|
67 |
+
character_detection_threshold=0.3,
|
68 |
+
panel_detection_threshold=0.2,
|
69 |
+
text_detection_threshold=0.25,
|
70 |
+
character_character_matching_threshold=0.7,
|
71 |
+
text_character_matching_threshold=0.4,
|
72 |
+
):
|
73 |
+
assert not self.config.disable_detections
|
74 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
75 |
+
|
76 |
+
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images)
|
77 |
+
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
78 |
+
|
79 |
+
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
80 |
+
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
81 |
+
|
82 |
+
# create callback fn
|
83 |
+
def get_character_character_matching_scores(batch_character_indices, batch_bboxes):
|
84 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
85 |
+
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
|
86 |
+
crop_bboxes = [batch_bboxes[i][batch_character_indices[i]] for i in range(len(batch_character_indices))]
|
87 |
+
crop_embeddings_for_batch = self.predict_crop_embeddings(images, crop_bboxes, move_to_device_fn)
|
88 |
+
character_obj_tokens_for_batch = []
|
89 |
+
c2c_tokens_for_batch = []
|
90 |
+
for predicted_obj_tokens, predicted_c2c_tokens, character_indices in zip(predicted_obj_tokens_for_batch, predicted_c2c_tokens_for_batch, batch_character_indices):
|
91 |
+
character_obj_tokens_for_batch.append(predicted_obj_tokens[character_indices])
|
92 |
+
c2c_tokens_for_batch.append(predicted_c2c_tokens)
|
93 |
+
return self._get_character_character_affinity_matrices(
|
94 |
+
character_obj_tokens_for_batch=character_obj_tokens_for_batch,
|
95 |
+
crop_embeddings_for_batch=crop_embeddings_for_batch,
|
96 |
+
c2c_tokens_for_batch=c2c_tokens_for_batch,
|
97 |
+
apply_sigmoid=True,
|
98 |
+
)
|
99 |
+
|
100 |
+
# create callback fn
|
101 |
+
def get_text_character_matching_scores(batch_text_indices, batch_character_indices):
|
102 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
103 |
+
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
104 |
+
text_obj_tokens_for_batch = []
|
105 |
+
character_obj_tokens_for_batch = []
|
106 |
+
t2c_tokens_for_batch = []
|
107 |
+
for predicted_obj_tokens, predicted_t2c_tokens, text_indices, character_indices in zip(predicted_obj_tokens_for_batch, predicted_t2c_tokens_for_batch, batch_text_indices, batch_character_indices):
|
108 |
+
text_obj_tokens_for_batch.append(predicted_obj_tokens[text_indices])
|
109 |
+
character_obj_tokens_for_batch.append(predicted_obj_tokens[character_indices])
|
110 |
+
t2c_tokens_for_batch.append(predicted_t2c_tokens)
|
111 |
+
return self._get_text_character_affinity_matrices(
|
112 |
+
character_obj_tokens_for_batch=character_obj_tokens_for_batch,
|
113 |
+
text_obj_tokens_for_this_batch=text_obj_tokens_for_batch,
|
114 |
+
t2c_tokens_for_batch=t2c_tokens_for_batch,
|
115 |
+
apply_sigmoid=True,
|
116 |
+
)
|
117 |
+
|
118 |
+
# create callback fn
|
119 |
+
def get_dialog_confidence_scores(batch_text_indices):
|
120 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
121 |
+
dialog_confidence = []
|
122 |
+
for predicted_obj_tokens, text_indices in zip(predicted_obj_tokens_for_batch, batch_text_indices):
|
123 |
+
confidence = self.is_this_text_a_dialogue(predicted_obj_tokens[text_indices]).sigmoid()
|
124 |
+
dialog_confidence.append(rearrange(confidence, "i 1 -> i"))
|
125 |
+
return dialog_confidence
|
126 |
+
|
127 |
+
return self.processor.postprocess_detections_and_associations(
|
128 |
+
predicted_bboxes=predicted_bboxes,
|
129 |
+
predicted_class_scores=predicted_class_scores,
|
130 |
+
original_image_sizes=torch.stack([torch.tensor(img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device),
|
131 |
+
get_character_character_matching_scores=get_character_character_matching_scores,
|
132 |
+
get_text_character_matching_scores=get_text_character_matching_scores,
|
133 |
+
get_dialog_confidence_scores=get_dialog_confidence_scores,
|
134 |
+
character_detection_threshold=character_detection_threshold,
|
135 |
+
panel_detection_threshold=panel_detection_threshold,
|
136 |
+
text_detection_threshold=text_detection_threshold,
|
137 |
+
character_character_matching_threshold=character_character_matching_threshold,
|
138 |
+
text_character_matching_threshold=text_character_matching_threshold,
|
139 |
+
)
|
140 |
+
|
141 |
+
def predict_crop_embeddings(self, images, crop_bboxes, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
|
142 |
+
if self.config.disable_crop_embeddings:
|
143 |
+
return None
|
144 |
+
|
145 |
+
assert isinstance(crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for"
|
146 |
+
|
147 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
148 |
+
|
149 |
+
# temporarily change the mask ratio from default to the one specified
|
150 |
+
old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
|
151 |
+
self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
|
152 |
+
|
153 |
+
crops_per_image = []
|
154 |
+
num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
|
155 |
+
for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
|
156 |
+
crops = self.processor.crop_image(image, bboxes)
|
157 |
+
assert len(crops) == num_crops
|
158 |
+
crops_per_image.extend(crops)
|
159 |
+
|
160 |
+
if len(crops_per_image) == 0:
|
161 |
+
return [[] for _ in crop_bboxes]
|
162 |
+
|
163 |
+
crops_per_image = self.processor.preprocess_inputs_for_crop_embeddings(crops_per_image)
|
164 |
+
crops_per_image = move_to_device_fn(crops_per_image)
|
165 |
+
|
166 |
+
# process the crops in batches to avoid OOM
|
167 |
+
embeddings = []
|
168 |
+
for i in range(0, len(crops_per_image), batch_size):
|
169 |
+
crops = crops_per_image[i:i+batch_size]
|
170 |
+
embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0]
|
171 |
+
embeddings.append(embeddings_per_batch)
|
172 |
+
embeddings = torch.cat(embeddings, dim=0)
|
173 |
+
|
174 |
+
crop_embeddings_for_batch = []
|
175 |
+
for num_crops in num_crops_per_batch:
|
176 |
+
crop_embeddings_for_batch.append(embeddings[:num_crops])
|
177 |
+
embeddings = embeddings[num_crops:]
|
178 |
+
|
179 |
+
# restore the mask ratio to the default
|
180 |
+
self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
|
181 |
+
|
182 |
+
return crop_embeddings_for_batch
|
183 |
+
|
184 |
+
def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32):
|
185 |
+
assert not self.config.disable_ocr
|
186 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
187 |
+
|
188 |
+
crops_per_image = []
|
189 |
+
num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
|
190 |
+
for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
|
191 |
+
crops = self.processor.crop_image(image, bboxes)
|
192 |
+
assert len(crops) == num_crops
|
193 |
+
crops_per_image.extend(crops)
|
194 |
+
|
195 |
+
if len(crops_per_image) == 0:
|
196 |
+
return [[] for _ in crop_bboxes]
|
197 |
+
|
198 |
+
crops_per_image = self.processor.preprocess_inputs_for_ocr(crops_per_image)
|
199 |
+
crops_per_image = move_to_device_fn(crops_per_image)
|
200 |
+
|
201 |
+
# process the crops in batches to avoid OOM
|
202 |
+
all_generated_texts = []
|
203 |
+
if use_tqdm:
|
204 |
+
from tqdm import tqdm
|
205 |
+
pbar = tqdm(range(0, len(crops_per_image), batch_size))
|
206 |
+
else:
|
207 |
+
pbar = range(0, len(crops_per_image), batch_size)
|
208 |
+
for i in pbar:
|
209 |
+
crops = crops_per_image[i:i+batch_size]
|
210 |
+
generated_ids = self.ocr_model.generate(crops)
|
211 |
+
generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
|
212 |
+
all_generated_texts.extend(generated_texts)
|
213 |
+
|
214 |
+
texts_for_images = []
|
215 |
+
for num_crops in num_crops_per_batch:
|
216 |
+
texts_for_images.append([x.replace("\n", "") for x in all_generated_texts[:num_crops]])
|
217 |
+
all_generated_texts = all_generated_texts[num_crops:]
|
218 |
+
|
219 |
+
return texts_for_images
|
220 |
+
|
221 |
+
def visualise_single_image_prediction(
|
222 |
+
self, image_as_np_array, predictions, filename=None
|
223 |
+
):
|
224 |
+
return visualise_single_image_prediction(image_as_np_array, predictions, filename)
|
225 |
+
|
226 |
+
def generate_transcript_for_single_image(
|
227 |
+
self, predictions, ocr_results, filename=None
|
228 |
+
):
|
229 |
+
character_clusters = predictions["character_cluster_labels"]
|
230 |
+
text_to_character = predictions["text_character_associations"]
|
231 |
+
text_to_character = {k: v for k, v in text_to_character}
|
232 |
+
transript = " ### Transcript ###\n"
|
233 |
+
for index, text in enumerate(ocr_results):
|
234 |
+
if index in text_to_character:
|
235 |
+
speaker = character_clusters[text_to_character[index]]
|
236 |
+
speaker = f"<{speaker}>"
|
237 |
+
else:
|
238 |
+
speaker = "<?>"
|
239 |
+
transript += f"{speaker}: {text}\n"
|
240 |
+
if filename is not None:
|
241 |
+
with open(filename, "w") as file:
|
242 |
+
file.write(transript)
|
243 |
+
return transript
|
244 |
+
|
245 |
+
def get_text_character_affinity_matrices_given_annotations(
|
246 |
+
self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
|
247 |
+
):
|
248 |
+
assert not self.config.disable_detections
|
249 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
250 |
+
|
251 |
+
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
|
252 |
+
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
253 |
+
processed_targets = inputs_to_detection_transformer.pop("labels")
|
254 |
+
|
255 |
+
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
256 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
257 |
+
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
258 |
+
|
259 |
+
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
260 |
+
matching_dict = {
|
261 |
+
"logits": predicted_class_scores,
|
262 |
+
"pred_boxes": predicted_bboxes,
|
263 |
+
}
|
264 |
+
indices = self.matcher(matching_dict, processed_targets)
|
265 |
+
|
266 |
+
matched_char_obj_tokens_for_batch = []
|
267 |
+
matched_text_obj_tokens_for_batch = []
|
268 |
+
t2c_tokens_for_batch = []
|
269 |
+
|
270 |
+
text_bboxes_for_batch = []
|
271 |
+
character_bboxes_for_batch = []
|
272 |
+
|
273 |
+
for j, (pred_idx, tgt_idx) in enumerate(indices):
|
274 |
+
target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
|
275 |
+
targets_for_this_image = processed_targets[j]
|
276 |
+
indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
|
277 |
+
indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
|
278 |
+
predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
|
279 |
+
predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
|
280 |
+
|
281 |
+
text_bboxes_for_batch.append(
|
282 |
+
[annotations[j]["bboxes_as_x1y1x2y2"][k] for k in indices_of_text_boxes_in_annotation]
|
283 |
+
)
|
284 |
+
character_bboxes_for_batch.append(
|
285 |
+
[annotations[j]["bboxes_as_x1y1x2y2"][k] for k in indices_of_char_boxes_in_annotation]
|
286 |
+
)
|
287 |
+
|
288 |
+
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
|
289 |
+
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
|
290 |
+
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
|
291 |
+
|
292 |
+
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
293 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
294 |
+
text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
|
295 |
+
t2c_tokens_for_batch=t2c_tokens_for_batch,
|
296 |
+
apply_sigmoid=apply_sigmoid,
|
297 |
+
)
|
298 |
+
|
299 |
+
return {
|
300 |
+
"text_character_affinity_matrices": text_character_affinity_matrices,
|
301 |
+
"text_bboxes_for_batch": text_bboxes_for_batch,
|
302 |
+
"character_bboxes_for_batch": character_bboxes_for_batch,
|
303 |
+
}
|
304 |
+
|
305 |
+
def get_obj_embeddings_corresponding_to_given_annotations(
|
306 |
+
self, images, annotations, move_to_device_fn=None
|
307 |
+
):
|
308 |
+
assert not self.config.disable_detections
|
309 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
310 |
+
|
311 |
+
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
|
312 |
+
inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
|
313 |
+
processed_targets = inputs_to_detection_transformer.pop("labels")
|
314 |
+
|
315 |
+
detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
|
316 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
|
317 |
+
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
|
318 |
+
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
|
319 |
+
|
320 |
+
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
|
321 |
+
matching_dict = {
|
322 |
+
"logits": predicted_class_scores,
|
323 |
+
"pred_boxes": predicted_bboxes,
|
324 |
+
}
|
325 |
+
indices = self.matcher(matching_dict, processed_targets)
|
326 |
+
|
327 |
+
matched_char_obj_tokens_for_batch = []
|
328 |
+
matched_text_obj_tokens_for_batch = []
|
329 |
+
matched_panel_obj_tokens_for_batch = []
|
330 |
+
t2c_tokens_for_batch = []
|
331 |
+
c2c_tokens_for_batch = []
|
332 |
+
|
333 |
+
for j, (pred_idx, tgt_idx) in enumerate(indices):
|
334 |
+
target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
|
335 |
+
targets_for_this_image = processed_targets[j]
|
336 |
+
indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
|
337 |
+
indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
|
338 |
+
indices_of_panel_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 2]
|
339 |
+
predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
|
340 |
+
predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
|
341 |
+
predicted_panel_indices = [target_idx_to_pred_idx[i] for i in indices_of_panel_boxes_in_annotation]
|
342 |
+
|
343 |
+
matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
|
344 |
+
matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
|
345 |
+
matched_panel_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_panel_indices])
|
346 |
+
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
|
347 |
+
c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
|
348 |
+
|
349 |
+
return {
|
350 |
+
"character": matched_char_obj_tokens_for_batch,
|
351 |
+
"text": matched_text_obj_tokens_for_batch,
|
352 |
+
"panel": matched_panel_obj_tokens_for_batch,
|
353 |
+
"t2c": t2c_tokens_for_batch,
|
354 |
+
"c2c": c2c_tokens_for_batch,
|
355 |
+
}
|
356 |
+
|
357 |
+
def sort_panels_and_text_bboxes_in_reading_order(
|
358 |
+
self,
|
359 |
+
batch_panel_bboxes,
|
360 |
+
batch_text_bboxes,
|
361 |
+
):
|
362 |
+
batch_sorted_panel_indices = []
|
363 |
+
batch_sorted_text_indices = []
|
364 |
+
for batch_index in range(len(batch_text_bboxes)):
|
365 |
+
panel_bboxes = batch_panel_bboxes[batch_index]
|
366 |
+
text_bboxes = batch_text_bboxes[batch_index]
|
367 |
+
sorted_panel_indices = sort_panels(panel_bboxes)
|
368 |
+
sorted_panels = [panel_bboxes[i] for i in sorted_panel_indices]
|
369 |
+
sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, sorted_panels)
|
370 |
+
batch_sorted_panel_indices.append(sorted_panel_indices)
|
371 |
+
batch_sorted_text_indices.append(sorted_text_indices)
|
372 |
+
return batch_sorted_panel_indices, batch_sorted_text_indices
|
373 |
+
|
374 |
+
def _get_detection_transformer_output(
|
375 |
+
self,
|
376 |
+
pixel_values: torch.FloatTensor,
|
377 |
+
pixel_mask: Optional[torch.LongTensor] = None
|
378 |
+
):
|
379 |
+
if self.config.disable_detections:
|
380 |
+
raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
|
381 |
+
return self.detection_transformer(
|
382 |
+
pixel_values=pixel_values,
|
383 |
+
pixel_mask=pixel_mask,
|
384 |
+
return_dict=True
|
385 |
+
)
|
386 |
+
|
387 |
+
def _get_predicted_obj_tokens(
|
388 |
+
self,
|
389 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
390 |
+
):
|
391 |
+
return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens]
|
392 |
+
|
393 |
+
def _get_predicted_c2c_tokens(
|
394 |
+
self,
|
395 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
396 |
+
):
|
397 |
+
return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens]
|
398 |
+
|
399 |
+
def _get_predicted_t2c_tokens(
|
400 |
+
self,
|
401 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
402 |
+
):
|
403 |
+
return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1]
|
404 |
+
|
405 |
+
def _get_predicted_bboxes_and_classes(
|
406 |
+
self,
|
407 |
+
detection_transformer_output: ConditionalDetrModelOutput,
|
408 |
+
):
|
409 |
+
if self.config.disable_detections:
|
410 |
+
raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
|
411 |
+
|
412 |
+
obj = self._get_predicted_obj_tokens(detection_transformer_output)
|
413 |
+
|
414 |
+
predicted_class_scores = self.class_labels_classifier(obj)
|
415 |
+
reference = detection_transformer_output.reference_points[:-self.num_non_obj_tokens]
|
416 |
+
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
|
417 |
+
predicted_boxes = self.bbox_predictor(obj)
|
418 |
+
predicted_boxes[..., :2] += reference_before_sigmoid
|
419 |
+
predicted_boxes = predicted_boxes.sigmoid()
|
420 |
+
|
421 |
+
return predicted_class_scores, predicted_boxes
|
422 |
+
|
423 |
+
def _get_character_character_affinity_matrices(
|
424 |
+
self,
|
425 |
+
character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
|
426 |
+
crop_embeddings_for_batch: List[torch.FloatTensor] = None,
|
427 |
+
c2c_tokens_for_batch: List[torch.FloatTensor] = None,
|
428 |
+
apply_sigmoid=True,
|
429 |
+
):
|
430 |
+
assert self.config.disable_detections or (character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None)
|
431 |
+
assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None
|
432 |
+
assert not self.config.disable_detections or not self.config.disable_crop_embeddings
|
433 |
+
|
434 |
+
if self.config.disable_detections:
|
435 |
+
affinity_matrices = []
|
436 |
+
for crop_embeddings in crop_embeddings_for_batch:
|
437 |
+
crop_embeddings = crop_embeddings / crop_embeddings.norm(dim=-1, keepdim=True)
|
438 |
+
affinity_matrix = einsum("i d, j d -> i j", affinity_matrix)
|
439 |
+
affinity_matrices.append(affinity_matrix)
|
440 |
+
return affinity_matrices
|
441 |
+
affinity_matrices = []
|
442 |
+
for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)):
|
443 |
+
if character_obj_tokens.shape[0] == 0:
|
444 |
+
affinity_matrices.append(torch.zeros(0, 0).type_as(character_obj_tokens))
|
445 |
+
continue
|
446 |
+
if not self.config.disable_crop_embeddings:
|
447 |
+
crop_embeddings = crop_embeddings_for_batch[batch_index]
|
448 |
+
assert character_obj_tokens.shape[0] == crop_embeddings.shape[0]
|
449 |
+
character_obj_tokens = torch.cat([character_obj_tokens, crop_embeddings], dim=-1)
|
450 |
+
char_i = repeat(character_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
|
451 |
+
char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=character_obj_tokens.shape[0])
|
452 |
+
char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
|
453 |
+
c2c = repeat(c2c, "d -> repeat d", repeat = char_ij.shape[0])
|
454 |
+
char_ij_c2c = torch.cat([char_ij, c2c], dim=-1)
|
455 |
+
character_character_affinities = self.character_character_matching_head(char_ij_c2c)
|
456 |
+
character_character_affinities = rearrange(character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
|
457 |
+
if apply_sigmoid:
|
458 |
+
character_character_affinities = character_character_affinities.sigmoid()
|
459 |
+
affinity_matrices.append(character_character_affinities)
|
460 |
+
return affinity_matrices
|
461 |
+
|
462 |
+
def _get_text_character_affinity_matrices(
|
463 |
+
self,
|
464 |
+
character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
|
465 |
+
text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
|
466 |
+
t2c_tokens_for_batch: List[torch.FloatTensor] = None,
|
467 |
+
apply_sigmoid=True,
|
468 |
+
):
|
469 |
+
assert not self.config.disable_detections
|
470 |
+
assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None
|
471 |
+
affinity_matrices = []
|
472 |
+
for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch):
|
473 |
+
if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
|
474 |
+
affinity_matrices.append(torch.zeros(text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens))
|
475 |
+
continue
|
476 |
+
text_i = repeat(text_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
|
477 |
+
char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=text_obj_tokens.shape[0])
|
478 |
+
text_char = rearrange([text_i, char_j], "two i j d -> (i j) (two d)")
|
479 |
+
t2c = repeat(t2c, "d -> repeat d", repeat = text_char.shape[0])
|
480 |
+
text_char_t2c = torch.cat([text_char, t2c], dim=-1)
|
481 |
+
text_character_affinities = self.text_character_matching_head(text_char_t2c)
|
482 |
+
text_character_affinities = rearrange(text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
|
483 |
+
if apply_sigmoid:
|
484 |
+
text_character_affinities = text_character_affinities.sigmoid()
|
485 |
+
affinity_matrices.append(text_character_affinities)
|
486 |
+
return affinity_matrices
|
processing_magi.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor
|
2 |
+
from transformers.image_transforms import center_to_corners_format
|
3 |
+
import torch
|
4 |
+
from typing import List
|
5 |
+
from shapely.geometry import box
|
6 |
+
from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order, x1y1x2y2_to_xywh
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class MagiProcessor():
|
10 |
+
def __init__(self, config):
|
11 |
+
self.config = config
|
12 |
+
self.detection_image_preprocessor = None
|
13 |
+
self.ocr_preprocessor = None
|
14 |
+
self.crop_embedding_image_preprocessor = None
|
15 |
+
if not config.disable_detections:
|
16 |
+
assert config.detection_image_preprocessing_config is not None
|
17 |
+
self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(config.detection_image_preprocessing_config)
|
18 |
+
if not config.disable_ocr:
|
19 |
+
assert config.ocr_pretrained_processor_path is not None
|
20 |
+
self.ocr_preprocessor = TrOCRProcessor.from_pretrained(config.ocr_pretrained_processor_path)
|
21 |
+
if not config.disable_crop_embeddings:
|
22 |
+
assert config.crop_embedding_image_preprocessing_config is not None
|
23 |
+
self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config)
|
24 |
+
|
25 |
+
def preprocess_inputs_for_detection(self, images, annotations=None):
|
26 |
+
images = list(images)
|
27 |
+
assert isinstance(images[0], np.ndarray)
|
28 |
+
annotations = self._convert_annotations_to_coco_format(annotations)
|
29 |
+
inputs = self.detection_image_preprocessor(images, annotations=annotations, return_tensors="pt")
|
30 |
+
return inputs
|
31 |
+
|
32 |
+
def preprocess_inputs_for_ocr(self, images):
|
33 |
+
images = list(images)
|
34 |
+
assert isinstance(images[0], np.ndarray)
|
35 |
+
return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
|
36 |
+
|
37 |
+
def preprocess_inputs_for_crop_embeddings(self, images):
|
38 |
+
images = list(images)
|
39 |
+
assert isinstance(images[0], np.ndarray)
|
40 |
+
return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
|
41 |
+
|
42 |
+
def postprocess_detections_and_associations(
|
43 |
+
self,
|
44 |
+
predicted_bboxes,
|
45 |
+
predicted_class_scores,
|
46 |
+
original_image_sizes,
|
47 |
+
get_character_character_matching_scores,
|
48 |
+
get_text_character_matching_scores,
|
49 |
+
get_dialog_confidence_scores,
|
50 |
+
character_detection_threshold=0.3,
|
51 |
+
panel_detection_threshold=0.2,
|
52 |
+
text_detection_threshold=0.25,
|
53 |
+
character_character_matching_threshold=0.7,
|
54 |
+
text_character_matching_threshold=0.4,
|
55 |
+
):
|
56 |
+
assert self.config.disable_detections is False
|
57 |
+
batch_scores, batch_labels = predicted_class_scores.max(-1)
|
58 |
+
batch_scores = batch_scores.sigmoid()
|
59 |
+
batch_labels = batch_labels.long()
|
60 |
+
batch_bboxes = center_to_corners_format(predicted_bboxes)
|
61 |
+
|
62 |
+
# scale the bboxes back to the original image size
|
63 |
+
if isinstance(original_image_sizes, List):
|
64 |
+
img_h = torch.Tensor([i[0] for i in original_image_sizes])
|
65 |
+
img_w = torch.Tensor([i[1] for i in original_image_sizes])
|
66 |
+
else:
|
67 |
+
img_h, img_w = original_image_sizes.unbind(1)
|
68 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device)
|
69 |
+
batch_bboxes = batch_bboxes * scale_fct[:, None, :]
|
70 |
+
|
71 |
+
batch_panel_indices = self._get_indices_of_panels_to_keep(batch_scores, batch_labels, batch_bboxes, panel_detection_threshold)
|
72 |
+
batch_character_indices = self._get_indices_of_characters_to_keep(batch_scores, batch_labels, batch_bboxes, character_detection_threshold)
|
73 |
+
batch_text_indices = self._get_indices_of_texts_to_keep(batch_scores, batch_labels, batch_bboxes, text_detection_threshold)
|
74 |
+
|
75 |
+
batch_character_character_matching_scores = get_character_character_matching_scores(batch_character_indices, batch_bboxes)
|
76 |
+
batch_text_character_matching_scores = get_text_character_matching_scores(batch_text_indices, batch_character_indices)
|
77 |
+
batch_dialog_confidence_scores = get_dialog_confidence_scores(batch_text_indices)
|
78 |
+
|
79 |
+
# sort panels and texts in the reading order
|
80 |
+
for batch_index in range(len(batch_scores)):
|
81 |
+
panel_bboxes = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
|
82 |
+
panel_scores = batch_scores[batch_index][batch_panel_indices[batch_index]]
|
83 |
+
text_bboxes = batch_bboxes[batch_index][batch_text_indices[batch_index]]
|
84 |
+
text_scores = batch_scores[batch_index][batch_text_indices[batch_index]]
|
85 |
+
|
86 |
+
sorted_panel_indices = sort_panels(panel_bboxes)
|
87 |
+
batch_bboxes[batch_index][batch_panel_indices[batch_index]] = panel_bboxes[sorted_panel_indices]
|
88 |
+
batch_scores[batch_index][batch_panel_indices[batch_index]] = panel_scores[sorted_panel_indices]
|
89 |
+
sorted_panels = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
|
90 |
+
|
91 |
+
sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, sorted_panels)
|
92 |
+
batch_bboxes[batch_index][batch_text_indices[batch_index]] = text_bboxes[sorted_text_indices]
|
93 |
+
batch_scores[batch_index][batch_text_indices[batch_index]] = text_scores[sorted_text_indices]
|
94 |
+
batch_text_character_matching_scores[batch_index] = batch_text_character_matching_scores[batch_index][sorted_text_indices]
|
95 |
+
batch_dialog_confidence_scores[batch_index] = batch_dialog_confidence_scores[batch_index][sorted_text_indices]
|
96 |
+
|
97 |
+
results = []
|
98 |
+
for batch_index in range(len(batch_scores)):
|
99 |
+
panel_bboxes = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
|
100 |
+
panel_scores = batch_scores[batch_index][batch_panel_indices[batch_index]]
|
101 |
+
text_bboxes = batch_bboxes[batch_index][batch_text_indices[batch_index]]
|
102 |
+
text_scores = batch_scores[batch_index][batch_text_indices[batch_index]]
|
103 |
+
character_bboxes = batch_bboxes[batch_index][batch_character_indices[batch_index]]
|
104 |
+
character_scores = batch_scores[batch_index][batch_character_indices[batch_index]]
|
105 |
+
char_i, char_j = torch.where(batch_character_character_matching_scores[batch_index] > character_character_matching_threshold)
|
106 |
+
character_character_associations = torch.stack([char_i, char_j], dim=1)
|
107 |
+
text_boxes_to_match = batch_dialog_confidence_scores[batch_index] > text_character_matching_threshold
|
108 |
+
if 0 in batch_text_character_matching_scores[batch_index].shape:
|
109 |
+
text_character_associations = torch.zeros((0, 2), dtype=torch.long)
|
110 |
+
else:
|
111 |
+
most_likely_speaker_for_each_text = torch.argmax(batch_text_character_matching_scores[batch_index], dim=1)[text_boxes_to_match]
|
112 |
+
text_indices = torch.arange(len(text_bboxes)).type_as(most_likely_speaker_for_each_text)[text_boxes_to_match]
|
113 |
+
text_character_associations = torch.stack([text_indices, most_likely_speaker_for_each_text], dim=1)
|
114 |
+
|
115 |
+
character_ufds = UnionFind.from_adj_matrix(
|
116 |
+
batch_character_character_matching_scores[batch_index] > character_character_matching_threshold
|
117 |
+
)
|
118 |
+
results.append({
|
119 |
+
"panels": panel_bboxes.tolist(),
|
120 |
+
"panel_scores": panel_scores.tolist(),
|
121 |
+
"texts": text_bboxes.tolist(),
|
122 |
+
"text_scores": text_scores.tolist(),
|
123 |
+
"characters": character_bboxes.tolist(),
|
124 |
+
"character_scores": character_scores.tolist(),
|
125 |
+
"character_character_associations": character_character_associations.tolist(),
|
126 |
+
"text_character_associations": text_character_associations.tolist(),
|
127 |
+
"character_cluster_labels": character_ufds.get_labels_for_connected_components(),
|
128 |
+
"dialog_confidences": batch_dialog_confidence_scores[batch_index].tolist(),
|
129 |
+
})
|
130 |
+
return results
|
131 |
+
|
132 |
+
def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
|
133 |
+
return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
|
134 |
+
|
135 |
+
def crop_image(self, image, bboxes):
|
136 |
+
crops_for_image = []
|
137 |
+
for bbox in bboxes:
|
138 |
+
x1, y1, x2, y2 = bbox
|
139 |
+
|
140 |
+
# fix the bounding box in case it is out of bounds or too small
|
141 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
142 |
+
x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2) # just incase
|
143 |
+
x1, y1 = max(0, x1), max(0, y1)
|
144 |
+
x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
|
145 |
+
x2, y2 = max(0, x2), max(0, y2)
|
146 |
+
x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
|
147 |
+
if x2 - x1 < 10:
|
148 |
+
if image.shape[1] - x1 > 10:
|
149 |
+
x2 = x1 + 10
|
150 |
+
else:
|
151 |
+
x1 = x2 - 10
|
152 |
+
if y2 - y1 < 10:
|
153 |
+
if image.shape[0] - y1 > 10:
|
154 |
+
y2 = y1 + 10
|
155 |
+
else:
|
156 |
+
y1 = y2 - 10
|
157 |
+
|
158 |
+
crop = image[y1:y2, x1:x2]
|
159 |
+
crops_for_image.append(crop)
|
160 |
+
return crops_for_image
|
161 |
+
|
162 |
+
def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
|
163 |
+
indices_of_characters_to_keep = []
|
164 |
+
for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
|
165 |
+
indices = torch.where((labels == 0) & (scores > character_detection_threshold))[0]
|
166 |
+
indices_of_characters_to_keep.append(indices)
|
167 |
+
return indices_of_characters_to_keep
|
168 |
+
|
169 |
+
def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
|
170 |
+
indices_of_panels_to_keep = []
|
171 |
+
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
172 |
+
indices = torch.where(labels == 2)[0]
|
173 |
+
bboxes = bboxes[indices]
|
174 |
+
scores = scores[indices]
|
175 |
+
labels = labels[indices]
|
176 |
+
if len(indices) == 0:
|
177 |
+
indices_of_panels_to_keep.append([])
|
178 |
+
continue
|
179 |
+
scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
|
180 |
+
panels_to_keep = []
|
181 |
+
union_of_panels_so_far = box(0, 0, 0, 0)
|
182 |
+
for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
|
183 |
+
panel_polygon = box(pb[0], pb[1], pb[2], pb[3])
|
184 |
+
if ps < panel_detection_threshold:
|
185 |
+
continue
|
186 |
+
if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
|
187 |
+
continue
|
188 |
+
panels_to_keep.append((ps, pl, pb, pi))
|
189 |
+
union_of_panels_so_far = union_of_panels_so_far.union(panel_polygon)
|
190 |
+
indices_of_panels_to_keep.append([p[3].item() for p in panels_to_keep])
|
191 |
+
return indices_of_panels_to_keep
|
192 |
+
|
193 |
+
def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
|
194 |
+
indices_of_texts_to_keep = []
|
195 |
+
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
196 |
+
indices = torch.where((labels == 1) & (scores > text_detection_threshold))[0]
|
197 |
+
bboxes = bboxes[indices]
|
198 |
+
scores = scores[indices]
|
199 |
+
labels = labels[indices]
|
200 |
+
if len(indices) == 0:
|
201 |
+
indices_of_texts_to_keep.append([])
|
202 |
+
continue
|
203 |
+
scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
|
204 |
+
texts_to_keep = []
|
205 |
+
texts_to_keep_as_shapely_objects = []
|
206 |
+
for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
|
207 |
+
text_polygon = box(tb[0], tb[1], tb[2], tb[3])
|
208 |
+
should_append = True
|
209 |
+
for t in texts_to_keep_as_shapely_objects:
|
210 |
+
if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
|
211 |
+
should_append = False
|
212 |
+
break
|
213 |
+
if should_append:
|
214 |
+
texts_to_keep.append((ts, tl, tb, ti))
|
215 |
+
texts_to_keep_as_shapely_objects.append(text_polygon)
|
216 |
+
indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
|
217 |
+
return indices_of_texts_to_keep
|
218 |
+
|
219 |
+
def _convert_annotations_to_coco_format(self, annotations):
|
220 |
+
if annotations is None:
|
221 |
+
return None
|
222 |
+
self._verify_annotations_are_in_correct_format(annotations)
|
223 |
+
coco_annotations = []
|
224 |
+
for annotation in annotations:
|
225 |
+
coco_annotation = {
|
226 |
+
"image_id": annotation["image_id"],
|
227 |
+
"annotations": [],
|
228 |
+
}
|
229 |
+
for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]):
|
230 |
+
coco_annotation["annotations"].append({
|
231 |
+
"bbox": x1y1x2y2_to_xywh(bbox),
|
232 |
+
"category_id": label,
|
233 |
+
"area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
|
234 |
+
})
|
235 |
+
coco_annotations.append(coco_annotation)
|
236 |
+
return coco_annotations
|
237 |
+
|
238 |
+
def _verify_annotations_are_in_correct_format(self, annotations):
|
239 |
+
error_msg = """
|
240 |
+
Annotations must be in the following format:
|
241 |
+
[
|
242 |
+
{
|
243 |
+
"image_id": 0,
|
244 |
+
"bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]],
|
245 |
+
"labels": [0, 1, 2],
|
246 |
+
},
|
247 |
+
...
|
248 |
+
]
|
249 |
+
Labels: 0 for characters, 1 for text, 2 for panels.
|
250 |
+
"""
|
251 |
+
if annotations is None:
|
252 |
+
return
|
253 |
+
if not isinstance(annotations, List) and not isinstance(annotations, tuple):
|
254 |
+
raise ValueError(
|
255 |
+
f"{error_msg} Expected a List/Tuple, found {type(annotations)}."
|
256 |
+
)
|
257 |
+
if len(annotations) == 0:
|
258 |
+
return
|
259 |
+
if not isinstance(annotations[0], dict):
|
260 |
+
raise ValueError(
|
261 |
+
f"{error_msg} Expected a List[Dict], found {type(annotations[0])}."
|
262 |
+
)
|
263 |
+
if "image_id" not in annotations[0]:
|
264 |
+
raise ValueError(
|
265 |
+
f"{error_msg} Dict must contain 'image_id'."
|
266 |
+
)
|
267 |
+
if "bboxes_as_x1y1x2y2" not in annotations[0]:
|
268 |
+
raise ValueError(
|
269 |
+
f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'."
|
270 |
+
)
|
271 |
+
if "labels" not in annotations[0]:
|
272 |
+
raise ValueError(
|
273 |
+
f"{error_msg} Dict must contain 'labels'."
|
274 |
+
)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:219c2b80e741b1d02e92f22701a38358a5606d6460ad8b6335091e909b212011
|
3 |
+
size 2063428286
|
utils.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import matplotlib.patches as patches
|
6 |
+
from shapely.geometry import Point, box
|
7 |
+
import networkx as nx
|
8 |
+
from copy import deepcopy
|
9 |
+
from itertools import groupby
|
10 |
+
|
11 |
+
def move_to_device(inputs, device):
|
12 |
+
if hasattr(inputs, "keys"):
|
13 |
+
return {k: move_to_device(v, device) for k, v in inputs.items()}
|
14 |
+
elif isinstance(inputs, list):
|
15 |
+
return [move_to_device(v, device) for v in inputs]
|
16 |
+
elif isinstance(inputs, tuple):
|
17 |
+
return tuple([move_to_device(v, device) for v in inputs])
|
18 |
+
elif isinstance(inputs, np.ndarray):
|
19 |
+
return torch.from_numpy(inputs).to(device)
|
20 |
+
else:
|
21 |
+
return inputs.to(device)
|
22 |
+
|
23 |
+
class UnionFind:
|
24 |
+
def __init__(self, n):
|
25 |
+
self.parent = list(range(n))
|
26 |
+
self.size = [1] * n
|
27 |
+
self.num_components = n
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_adj_matrix(cls, adj_matrix):
|
31 |
+
ufds = cls(adj_matrix.shape[0])
|
32 |
+
for i in range(adj_matrix.shape[0]):
|
33 |
+
for j in range(adj_matrix.shape[1]):
|
34 |
+
if adj_matrix[i, j] > 0:
|
35 |
+
ufds.unite(i, j)
|
36 |
+
return ufds
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def from_adj_list(cls, adj_list):
|
40 |
+
ufds = cls(len(adj_list))
|
41 |
+
for i in range(len(adj_list)):
|
42 |
+
for j in adj_list[i]:
|
43 |
+
ufds.unite(i, j)
|
44 |
+
return ufds
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_edge_list(cls, edge_list, num_nodes):
|
48 |
+
ufds = cls(num_nodes)
|
49 |
+
for edge in edge_list:
|
50 |
+
ufds.unite(edge[0], edge[1])
|
51 |
+
return ufds
|
52 |
+
|
53 |
+
def find(self, x):
|
54 |
+
if self.parent[x] == x:
|
55 |
+
return x
|
56 |
+
self.parent[x] = self.find(self.parent[x])
|
57 |
+
return self.parent[x]
|
58 |
+
|
59 |
+
def unite(self, x, y):
|
60 |
+
x = self.find(x)
|
61 |
+
y = self.find(y)
|
62 |
+
if x != y:
|
63 |
+
if self.size[x] < self.size[y]:
|
64 |
+
x, y = y, x
|
65 |
+
self.parent[y] = x
|
66 |
+
self.size[x] += self.size[y]
|
67 |
+
self.num_components -= 1
|
68 |
+
|
69 |
+
def get_components_of(self, x):
|
70 |
+
x = self.find(x)
|
71 |
+
return [i for i in range(len(self.parent)) if self.find(i) == x]
|
72 |
+
|
73 |
+
def are_connected(self, x, y):
|
74 |
+
return self.find(x) == self.find(y)
|
75 |
+
|
76 |
+
def get_size(self, x):
|
77 |
+
return self.size[self.find(x)]
|
78 |
+
|
79 |
+
def get_num_components(self):
|
80 |
+
return self.num_components
|
81 |
+
|
82 |
+
def get_labels_for_connected_components(self):
|
83 |
+
map_parent_to_label = {}
|
84 |
+
labels = []
|
85 |
+
for i in range(len(self.parent)):
|
86 |
+
parent = self.find(i)
|
87 |
+
if parent not in map_parent_to_label:
|
88 |
+
map_parent_to_label[parent] = len(map_parent_to_label)
|
89 |
+
labels.append(map_parent_to_label[parent])
|
90 |
+
return labels
|
91 |
+
|
92 |
+
def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
93 |
+
figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
|
94 |
+
subplot.imshow(image_as_np_array)
|
95 |
+
plot_bboxes(subplot, predictions["panels"], color="green")
|
96 |
+
plot_bboxes(subplot, predictions["texts"], color="red", add_index=True)
|
97 |
+
plot_bboxes(subplot, predictions["characters"], color="blue")
|
98 |
+
|
99 |
+
COLOURS = [
|
100 |
+
"#b7ff51", # green
|
101 |
+
"#f50a8f", # pink
|
102 |
+
"#4b13b6", # purple
|
103 |
+
"#ddaa34", # orange
|
104 |
+
"#bea2a2", # brown
|
105 |
+
]
|
106 |
+
colour_index = 0
|
107 |
+
character_cluster_labels = predictions["character_cluster_labels"]
|
108 |
+
unique_label_sorted_by_frequency = sorted(list(set(character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
|
109 |
+
for label in unique_label_sorted_by_frequency:
|
110 |
+
root = None
|
111 |
+
others = []
|
112 |
+
for i in range(len(predictions["characters"])):
|
113 |
+
if character_cluster_labels[i] == label:
|
114 |
+
if root is None:
|
115 |
+
root = i
|
116 |
+
else:
|
117 |
+
others.append(i)
|
118 |
+
if colour_index >= len(COLOURS):
|
119 |
+
random_colour = COLOURS[0]
|
120 |
+
while random_colour in COLOURS:
|
121 |
+
random_colour = "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
|
122 |
+
else:
|
123 |
+
random_colour = COLOURS[colour_index]
|
124 |
+
colour_index += 1
|
125 |
+
bbox_i = predictions["characters"][root]
|
126 |
+
x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
|
127 |
+
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
128 |
+
subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
|
129 |
+
for j in others:
|
130 |
+
# draw line from centre of bbox i to centre of bbox j
|
131 |
+
bbox_j = predictions["characters"][j]
|
132 |
+
x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
|
133 |
+
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
134 |
+
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
135 |
+
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
136 |
+
subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
|
137 |
+
subplot.plot([x2], [y2], color=random_colour, marker="o", markersize=5)
|
138 |
+
|
139 |
+
for (i, j) in predictions["text_character_associations"]:
|
140 |
+
score = predictions["dialog_confidences"][i]
|
141 |
+
bbox_i = predictions["texts"][i]
|
142 |
+
bbox_j = predictions["characters"][j]
|
143 |
+
x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
|
144 |
+
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
145 |
+
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
146 |
+
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
147 |
+
subplot.plot([x1, x2], [y1, y2], color="red", linewidth=2, linestyle="dashed", alpha=score)
|
148 |
+
|
149 |
+
subplot.axis("off")
|
150 |
+
if filename is not None:
|
151 |
+
plt.savefig(filename, bbox_inches="tight", pad_inches=0)
|
152 |
+
|
153 |
+
figure.canvas.draw()
|
154 |
+
image = np.array(figure.canvas.renderer._renderer)
|
155 |
+
plt.close()
|
156 |
+
return image
|
157 |
+
|
158 |
+
def plot_bboxes(subplot, bboxes, color="red", add_index=False):
|
159 |
+
for id, bbox in enumerate(bboxes):
|
160 |
+
w = bbox[2] - bbox[0]
|
161 |
+
h = bbox[3] - bbox[1]
|
162 |
+
rect = patches.Rectangle(
|
163 |
+
bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
|
164 |
+
)
|
165 |
+
subplot.add_patch(rect)
|
166 |
+
if add_index:
|
167 |
+
cx, cy = bbox[0] + w / 2, bbox[1] + h / 2
|
168 |
+
subplot.text(cx, cy, str(id), color=color, fontsize=10, ha="center", va="center")
|
169 |
+
|
170 |
+
def sort_panels(rects):
|
171 |
+
before_rects = convert_to_list_of_lists(rects)
|
172 |
+
# slightly erode all rectangles initially to account for imperfect detections
|
173 |
+
rects = [erode_rectangle(rect, 0.05) for rect in before_rects]
|
174 |
+
G = nx.DiGraph()
|
175 |
+
G.add_nodes_from(range(len(rects)))
|
176 |
+
for i in range(len(rects)):
|
177 |
+
for j in range(len(rects)):
|
178 |
+
if i == j:
|
179 |
+
continue
|
180 |
+
if is_there_a_directed_edge(i, j, rects):
|
181 |
+
G.add_edge(i, j, weight=get_distance(rects[i], rects[j]))
|
182 |
+
else:
|
183 |
+
G.add_edge(j, i, weight=get_distance(rects[i], rects[j]))
|
184 |
+
while True:
|
185 |
+
cycles = sorted(nx.simple_cycles(G))
|
186 |
+
cycles = [cycle for cycle in cycles if len(cycle) > 1]
|
187 |
+
if len(cycles) == 0:
|
188 |
+
break
|
189 |
+
cycle = cycles[0]
|
190 |
+
edges = [e for e in zip(cycle, cycle[1:] + cycle[:1])]
|
191 |
+
max_cyclic_edge = max(edges, key=lambda x: G.edges[x]["weight"])
|
192 |
+
G.remove_edge(*max_cyclic_edge)
|
193 |
+
return list(nx.topological_sort(G))
|
194 |
+
|
195 |
+
def is_strictly_above(rectA, rectB):
|
196 |
+
x1A, y1A, x2A, y2A = rectA
|
197 |
+
x1B, y1B, x2B, y2B = rectB
|
198 |
+
return y2A < y1B
|
199 |
+
|
200 |
+
def is_strictly_below(rectA, rectB):
|
201 |
+
x1A, y1A, x2A, y2A = rectA
|
202 |
+
x1B, y1B, x2B, y2B = rectB
|
203 |
+
return y2B < y1A
|
204 |
+
|
205 |
+
def is_strictly_left_of(rectA, rectB):
|
206 |
+
x1A, y1A, x2A, y2A = rectA
|
207 |
+
x1B, y1B, x2B, y2B = rectB
|
208 |
+
return x2A < x1B
|
209 |
+
|
210 |
+
def is_strictly_right_of(rectA, rectB):
|
211 |
+
x1A, y1A, x2A, y2A = rectA
|
212 |
+
x1B, y1B, x2B, y2B = rectB
|
213 |
+
return x2B < x1A
|
214 |
+
|
215 |
+
def intersects(rectA, rectB):
|
216 |
+
return box(*rectA).intersects(box(*rectB))
|
217 |
+
|
218 |
+
def is_there_a_directed_edge(a, b, rects):
|
219 |
+
rectA = rects[a]
|
220 |
+
rectB = rects[b]
|
221 |
+
centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2, rectA[1] + (rectA[3] - rectA[1]) / 2]
|
222 |
+
centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2, rectB[1] + (rectB[3] - rectB[1]) / 2]
|
223 |
+
if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
|
224 |
+
return box(*rectA).area > (box(*rectB)).area
|
225 |
+
copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
|
226 |
+
copy_B = [rectB[0], rectB[1], rectB[2], rectB[3]]
|
227 |
+
while True:
|
228 |
+
if is_strictly_above(copy_A, copy_B) and not is_strictly_left_of(copy_A, copy_B):
|
229 |
+
return 1
|
230 |
+
if is_strictly_above(copy_B, copy_A) and not is_strictly_left_of(copy_B, copy_A):
|
231 |
+
return 0
|
232 |
+
if is_strictly_right_of(copy_A, copy_B) and not is_strictly_below(copy_A, copy_B):
|
233 |
+
return 1
|
234 |
+
if is_strictly_right_of(copy_B, copy_A) and not is_strictly_below(copy_B, copy_A):
|
235 |
+
return 0
|
236 |
+
if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
|
237 |
+
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
238 |
+
if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
|
239 |
+
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
240 |
+
# otherwise they intersect
|
241 |
+
copy_A = erode_rectangle(copy_A, 0.05)
|
242 |
+
copy_B = erode_rectangle(copy_B, 0.05)
|
243 |
+
|
244 |
+
def get_distance(rectA, rectB):
|
245 |
+
return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
|
246 |
+
|
247 |
+
def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
|
248 |
+
rects = deepcopy(rects)
|
249 |
+
while True:
|
250 |
+
xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
|
251 |
+
rect_index = [i for i in range(len(rects)) if intersects(rects[i], [xmin, ymin, xmax, ymax])]
|
252 |
+
rects_copy = [rect for rect in rects if intersects(rect, [xmin, ymin, xmax, ymax])]
|
253 |
+
|
254 |
+
# try to split the panels using a "horizontal" lines
|
255 |
+
overlapping_y_ranges = merge_overlapping_ranges([(y1, y2) for x1, y1, x2, y2 in rects_copy])
|
256 |
+
panel_index_to_split = {}
|
257 |
+
for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
|
258 |
+
for i, index in enumerate(rect_index):
|
259 |
+
if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
|
260 |
+
panel_index_to_split[index] = split_index
|
261 |
+
|
262 |
+
if panel_index_to_split[a] != panel_index_to_split[b]:
|
263 |
+
return panel_index_to_split[a] < panel_index_to_split[b]
|
264 |
+
|
265 |
+
# try to split the panels using a "vertical" lines
|
266 |
+
overlapping_x_ranges = merge_overlapping_ranges([(x1, x2) for x1, y1, x2, y2 in rects_copy])
|
267 |
+
panel_index_to_split = {}
|
268 |
+
for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
|
269 |
+
for i, index in enumerate(rect_index):
|
270 |
+
if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
|
271 |
+
panel_index_to_split[index] = split_index
|
272 |
+
if panel_index_to_split[a] != panel_index_to_split[b]:
|
273 |
+
return panel_index_to_split[a] < panel_index_to_split[b]
|
274 |
+
|
275 |
+
# otherwise, erode the rectangles and try again
|
276 |
+
rects = [erode_rectangle(rect, 0.05) for rect in rects]
|
277 |
+
|
278 |
+
def erode_rectangle(bbox, erosion_factor):
|
279 |
+
x1, y1, x2, y2 = bbox
|
280 |
+
w, h = x2 - x1, y2 - y1
|
281 |
+
cx, cy = x1 + w / 2, y1 + h / 2
|
282 |
+
if w < h:
|
283 |
+
aspect_ratio = w / h
|
284 |
+
erosion_factor_width = erosion_factor * aspect_ratio
|
285 |
+
erosion_factor_height = erosion_factor
|
286 |
+
else:
|
287 |
+
aspect_ratio = h / w
|
288 |
+
erosion_factor_width = erosion_factor
|
289 |
+
erosion_factor_height = erosion_factor * aspect_ratio
|
290 |
+
w = w - w * erosion_factor_width
|
291 |
+
h = h - h * erosion_factor_height
|
292 |
+
x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
|
293 |
+
return [x1, y1, x2, y2]
|
294 |
+
|
295 |
+
def merge_overlapping_ranges(ranges):
|
296 |
+
"""
|
297 |
+
ranges: list of tuples (x1, x2)
|
298 |
+
"""
|
299 |
+
if len(ranges) == 0:
|
300 |
+
return []
|
301 |
+
ranges = sorted(ranges, key=lambda x: x[0])
|
302 |
+
merged_ranges = []
|
303 |
+
for i, r in enumerate(ranges):
|
304 |
+
if i == 0:
|
305 |
+
prev_x1, prev_x2 = r
|
306 |
+
continue
|
307 |
+
x1, x2 = r
|
308 |
+
if x1 > prev_x2:
|
309 |
+
merged_ranges.append((prev_x1, prev_x2))
|
310 |
+
prev_x1, prev_x2 = x1, x2
|
311 |
+
else:
|
312 |
+
prev_x2 = max(prev_x2, x2)
|
313 |
+
merged_ranges.append((prev_x1, prev_x2))
|
314 |
+
return merged_ranges
|
315 |
+
|
316 |
+
def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
|
317 |
+
text_bboxes = convert_to_list_of_lists(text_bboxes)
|
318 |
+
sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
|
319 |
+
|
320 |
+
if len(text_bboxes) == 0:
|
321 |
+
return []
|
322 |
+
|
323 |
+
def indices_of_same_elements(nums):
|
324 |
+
groups = groupby(range(len(nums)), key=lambda i: nums[i])
|
325 |
+
return [list(indices) for _, indices in groups]
|
326 |
+
|
327 |
+
panel_id_for_text = get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes)
|
328 |
+
indices_of_texts = list(range(len(text_bboxes)))
|
329 |
+
indices_of_texts, panel_id_for_text = zip(*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
|
330 |
+
indices_of_texts = list(indices_of_texts)
|
331 |
+
grouped_indices = indices_of_same_elements(panel_id_for_text)
|
332 |
+
for group in grouped_indices:
|
333 |
+
subset_of_text_indices = [indices_of_texts[i] for i in group]
|
334 |
+
text_bboxes_of_subset = [text_bboxes[i] for i in subset_of_text_indices]
|
335 |
+
sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
|
336 |
+
indices_of_texts[group[0] : group[-1] + 1] = [subset_of_text_indices[i] for i in sorted_subset_indices]
|
337 |
+
return indices_of_texts
|
338 |
+
|
339 |
+
def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
|
340 |
+
text_to_panel_mapping = []
|
341 |
+
for text_bbox in text_bboxes:
|
342 |
+
shapely_text_polygon = box(*text_bbox)
|
343 |
+
all_intersections = []
|
344 |
+
all_distances = []
|
345 |
+
if len(sorted_panel_bboxes) == 0:
|
346 |
+
text_to_panel_mapping.append(-1)
|
347 |
+
continue
|
348 |
+
for j, annotation in enumerate(sorted_panel_bboxes):
|
349 |
+
shapely_annotation_polygon = box(*annotation)
|
350 |
+
if shapely_text_polygon.intersects(shapely_annotation_polygon):
|
351 |
+
all_intersections.append((shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
|
352 |
+
all_distances.append((shapely_text_polygon.distance(shapely_annotation_polygon), j))
|
353 |
+
if len(all_intersections) == 0:
|
354 |
+
text_to_panel_mapping.append(min(all_distances, key=lambda x: x[0])[1])
|
355 |
+
else:
|
356 |
+
text_to_panel_mapping.append(max(all_intersections, key=lambda x: x[0])[1])
|
357 |
+
return text_to_panel_mapping
|
358 |
+
|
359 |
+
def sort_texts_within_panel(rects):
|
360 |
+
smallest_y = float("inf")
|
361 |
+
greatest_x = float("-inf")
|
362 |
+
for i, rect in enumerate(rects):
|
363 |
+
x1, y1, x2, y2 = rect
|
364 |
+
smallest_y = min(smallest_y, y1)
|
365 |
+
greatest_x = max(greatest_x, x2)
|
366 |
+
|
367 |
+
reference_point = Point(greatest_x, smallest_y)
|
368 |
+
|
369 |
+
polygons_and_index = []
|
370 |
+
for i, rect in enumerate(rects):
|
371 |
+
x1, y1, x2, y2 = rect
|
372 |
+
polygons_and_index.append((box(x1,y1,x2,y2), i))
|
373 |
+
# sort points by closest to reference point
|
374 |
+
polygons_and_index = sorted(polygons_and_index, key=lambda x: reference_point.distance(x[0]))
|
375 |
+
indices = [x[1] for x in polygons_and_index]
|
376 |
+
return indices
|
377 |
+
|
378 |
+
def x1y1wh_to_x1y1x2y2(bbox):
|
379 |
+
x1, y1, w, h = bbox
|
380 |
+
return [x1, y1, x1 + w, y1 + h]
|
381 |
+
|
382 |
+
def x1y1x2y2_to_xywh(bbox):
|
383 |
+
x1, y1, x2, y2 = bbox
|
384 |
+
return [x1, y1, x2 - x1, y2 - y1]
|
385 |
+
|
386 |
+
def convert_to_list_of_lists(rects):
|
387 |
+
if isinstance(rects, torch.Tensor):
|
388 |
+
return rects.tolist()
|
389 |
+
if isinstance(rects, np.ndarray):
|
390 |
+
return rects.tolist()
|
391 |
+
return [[a, b, c, d] for a, b, c, d in rects]
|