ragavsachdeva commited on
Commit
b36f7c2
·
verified ·
1 Parent(s): 7bb1e4c

Upload model

Browse files
Files changed (6) hide show
  1. config.json +475 -0
  2. configuration_magi.py +38 -0
  3. modelling_magi.py +486 -0
  4. processing_magi.py +274 -0
  5. pytorch_model.bin +3 -0
  6. utils.py +391 -0
config.json ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "to_push",
3
+ "architectures": [
4
+ "MagiModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_magi.MagiConfig",
8
+ "AutoModel": "modelling_magi.MagiModel"
9
+ },
10
+ "crop_embedding_image_preprocessing_config": {
11
+ "_processor_class": null,
12
+ "do_normalize": true,
13
+ "do_rescale": true,
14
+ "do_resize": true,
15
+ "image_mean": [
16
+ 0.485,
17
+ 0.456,
18
+ 0.406
19
+ ],
20
+ "image_processor_type": "ViTImageProcessor",
21
+ "image_std": [
22
+ 0.229,
23
+ 0.224,
24
+ 0.225
25
+ ],
26
+ "resample": 2,
27
+ "rescale_factor": 0.00392156862745098,
28
+ "size": {
29
+ "height": 224,
30
+ "width": 224
31
+ }
32
+ },
33
+ "crop_embedding_model_config": {
34
+ "_name_or_path": "facebook/vit-mae-base",
35
+ "add_cross_attention": false,
36
+ "architectures": [
37
+ "ViTMAEForPreTraining"
38
+ ],
39
+ "attention_probs_dropout_prob": 0.0,
40
+ "bad_words_ids": null,
41
+ "begin_suppress_tokens": null,
42
+ "bos_token_id": null,
43
+ "chunk_size_feed_forward": 0,
44
+ "cross_attention_hidden_size": null,
45
+ "decoder_hidden_size": 512,
46
+ "decoder_intermediate_size": 2048,
47
+ "decoder_num_attention_heads": 16,
48
+ "decoder_num_hidden_layers": 8,
49
+ "decoder_start_token_id": null,
50
+ "diversity_penalty": 0.0,
51
+ "do_sample": false,
52
+ "early_stopping": false,
53
+ "encoder_no_repeat_ngram_size": 0,
54
+ "eos_token_id": null,
55
+ "exponential_decay_length_penalty": null,
56
+ "finetuning_task": null,
57
+ "forced_bos_token_id": null,
58
+ "forced_eos_token_id": null,
59
+ "hidden_act": "gelu",
60
+ "hidden_dropout_prob": 0.0,
61
+ "hidden_size": 768,
62
+ "id2label": {
63
+ "0": "LABEL_0",
64
+ "1": "LABEL_1"
65
+ },
66
+ "image_size": 224,
67
+ "initializer_range": 0.02,
68
+ "intermediate_size": 3072,
69
+ "is_decoder": false,
70
+ "is_encoder_decoder": false,
71
+ "label2id": {
72
+ "LABEL_0": 0,
73
+ "LABEL_1": 1
74
+ },
75
+ "layer_norm_eps": 1e-12,
76
+ "length_penalty": 1.0,
77
+ "mask_ratio": 0.75,
78
+ "max_length": 20,
79
+ "min_length": 0,
80
+ "model_type": "",
81
+ "no_repeat_ngram_size": 0,
82
+ "norm_pix_loss": false,
83
+ "num_attention_heads": 12,
84
+ "num_beam_groups": 1,
85
+ "num_beams": 1,
86
+ "num_channels": 3,
87
+ "num_hidden_layers": 12,
88
+ "num_return_sequences": 1,
89
+ "output_attentions": false,
90
+ "output_hidden_states": false,
91
+ "output_scores": false,
92
+ "pad_token_id": null,
93
+ "patch_size": 16,
94
+ "prefix": null,
95
+ "problem_type": null,
96
+ "pruned_heads": {},
97
+ "qkv_bias": true,
98
+ "remove_invalid_values": false,
99
+ "repetition_penalty": 1.0,
100
+ "return_dict": true,
101
+ "return_dict_in_generate": false,
102
+ "sep_token_id": null,
103
+ "suppress_tokens": null,
104
+ "task_specific_params": null,
105
+ "temperature": 1.0,
106
+ "tf_legacy_loss": false,
107
+ "tie_encoder_decoder": false,
108
+ "tie_word_embeddings": true,
109
+ "tokenizer_class": null,
110
+ "top_k": 50,
111
+ "top_p": 1.0,
112
+ "torch_dtype": "float32",
113
+ "torchscript": false,
114
+ "typical_p": 1.0,
115
+ "use_bfloat16": false
116
+ },
117
+ "detection_image_preprocessing_config": {
118
+ "_processor_class": null,
119
+ "do_normalize": true,
120
+ "do_pad": true,
121
+ "do_rescale": true,
122
+ "do_resize": true,
123
+ "format": "coco_detection",
124
+ "image_mean": [
125
+ 0.485,
126
+ 0.456,
127
+ 0.406
128
+ ],
129
+ "image_processor_type": "ConditionalDetrImageProcessor",
130
+ "image_std": [
131
+ 0.229,
132
+ 0.224,
133
+ 0.225
134
+ ],
135
+ "resample": 2,
136
+ "rescale_factor": 0.00392156862745098,
137
+ "size": {
138
+ "longest_edge": 1333,
139
+ "shortest_edge": 800
140
+ }
141
+ },
142
+ "detection_model_config": {
143
+ "_name_or_path": "microsoft/conditional-detr-resnet-50",
144
+ "activation_dropout": 0.0,
145
+ "activation_function": "relu",
146
+ "add_cross_attention": false,
147
+ "architectures": [
148
+ "ConditionalDETRForObjectDetection"
149
+ ],
150
+ "attention_dropout": 0.0,
151
+ "auxiliary_loss": false,
152
+ "backbone": "resnet50",
153
+ "backbone_config": null,
154
+ "bad_words_ids": null,
155
+ "bbox_cost": 5,
156
+ "bbox_loss_coefficient": 5,
157
+ "begin_suppress_tokens": null,
158
+ "bos_token_id": null,
159
+ "chunk_size_feed_forward": 0,
160
+ "class_cost": 2,
161
+ "cls_loss_coefficient": 2,
162
+ "cross_attention_hidden_size": null,
163
+ "d_model": 256,
164
+ "decoder_attention_heads": 8,
165
+ "decoder_ffn_dim": 2048,
166
+ "decoder_layerdrop": 0.0,
167
+ "decoder_layers": 6,
168
+ "decoder_start_token_id": null,
169
+ "dice_loss_coefficient": 1,
170
+ "dilation": false,
171
+ "diversity_penalty": 0.0,
172
+ "do_sample": false,
173
+ "dropout": 0.1,
174
+ "early_stopping": false,
175
+ "encoder_attention_heads": 8,
176
+ "encoder_ffn_dim": 2048,
177
+ "encoder_layerdrop": 0.0,
178
+ "encoder_layers": 6,
179
+ "encoder_no_repeat_ngram_size": 0,
180
+ "eos_token_id": null,
181
+ "exponential_decay_length_penalty": null,
182
+ "finetuning_task": null,
183
+ "focal_alpha": 0.25,
184
+ "forced_bos_token_id": null,
185
+ "forced_eos_token_id": null,
186
+ "giou_cost": 2,
187
+ "giou_loss_coefficient": 2,
188
+ "id2label": {
189
+ "0": "LABEL_0",
190
+ "1": "LABEL_1",
191
+ "2": "LABEL_2"
192
+ },
193
+ "init_std": 0.02,
194
+ "init_xavier_std": 1.0,
195
+ "is_decoder": false,
196
+ "is_encoder_decoder": true,
197
+ "label2id": {
198
+ "LABEL_0": 0,
199
+ "LABEL_1": 1,
200
+ "LABEL_2": 2
201
+ },
202
+ "length_penalty": 1.0,
203
+ "mask_loss_coefficient": 1,
204
+ "max_length": 20,
205
+ "max_position_embeddings": 1024,
206
+ "min_length": 0,
207
+ "model_type": "",
208
+ "no_repeat_ngram_size": 0,
209
+ "num_beam_groups": 1,
210
+ "num_beams": 1,
211
+ "num_channels": 3,
212
+ "num_hidden_layers": 6,
213
+ "num_queries": 305,
214
+ "num_return_sequences": 1,
215
+ "output_attentions": false,
216
+ "output_hidden_states": false,
217
+ "output_scores": false,
218
+ "pad_token_id": null,
219
+ "position_embedding_type": "sine",
220
+ "prefix": null,
221
+ "problem_type": null,
222
+ "pruned_heads": {},
223
+ "remove_invalid_values": false,
224
+ "repetition_penalty": 1.0,
225
+ "return_dict": true,
226
+ "return_dict_in_generate": false,
227
+ "scale_embedding": false,
228
+ "sep_token_id": null,
229
+ "suppress_tokens": null,
230
+ "task_specific_params": null,
231
+ "temperature": 1.0,
232
+ "tf_legacy_loss": false,
233
+ "tie_encoder_decoder": false,
234
+ "tie_word_embeddings": true,
235
+ "tokenizer_class": null,
236
+ "top_k": 50,
237
+ "top_p": 1.0,
238
+ "torch_dtype": "float32",
239
+ "torchscript": false,
240
+ "typical_p": 1.0,
241
+ "use_bfloat16": false,
242
+ "use_pretrained_backbone": true,
243
+ "use_timm_backbone": true
244
+ },
245
+ "disable_crop_embeddings": false,
246
+ "disable_detections": false,
247
+ "disable_ocr": false,
248
+ "model_type": "magi",
249
+ "ocr_model_config": {
250
+ "_name_or_path": "/work/rs/logs/manga_ocr/nt8rn2ul/",
251
+ "add_cross_attention": false,
252
+ "architectures": [
253
+ "VisionEncoderDecoderModel"
254
+ ],
255
+ "bad_words_ids": null,
256
+ "begin_suppress_tokens": null,
257
+ "bos_token_id": null,
258
+ "chunk_size_feed_forward": 0,
259
+ "cross_attention_hidden_size": null,
260
+ "decoder": {
261
+ "_name_or_path": "",
262
+ "activation_dropout": 0.0,
263
+ "activation_function": "gelu",
264
+ "add_cross_attention": true,
265
+ "architectures": null,
266
+ "attention_dropout": 0.0,
267
+ "bad_words_ids": null,
268
+ "begin_suppress_tokens": null,
269
+ "bos_token_id": 0,
270
+ "chunk_size_feed_forward": 0,
271
+ "classifier_dropout": 0.0,
272
+ "cross_attention_hidden_size": 768,
273
+ "d_model": 1024,
274
+ "decoder_attention_heads": 16,
275
+ "decoder_ffn_dim": 4096,
276
+ "decoder_layerdrop": 0.0,
277
+ "decoder_layers": 12,
278
+ "decoder_start_token_id": 2,
279
+ "diversity_penalty": 0.0,
280
+ "do_sample": false,
281
+ "dropout": 0.1,
282
+ "early_stopping": false,
283
+ "encoder_no_repeat_ngram_size": 0,
284
+ "eos_token_id": 2,
285
+ "exponential_decay_length_penalty": null,
286
+ "finetuning_task": null,
287
+ "forced_bos_token_id": null,
288
+ "forced_eos_token_id": null,
289
+ "id2label": {
290
+ "0": "LABEL_0",
291
+ "1": "LABEL_1"
292
+ },
293
+ "init_std": 0.02,
294
+ "is_decoder": true,
295
+ "is_encoder_decoder": false,
296
+ "label2id": {
297
+ "LABEL_0": 0,
298
+ "LABEL_1": 1
299
+ },
300
+ "layernorm_embedding": true,
301
+ "length_penalty": 1.0,
302
+ "max_length": 20,
303
+ "max_position_embeddings": 512,
304
+ "min_length": 0,
305
+ "model_type": "trocr",
306
+ "no_repeat_ngram_size": 0,
307
+ "num_beam_groups": 1,
308
+ "num_beams": 1,
309
+ "num_return_sequences": 1,
310
+ "output_attentions": false,
311
+ "output_hidden_states": false,
312
+ "output_scores": false,
313
+ "pad_token_id": 1,
314
+ "prefix": null,
315
+ "problem_type": null,
316
+ "pruned_heads": {},
317
+ "remove_invalid_values": false,
318
+ "repetition_penalty": 1.0,
319
+ "return_dict": true,
320
+ "return_dict_in_generate": false,
321
+ "scale_embedding": false,
322
+ "sep_token_id": null,
323
+ "suppress_tokens": null,
324
+ "task_specific_params": null,
325
+ "temperature": 1.0,
326
+ "tf_legacy_loss": false,
327
+ "tie_encoder_decoder": false,
328
+ "tie_word_embeddings": true,
329
+ "tokenizer_class": null,
330
+ "top_k": 50,
331
+ "top_p": 1.0,
332
+ "torch_dtype": null,
333
+ "torchscript": false,
334
+ "typical_p": 1.0,
335
+ "use_bfloat16": false,
336
+ "use_cache": false,
337
+ "use_learned_position_embeddings": true,
338
+ "vocab_size": 50265
339
+ },
340
+ "decoder_start_token_id": 0,
341
+ "diversity_penalty": 0.0,
342
+ "do_sample": false,
343
+ "early_stopping": true,
344
+ "encoder": {
345
+ "_name_or_path": "",
346
+ "add_cross_attention": false,
347
+ "architectures": null,
348
+ "attention_probs_dropout_prob": 0.0,
349
+ "bad_words_ids": null,
350
+ "begin_suppress_tokens": null,
351
+ "bos_token_id": null,
352
+ "chunk_size_feed_forward": 0,
353
+ "cross_attention_hidden_size": null,
354
+ "decoder_start_token_id": null,
355
+ "diversity_penalty": 0.0,
356
+ "do_sample": false,
357
+ "early_stopping": false,
358
+ "encoder_no_repeat_ngram_size": 0,
359
+ "encoder_stride": 16,
360
+ "eos_token_id": null,
361
+ "exponential_decay_length_penalty": null,
362
+ "finetuning_task": null,
363
+ "forced_bos_token_id": null,
364
+ "forced_eos_token_id": null,
365
+ "hidden_act": "gelu",
366
+ "hidden_dropout_prob": 0.0,
367
+ "hidden_size": 768,
368
+ "id2label": {
369
+ "0": "LABEL_0",
370
+ "1": "LABEL_1"
371
+ },
372
+ "image_size": 384,
373
+ "initializer_range": 0.02,
374
+ "intermediate_size": 3072,
375
+ "is_decoder": false,
376
+ "is_encoder_decoder": false,
377
+ "label2id": {
378
+ "LABEL_0": 0,
379
+ "LABEL_1": 1
380
+ },
381
+ "layer_norm_eps": 1e-12,
382
+ "length_penalty": 1.0,
383
+ "max_length": 20,
384
+ "min_length": 0,
385
+ "model_type": "vit",
386
+ "no_repeat_ngram_size": 0,
387
+ "num_attention_heads": 12,
388
+ "num_beam_groups": 1,
389
+ "num_beams": 1,
390
+ "num_channels": 3,
391
+ "num_hidden_layers": 12,
392
+ "num_return_sequences": 1,
393
+ "output_attentions": false,
394
+ "output_hidden_states": false,
395
+ "output_scores": false,
396
+ "pad_token_id": null,
397
+ "patch_size": 16,
398
+ "prefix": null,
399
+ "problem_type": null,
400
+ "pruned_heads": {},
401
+ "qkv_bias": false,
402
+ "remove_invalid_values": false,
403
+ "repetition_penalty": 1.0,
404
+ "return_dict": true,
405
+ "return_dict_in_generate": false,
406
+ "sep_token_id": null,
407
+ "suppress_tokens": null,
408
+ "task_specific_params": null,
409
+ "temperature": 1.0,
410
+ "tf_legacy_loss": false,
411
+ "tie_encoder_decoder": false,
412
+ "tie_word_embeddings": true,
413
+ "tokenizer_class": null,
414
+ "top_k": 50,
415
+ "top_p": 1.0,
416
+ "torch_dtype": null,
417
+ "torchscript": false,
418
+ "typical_p": 1.0,
419
+ "use_bfloat16": false
420
+ },
421
+ "encoder_no_repeat_ngram_size": 0,
422
+ "eos_token_id": 2,
423
+ "exponential_decay_length_penalty": null,
424
+ "finetuning_task": null,
425
+ "forced_bos_token_id": null,
426
+ "forced_eos_token_id": null,
427
+ "id2label": {
428
+ "0": "LABEL_0",
429
+ "1": "LABEL_1"
430
+ },
431
+ "is_decoder": false,
432
+ "is_encoder_decoder": true,
433
+ "label2id": {
434
+ "LABEL_0": 0,
435
+ "LABEL_1": 1
436
+ },
437
+ "length_penalty": 2.0,
438
+ "max_length": 300,
439
+ "min_length": 0,
440
+ "model_type": "vision-encoder-decoder",
441
+ "no_repeat_ngram_size": 3,
442
+ "num_beam_groups": 1,
443
+ "num_beams": 4,
444
+ "num_return_sequences": 1,
445
+ "output_attentions": false,
446
+ "output_hidden_states": false,
447
+ "output_scores": false,
448
+ "pad_token_id": 1,
449
+ "prefix": null,
450
+ "problem_type": null,
451
+ "pruned_heads": {},
452
+ "remove_invalid_values": false,
453
+ "repetition_penalty": 1.0,
454
+ "return_dict": true,
455
+ "return_dict_in_generate": false,
456
+ "sep_token_id": null,
457
+ "suppress_tokens": null,
458
+ "task_specific_params": null,
459
+ "temperature": 1.0,
460
+ "tf_legacy_loss": false,
461
+ "tie_encoder_decoder": false,
462
+ "tie_word_embeddings": false,
463
+ "tokenizer_class": null,
464
+ "top_k": 50,
465
+ "top_p": 1.0,
466
+ "torch_dtype": "float32",
467
+ "torchscript": false,
468
+ "typical_p": 1.0,
469
+ "use_bfloat16": false,
470
+ "vocab_size": 50265
471
+ },
472
+ "ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
473
+ "torch_dtype": "float32",
474
+ "transformers_version": "4.34.0.dev0"
475
+ }
configuration_magi.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, VisionEncoderDecoderConfig
2
+ from typing import List
3
+
4
+
5
+ class MagiConfig(PretrainedConfig):
6
+ model_type = "magi"
7
+
8
+ def __init__(
9
+ self,
10
+ disable_ocr: bool = False,
11
+ disable_crop_embeddings: bool = False,
12
+ disable_detections: bool = False,
13
+ detection_model_config: dict = None,
14
+ ocr_model_config: dict = None,
15
+ crop_embedding_model_config: dict = None,
16
+ detection_image_preprocessing_config: dict = None,
17
+ ocr_pretrained_processor_path: str = None,
18
+ crop_embedding_image_preprocessing_config: dict = None,
19
+ **kwargs,
20
+ ):
21
+ self.disable_ocr = disable_ocr
22
+ self.disable_crop_embeddings = disable_crop_embeddings
23
+ self.disable_detections = disable_detections
24
+
25
+ self.detection_model_config = None
26
+ self.ocr_model_config = None
27
+ self.crop_embedding_model_config = None
28
+ if detection_model_config is not None:
29
+ self.detection_model_config = PretrainedConfig.from_dict(detection_model_config)
30
+ if ocr_model_config is not None:
31
+ self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config)
32
+ if crop_embedding_model_config is not None:
33
+ self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config)
34
+
35
+ self.detection_image_preprocessing_config = detection_image_preprocessing_config
36
+ self.ocr_pretrained_processor_path = ocr_pretrained_processor_path
37
+ self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config
38
+ super().__init__(**kwargs)
modelling_magi.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
2
+ from transformers.models.conditional_detr.modeling_conditional_detr import (
3
+ ConditionalDetrMLPPredictionHead,
4
+ ConditionalDetrModelOutput,
5
+ ConditionalDetrHungarianMatcher,
6
+ inverse_sigmoid,
7
+ )
8
+ from .configuration_magi import MagiConfig
9
+ from .processing_magi import MagiProcessor
10
+ from torch import nn
11
+ from typing import Optional, List
12
+ import torch
13
+ from einops import rearrange, repeat, einsum
14
+ from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
15
+
16
+ class MagiModel(PreTrainedModel):
17
+ config_class = MagiConfig
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.config = config
22
+ self.processor = MagiProcessor(config)
23
+ if not config.disable_ocr:
24
+ self.ocr_model = VisionEncoderDecoderModel(config.ocr_model_config)
25
+ if not config.disable_crop_embeddings:
26
+ self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config)
27
+ if not config.disable_detections:
28
+ self.num_non_obj_tokens = 5
29
+ self.detection_transformer = ConditionalDetrModel(config.detection_model_config)
30
+ self.bbox_predictor = ConditionalDetrMLPPredictionHead(
31
+ input_dim=config.detection_model_config.d_model,
32
+ hidden_dim=config.detection_model_config.d_model,
33
+ output_dim=4, num_layers=3
34
+ )
35
+ self.is_this_text_a_dialogue = ConditionalDetrMLPPredictionHead(
36
+ input_dim=config.detection_model_config.d_model,
37
+ hidden_dim=config.detection_model_config.d_model,
38
+ output_dim=1,
39
+ num_layers=3
40
+ )
41
+ self.character_character_matching_head = ConditionalDetrMLPPredictionHead(
42
+ input_dim = 3 * config.detection_model_config.d_model + (2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0),
43
+ hidden_dim=config.detection_model_config.d_model,
44
+ output_dim=1, num_layers=3
45
+ )
46
+ self.text_character_matching_head = ConditionalDetrMLPPredictionHead(
47
+ input_dim = 3 * config.detection_model_config.d_model,
48
+ hidden_dim=config.detection_model_config.d_model,
49
+ output_dim=1, num_layers=3
50
+ )
51
+ self.class_labels_classifier = nn.Linear(
52
+ config.detection_model_config.d_model, config.detection_model_config.num_labels
53
+ )
54
+ self.matcher = ConditionalDetrHungarianMatcher(
55
+ class_cost=config.detection_model_config.class_cost,
56
+ bbox_cost=config.detection_model_config.bbox_cost,
57
+ giou_cost=config.detection_model_config.giou_cost
58
+ )
59
+
60
+ def move_to_device(self, input):
61
+ return move_to_device(input, self.device)
62
+
63
+ def predict_detections_and_associations(
64
+ self,
65
+ images,
66
+ move_to_device_fn=None,
67
+ character_detection_threshold=0.3,
68
+ panel_detection_threshold=0.2,
69
+ text_detection_threshold=0.25,
70
+ character_character_matching_threshold=0.7,
71
+ text_character_matching_threshold=0.4,
72
+ ):
73
+ assert not self.config.disable_detections
74
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
75
+
76
+ inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images)
77
+ inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
78
+
79
+ detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
80
+ predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
81
+
82
+ # create callback fn
83
+ def get_character_character_matching_scores(batch_character_indices, batch_bboxes):
84
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
85
+ predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
86
+ crop_bboxes = [batch_bboxes[i][batch_character_indices[i]] for i in range(len(batch_character_indices))]
87
+ crop_embeddings_for_batch = self.predict_crop_embeddings(images, crop_bboxes, move_to_device_fn)
88
+ character_obj_tokens_for_batch = []
89
+ c2c_tokens_for_batch = []
90
+ for predicted_obj_tokens, predicted_c2c_tokens, character_indices in zip(predicted_obj_tokens_for_batch, predicted_c2c_tokens_for_batch, batch_character_indices):
91
+ character_obj_tokens_for_batch.append(predicted_obj_tokens[character_indices])
92
+ c2c_tokens_for_batch.append(predicted_c2c_tokens)
93
+ return self._get_character_character_affinity_matrices(
94
+ character_obj_tokens_for_batch=character_obj_tokens_for_batch,
95
+ crop_embeddings_for_batch=crop_embeddings_for_batch,
96
+ c2c_tokens_for_batch=c2c_tokens_for_batch,
97
+ apply_sigmoid=True,
98
+ )
99
+
100
+ # create callback fn
101
+ def get_text_character_matching_scores(batch_text_indices, batch_character_indices):
102
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
103
+ predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
104
+ text_obj_tokens_for_batch = []
105
+ character_obj_tokens_for_batch = []
106
+ t2c_tokens_for_batch = []
107
+ for predicted_obj_tokens, predicted_t2c_tokens, text_indices, character_indices in zip(predicted_obj_tokens_for_batch, predicted_t2c_tokens_for_batch, batch_text_indices, batch_character_indices):
108
+ text_obj_tokens_for_batch.append(predicted_obj_tokens[text_indices])
109
+ character_obj_tokens_for_batch.append(predicted_obj_tokens[character_indices])
110
+ t2c_tokens_for_batch.append(predicted_t2c_tokens)
111
+ return self._get_text_character_affinity_matrices(
112
+ character_obj_tokens_for_batch=character_obj_tokens_for_batch,
113
+ text_obj_tokens_for_this_batch=text_obj_tokens_for_batch,
114
+ t2c_tokens_for_batch=t2c_tokens_for_batch,
115
+ apply_sigmoid=True,
116
+ )
117
+
118
+ # create callback fn
119
+ def get_dialog_confidence_scores(batch_text_indices):
120
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
121
+ dialog_confidence = []
122
+ for predicted_obj_tokens, text_indices in zip(predicted_obj_tokens_for_batch, batch_text_indices):
123
+ confidence = self.is_this_text_a_dialogue(predicted_obj_tokens[text_indices]).sigmoid()
124
+ dialog_confidence.append(rearrange(confidence, "i 1 -> i"))
125
+ return dialog_confidence
126
+
127
+ return self.processor.postprocess_detections_and_associations(
128
+ predicted_bboxes=predicted_bboxes,
129
+ predicted_class_scores=predicted_class_scores,
130
+ original_image_sizes=torch.stack([torch.tensor(img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device),
131
+ get_character_character_matching_scores=get_character_character_matching_scores,
132
+ get_text_character_matching_scores=get_text_character_matching_scores,
133
+ get_dialog_confidence_scores=get_dialog_confidence_scores,
134
+ character_detection_threshold=character_detection_threshold,
135
+ panel_detection_threshold=panel_detection_threshold,
136
+ text_detection_threshold=text_detection_threshold,
137
+ character_character_matching_threshold=character_character_matching_threshold,
138
+ text_character_matching_threshold=text_character_matching_threshold,
139
+ )
140
+
141
+ def predict_crop_embeddings(self, images, crop_bboxes, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
142
+ if self.config.disable_crop_embeddings:
143
+ return None
144
+
145
+ assert isinstance(crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for"
146
+
147
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
148
+
149
+ # temporarily change the mask ratio from default to the one specified
150
+ old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
151
+ self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
152
+
153
+ crops_per_image = []
154
+ num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
155
+ for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
156
+ crops = self.processor.crop_image(image, bboxes)
157
+ assert len(crops) == num_crops
158
+ crops_per_image.extend(crops)
159
+
160
+ if len(crops_per_image) == 0:
161
+ return [[] for _ in crop_bboxes]
162
+
163
+ crops_per_image = self.processor.preprocess_inputs_for_crop_embeddings(crops_per_image)
164
+ crops_per_image = move_to_device_fn(crops_per_image)
165
+
166
+ # process the crops in batches to avoid OOM
167
+ embeddings = []
168
+ for i in range(0, len(crops_per_image), batch_size):
169
+ crops = crops_per_image[i:i+batch_size]
170
+ embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0]
171
+ embeddings.append(embeddings_per_batch)
172
+ embeddings = torch.cat(embeddings, dim=0)
173
+
174
+ crop_embeddings_for_batch = []
175
+ for num_crops in num_crops_per_batch:
176
+ crop_embeddings_for_batch.append(embeddings[:num_crops])
177
+ embeddings = embeddings[num_crops:]
178
+
179
+ # restore the mask ratio to the default
180
+ self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
181
+
182
+ return crop_embeddings_for_batch
183
+
184
+ def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32):
185
+ assert not self.config.disable_ocr
186
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
187
+
188
+ crops_per_image = []
189
+ num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
190
+ for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
191
+ crops = self.processor.crop_image(image, bboxes)
192
+ assert len(crops) == num_crops
193
+ crops_per_image.extend(crops)
194
+
195
+ if len(crops_per_image) == 0:
196
+ return [[] for _ in crop_bboxes]
197
+
198
+ crops_per_image = self.processor.preprocess_inputs_for_ocr(crops_per_image)
199
+ crops_per_image = move_to_device_fn(crops_per_image)
200
+
201
+ # process the crops in batches to avoid OOM
202
+ all_generated_texts = []
203
+ if use_tqdm:
204
+ from tqdm import tqdm
205
+ pbar = tqdm(range(0, len(crops_per_image), batch_size))
206
+ else:
207
+ pbar = range(0, len(crops_per_image), batch_size)
208
+ for i in pbar:
209
+ crops = crops_per_image[i:i+batch_size]
210
+ generated_ids = self.ocr_model.generate(crops)
211
+ generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
212
+ all_generated_texts.extend(generated_texts)
213
+
214
+ texts_for_images = []
215
+ for num_crops in num_crops_per_batch:
216
+ texts_for_images.append([x.replace("\n", "") for x in all_generated_texts[:num_crops]])
217
+ all_generated_texts = all_generated_texts[num_crops:]
218
+
219
+ return texts_for_images
220
+
221
+ def visualise_single_image_prediction(
222
+ self, image_as_np_array, predictions, filename=None
223
+ ):
224
+ return visualise_single_image_prediction(image_as_np_array, predictions, filename)
225
+
226
+ def generate_transcript_for_single_image(
227
+ self, predictions, ocr_results, filename=None
228
+ ):
229
+ character_clusters = predictions["character_cluster_labels"]
230
+ text_to_character = predictions["text_character_associations"]
231
+ text_to_character = {k: v for k, v in text_to_character}
232
+ transript = " ### Transcript ###\n"
233
+ for index, text in enumerate(ocr_results):
234
+ if index in text_to_character:
235
+ speaker = character_clusters[text_to_character[index]]
236
+ speaker = f"<{speaker}>"
237
+ else:
238
+ speaker = "<?>"
239
+ transript += f"{speaker}: {text}\n"
240
+ if filename is not None:
241
+ with open(filename, "w") as file:
242
+ file.write(transript)
243
+ return transript
244
+
245
+ def get_text_character_affinity_matrices_given_annotations(
246
+ self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
247
+ ):
248
+ assert not self.config.disable_detections
249
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
250
+
251
+ inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
252
+ inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
253
+ processed_targets = inputs_to_detection_transformer.pop("labels")
254
+
255
+ detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
256
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
257
+ predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
258
+
259
+ predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
260
+ matching_dict = {
261
+ "logits": predicted_class_scores,
262
+ "pred_boxes": predicted_bboxes,
263
+ }
264
+ indices = self.matcher(matching_dict, processed_targets)
265
+
266
+ matched_char_obj_tokens_for_batch = []
267
+ matched_text_obj_tokens_for_batch = []
268
+ t2c_tokens_for_batch = []
269
+
270
+ text_bboxes_for_batch = []
271
+ character_bboxes_for_batch = []
272
+
273
+ for j, (pred_idx, tgt_idx) in enumerate(indices):
274
+ target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
275
+ targets_for_this_image = processed_targets[j]
276
+ indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
277
+ indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
278
+ predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
279
+ predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
280
+
281
+ text_bboxes_for_batch.append(
282
+ [annotations[j]["bboxes_as_x1y1x2y2"][k] for k in indices_of_text_boxes_in_annotation]
283
+ )
284
+ character_bboxes_for_batch.append(
285
+ [annotations[j]["bboxes_as_x1y1x2y2"][k] for k in indices_of_char_boxes_in_annotation]
286
+ )
287
+
288
+ matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
289
+ matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
290
+ t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
291
+
292
+ text_character_affinity_matrices = self._get_text_character_affinity_matrices(
293
+ character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
294
+ text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
295
+ t2c_tokens_for_batch=t2c_tokens_for_batch,
296
+ apply_sigmoid=apply_sigmoid,
297
+ )
298
+
299
+ return {
300
+ "text_character_affinity_matrices": text_character_affinity_matrices,
301
+ "text_bboxes_for_batch": text_bboxes_for_batch,
302
+ "character_bboxes_for_batch": character_bboxes_for_batch,
303
+ }
304
+
305
+ def get_obj_embeddings_corresponding_to_given_annotations(
306
+ self, images, annotations, move_to_device_fn=None
307
+ ):
308
+ assert not self.config.disable_detections
309
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
310
+
311
+ inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
312
+ inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
313
+ processed_targets = inputs_to_detection_transformer.pop("labels")
314
+
315
+ detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
316
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
317
+ predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
318
+ predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
319
+
320
+ predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
321
+ matching_dict = {
322
+ "logits": predicted_class_scores,
323
+ "pred_boxes": predicted_bboxes,
324
+ }
325
+ indices = self.matcher(matching_dict, processed_targets)
326
+
327
+ matched_char_obj_tokens_for_batch = []
328
+ matched_text_obj_tokens_for_batch = []
329
+ matched_panel_obj_tokens_for_batch = []
330
+ t2c_tokens_for_batch = []
331
+ c2c_tokens_for_batch = []
332
+
333
+ for j, (pred_idx, tgt_idx) in enumerate(indices):
334
+ target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
335
+ targets_for_this_image = processed_targets[j]
336
+ indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
337
+ indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
338
+ indices_of_panel_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 2]
339
+ predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
340
+ predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
341
+ predicted_panel_indices = [target_idx_to_pred_idx[i] for i in indices_of_panel_boxes_in_annotation]
342
+
343
+ matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
344
+ matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
345
+ matched_panel_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_panel_indices])
346
+ t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
347
+ c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
348
+
349
+ return {
350
+ "character": matched_char_obj_tokens_for_batch,
351
+ "text": matched_text_obj_tokens_for_batch,
352
+ "panel": matched_panel_obj_tokens_for_batch,
353
+ "t2c": t2c_tokens_for_batch,
354
+ "c2c": c2c_tokens_for_batch,
355
+ }
356
+
357
+ def sort_panels_and_text_bboxes_in_reading_order(
358
+ self,
359
+ batch_panel_bboxes,
360
+ batch_text_bboxes,
361
+ ):
362
+ batch_sorted_panel_indices = []
363
+ batch_sorted_text_indices = []
364
+ for batch_index in range(len(batch_text_bboxes)):
365
+ panel_bboxes = batch_panel_bboxes[batch_index]
366
+ text_bboxes = batch_text_bboxes[batch_index]
367
+ sorted_panel_indices = sort_panels(panel_bboxes)
368
+ sorted_panels = [panel_bboxes[i] for i in sorted_panel_indices]
369
+ sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, sorted_panels)
370
+ batch_sorted_panel_indices.append(sorted_panel_indices)
371
+ batch_sorted_text_indices.append(sorted_text_indices)
372
+ return batch_sorted_panel_indices, batch_sorted_text_indices
373
+
374
+ def _get_detection_transformer_output(
375
+ self,
376
+ pixel_values: torch.FloatTensor,
377
+ pixel_mask: Optional[torch.LongTensor] = None
378
+ ):
379
+ if self.config.disable_detections:
380
+ raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
381
+ return self.detection_transformer(
382
+ pixel_values=pixel_values,
383
+ pixel_mask=pixel_mask,
384
+ return_dict=True
385
+ )
386
+
387
+ def _get_predicted_obj_tokens(
388
+ self,
389
+ detection_transformer_output: ConditionalDetrModelOutput
390
+ ):
391
+ return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens]
392
+
393
+ def _get_predicted_c2c_tokens(
394
+ self,
395
+ detection_transformer_output: ConditionalDetrModelOutput
396
+ ):
397
+ return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens]
398
+
399
+ def _get_predicted_t2c_tokens(
400
+ self,
401
+ detection_transformer_output: ConditionalDetrModelOutput
402
+ ):
403
+ return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1]
404
+
405
+ def _get_predicted_bboxes_and_classes(
406
+ self,
407
+ detection_transformer_output: ConditionalDetrModelOutput,
408
+ ):
409
+ if self.config.disable_detections:
410
+ raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
411
+
412
+ obj = self._get_predicted_obj_tokens(detection_transformer_output)
413
+
414
+ predicted_class_scores = self.class_labels_classifier(obj)
415
+ reference = detection_transformer_output.reference_points[:-self.num_non_obj_tokens]
416
+ reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
417
+ predicted_boxes = self.bbox_predictor(obj)
418
+ predicted_boxes[..., :2] += reference_before_sigmoid
419
+ predicted_boxes = predicted_boxes.sigmoid()
420
+
421
+ return predicted_class_scores, predicted_boxes
422
+
423
+ def _get_character_character_affinity_matrices(
424
+ self,
425
+ character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
426
+ crop_embeddings_for_batch: List[torch.FloatTensor] = None,
427
+ c2c_tokens_for_batch: List[torch.FloatTensor] = None,
428
+ apply_sigmoid=True,
429
+ ):
430
+ assert self.config.disable_detections or (character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None)
431
+ assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None
432
+ assert not self.config.disable_detections or not self.config.disable_crop_embeddings
433
+
434
+ if self.config.disable_detections:
435
+ affinity_matrices = []
436
+ for crop_embeddings in crop_embeddings_for_batch:
437
+ crop_embeddings = crop_embeddings / crop_embeddings.norm(dim=-1, keepdim=True)
438
+ affinity_matrix = einsum("i d, j d -> i j", affinity_matrix)
439
+ affinity_matrices.append(affinity_matrix)
440
+ return affinity_matrices
441
+ affinity_matrices = []
442
+ for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)):
443
+ if character_obj_tokens.shape[0] == 0:
444
+ affinity_matrices.append(torch.zeros(0, 0).type_as(character_obj_tokens))
445
+ continue
446
+ if not self.config.disable_crop_embeddings:
447
+ crop_embeddings = crop_embeddings_for_batch[batch_index]
448
+ assert character_obj_tokens.shape[0] == crop_embeddings.shape[0]
449
+ character_obj_tokens = torch.cat([character_obj_tokens, crop_embeddings], dim=-1)
450
+ char_i = repeat(character_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
451
+ char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=character_obj_tokens.shape[0])
452
+ char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
453
+ c2c = repeat(c2c, "d -> repeat d", repeat = char_ij.shape[0])
454
+ char_ij_c2c = torch.cat([char_ij, c2c], dim=-1)
455
+ character_character_affinities = self.character_character_matching_head(char_ij_c2c)
456
+ character_character_affinities = rearrange(character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
457
+ if apply_sigmoid:
458
+ character_character_affinities = character_character_affinities.sigmoid()
459
+ affinity_matrices.append(character_character_affinities)
460
+ return affinity_matrices
461
+
462
+ def _get_text_character_affinity_matrices(
463
+ self,
464
+ character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
465
+ text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
466
+ t2c_tokens_for_batch: List[torch.FloatTensor] = None,
467
+ apply_sigmoid=True,
468
+ ):
469
+ assert not self.config.disable_detections
470
+ assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None
471
+ affinity_matrices = []
472
+ for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch):
473
+ if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
474
+ affinity_matrices.append(torch.zeros(text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens))
475
+ continue
476
+ text_i = repeat(text_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
477
+ char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=text_obj_tokens.shape[0])
478
+ text_char = rearrange([text_i, char_j], "two i j d -> (i j) (two d)")
479
+ t2c = repeat(t2c, "d -> repeat d", repeat = text_char.shape[0])
480
+ text_char_t2c = torch.cat([text_char, t2c], dim=-1)
481
+ text_character_affinities = self.text_character_matching_head(text_char_t2c)
482
+ text_character_affinities = rearrange(text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
483
+ if apply_sigmoid:
484
+ text_character_affinities = text_character_affinities.sigmoid()
485
+ affinity_matrices.append(text_character_affinities)
486
+ return affinity_matrices
processing_magi.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor
2
+ from transformers.image_transforms import center_to_corners_format
3
+ import torch
4
+ from typing import List
5
+ from shapely.geometry import box
6
+ from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order, x1y1x2y2_to_xywh
7
+ import numpy as np
8
+
9
+ class MagiProcessor():
10
+ def __init__(self, config):
11
+ self.config = config
12
+ self.detection_image_preprocessor = None
13
+ self.ocr_preprocessor = None
14
+ self.crop_embedding_image_preprocessor = None
15
+ if not config.disable_detections:
16
+ assert config.detection_image_preprocessing_config is not None
17
+ self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(config.detection_image_preprocessing_config)
18
+ if not config.disable_ocr:
19
+ assert config.ocr_pretrained_processor_path is not None
20
+ self.ocr_preprocessor = TrOCRProcessor.from_pretrained(config.ocr_pretrained_processor_path)
21
+ if not config.disable_crop_embeddings:
22
+ assert config.crop_embedding_image_preprocessing_config is not None
23
+ self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config)
24
+
25
+ def preprocess_inputs_for_detection(self, images, annotations=None):
26
+ images = list(images)
27
+ assert isinstance(images[0], np.ndarray)
28
+ annotations = self._convert_annotations_to_coco_format(annotations)
29
+ inputs = self.detection_image_preprocessor(images, annotations=annotations, return_tensors="pt")
30
+ return inputs
31
+
32
+ def preprocess_inputs_for_ocr(self, images):
33
+ images = list(images)
34
+ assert isinstance(images[0], np.ndarray)
35
+ return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
36
+
37
+ def preprocess_inputs_for_crop_embeddings(self, images):
38
+ images = list(images)
39
+ assert isinstance(images[0], np.ndarray)
40
+ return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
41
+
42
+ def postprocess_detections_and_associations(
43
+ self,
44
+ predicted_bboxes,
45
+ predicted_class_scores,
46
+ original_image_sizes,
47
+ get_character_character_matching_scores,
48
+ get_text_character_matching_scores,
49
+ get_dialog_confidence_scores,
50
+ character_detection_threshold=0.3,
51
+ panel_detection_threshold=0.2,
52
+ text_detection_threshold=0.25,
53
+ character_character_matching_threshold=0.7,
54
+ text_character_matching_threshold=0.4,
55
+ ):
56
+ assert self.config.disable_detections is False
57
+ batch_scores, batch_labels = predicted_class_scores.max(-1)
58
+ batch_scores = batch_scores.sigmoid()
59
+ batch_labels = batch_labels.long()
60
+ batch_bboxes = center_to_corners_format(predicted_bboxes)
61
+
62
+ # scale the bboxes back to the original image size
63
+ if isinstance(original_image_sizes, List):
64
+ img_h = torch.Tensor([i[0] for i in original_image_sizes])
65
+ img_w = torch.Tensor([i[1] for i in original_image_sizes])
66
+ else:
67
+ img_h, img_w = original_image_sizes.unbind(1)
68
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device)
69
+ batch_bboxes = batch_bboxes * scale_fct[:, None, :]
70
+
71
+ batch_panel_indices = self._get_indices_of_panels_to_keep(batch_scores, batch_labels, batch_bboxes, panel_detection_threshold)
72
+ batch_character_indices = self._get_indices_of_characters_to_keep(batch_scores, batch_labels, batch_bboxes, character_detection_threshold)
73
+ batch_text_indices = self._get_indices_of_texts_to_keep(batch_scores, batch_labels, batch_bboxes, text_detection_threshold)
74
+
75
+ batch_character_character_matching_scores = get_character_character_matching_scores(batch_character_indices, batch_bboxes)
76
+ batch_text_character_matching_scores = get_text_character_matching_scores(batch_text_indices, batch_character_indices)
77
+ batch_dialog_confidence_scores = get_dialog_confidence_scores(batch_text_indices)
78
+
79
+ # sort panels and texts in the reading order
80
+ for batch_index in range(len(batch_scores)):
81
+ panel_bboxes = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
82
+ panel_scores = batch_scores[batch_index][batch_panel_indices[batch_index]]
83
+ text_bboxes = batch_bboxes[batch_index][batch_text_indices[batch_index]]
84
+ text_scores = batch_scores[batch_index][batch_text_indices[batch_index]]
85
+
86
+ sorted_panel_indices = sort_panels(panel_bboxes)
87
+ batch_bboxes[batch_index][batch_panel_indices[batch_index]] = panel_bboxes[sorted_panel_indices]
88
+ batch_scores[batch_index][batch_panel_indices[batch_index]] = panel_scores[sorted_panel_indices]
89
+ sorted_panels = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
90
+
91
+ sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, sorted_panels)
92
+ batch_bboxes[batch_index][batch_text_indices[batch_index]] = text_bboxes[sorted_text_indices]
93
+ batch_scores[batch_index][batch_text_indices[batch_index]] = text_scores[sorted_text_indices]
94
+ batch_text_character_matching_scores[batch_index] = batch_text_character_matching_scores[batch_index][sorted_text_indices]
95
+ batch_dialog_confidence_scores[batch_index] = batch_dialog_confidence_scores[batch_index][sorted_text_indices]
96
+
97
+ results = []
98
+ for batch_index in range(len(batch_scores)):
99
+ panel_bboxes = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
100
+ panel_scores = batch_scores[batch_index][batch_panel_indices[batch_index]]
101
+ text_bboxes = batch_bboxes[batch_index][batch_text_indices[batch_index]]
102
+ text_scores = batch_scores[batch_index][batch_text_indices[batch_index]]
103
+ character_bboxes = batch_bboxes[batch_index][batch_character_indices[batch_index]]
104
+ character_scores = batch_scores[batch_index][batch_character_indices[batch_index]]
105
+ char_i, char_j = torch.where(batch_character_character_matching_scores[batch_index] > character_character_matching_threshold)
106
+ character_character_associations = torch.stack([char_i, char_j], dim=1)
107
+ text_boxes_to_match = batch_dialog_confidence_scores[batch_index] > text_character_matching_threshold
108
+ if 0 in batch_text_character_matching_scores[batch_index].shape:
109
+ text_character_associations = torch.zeros((0, 2), dtype=torch.long)
110
+ else:
111
+ most_likely_speaker_for_each_text = torch.argmax(batch_text_character_matching_scores[batch_index], dim=1)[text_boxes_to_match]
112
+ text_indices = torch.arange(len(text_bboxes)).type_as(most_likely_speaker_for_each_text)[text_boxes_to_match]
113
+ text_character_associations = torch.stack([text_indices, most_likely_speaker_for_each_text], dim=1)
114
+
115
+ character_ufds = UnionFind.from_adj_matrix(
116
+ batch_character_character_matching_scores[batch_index] > character_character_matching_threshold
117
+ )
118
+ results.append({
119
+ "panels": panel_bboxes.tolist(),
120
+ "panel_scores": panel_scores.tolist(),
121
+ "texts": text_bboxes.tolist(),
122
+ "text_scores": text_scores.tolist(),
123
+ "characters": character_bboxes.tolist(),
124
+ "character_scores": character_scores.tolist(),
125
+ "character_character_associations": character_character_associations.tolist(),
126
+ "text_character_associations": text_character_associations.tolist(),
127
+ "character_cluster_labels": character_ufds.get_labels_for_connected_components(),
128
+ "dialog_confidences": batch_dialog_confidence_scores[batch_index].tolist(),
129
+ })
130
+ return results
131
+
132
+ def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
133
+ return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
134
+
135
+ def crop_image(self, image, bboxes):
136
+ crops_for_image = []
137
+ for bbox in bboxes:
138
+ x1, y1, x2, y2 = bbox
139
+
140
+ # fix the bounding box in case it is out of bounds or too small
141
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
142
+ x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2) # just incase
143
+ x1, y1 = max(0, x1), max(0, y1)
144
+ x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
145
+ x2, y2 = max(0, x2), max(0, y2)
146
+ x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
147
+ if x2 - x1 < 10:
148
+ if image.shape[1] - x1 > 10:
149
+ x2 = x1 + 10
150
+ else:
151
+ x1 = x2 - 10
152
+ if y2 - y1 < 10:
153
+ if image.shape[0] - y1 > 10:
154
+ y2 = y1 + 10
155
+ else:
156
+ y1 = y2 - 10
157
+
158
+ crop = image[y1:y2, x1:x2]
159
+ crops_for_image.append(crop)
160
+ return crops_for_image
161
+
162
+ def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
163
+ indices_of_characters_to_keep = []
164
+ for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
165
+ indices = torch.where((labels == 0) & (scores > character_detection_threshold))[0]
166
+ indices_of_characters_to_keep.append(indices)
167
+ return indices_of_characters_to_keep
168
+
169
+ def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
170
+ indices_of_panels_to_keep = []
171
+ for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
172
+ indices = torch.where(labels == 2)[0]
173
+ bboxes = bboxes[indices]
174
+ scores = scores[indices]
175
+ labels = labels[indices]
176
+ if len(indices) == 0:
177
+ indices_of_panels_to_keep.append([])
178
+ continue
179
+ scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
180
+ panels_to_keep = []
181
+ union_of_panels_so_far = box(0, 0, 0, 0)
182
+ for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
183
+ panel_polygon = box(pb[0], pb[1], pb[2], pb[3])
184
+ if ps < panel_detection_threshold:
185
+ continue
186
+ if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
187
+ continue
188
+ panels_to_keep.append((ps, pl, pb, pi))
189
+ union_of_panels_so_far = union_of_panels_so_far.union(panel_polygon)
190
+ indices_of_panels_to_keep.append([p[3].item() for p in panels_to_keep])
191
+ return indices_of_panels_to_keep
192
+
193
+ def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
194
+ indices_of_texts_to_keep = []
195
+ for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
196
+ indices = torch.where((labels == 1) & (scores > text_detection_threshold))[0]
197
+ bboxes = bboxes[indices]
198
+ scores = scores[indices]
199
+ labels = labels[indices]
200
+ if len(indices) == 0:
201
+ indices_of_texts_to_keep.append([])
202
+ continue
203
+ scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
204
+ texts_to_keep = []
205
+ texts_to_keep_as_shapely_objects = []
206
+ for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
207
+ text_polygon = box(tb[0], tb[1], tb[2], tb[3])
208
+ should_append = True
209
+ for t in texts_to_keep_as_shapely_objects:
210
+ if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
211
+ should_append = False
212
+ break
213
+ if should_append:
214
+ texts_to_keep.append((ts, tl, tb, ti))
215
+ texts_to_keep_as_shapely_objects.append(text_polygon)
216
+ indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
217
+ return indices_of_texts_to_keep
218
+
219
+ def _convert_annotations_to_coco_format(self, annotations):
220
+ if annotations is None:
221
+ return None
222
+ self._verify_annotations_are_in_correct_format(annotations)
223
+ coco_annotations = []
224
+ for annotation in annotations:
225
+ coco_annotation = {
226
+ "image_id": annotation["image_id"],
227
+ "annotations": [],
228
+ }
229
+ for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]):
230
+ coco_annotation["annotations"].append({
231
+ "bbox": x1y1x2y2_to_xywh(bbox),
232
+ "category_id": label,
233
+ "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
234
+ })
235
+ coco_annotations.append(coco_annotation)
236
+ return coco_annotations
237
+
238
+ def _verify_annotations_are_in_correct_format(self, annotations):
239
+ error_msg = """
240
+ Annotations must be in the following format:
241
+ [
242
+ {
243
+ "image_id": 0,
244
+ "bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]],
245
+ "labels": [0, 1, 2],
246
+ },
247
+ ...
248
+ ]
249
+ Labels: 0 for characters, 1 for text, 2 for panels.
250
+ """
251
+ if annotations is None:
252
+ return
253
+ if not isinstance(annotations, List) and not isinstance(annotations, tuple):
254
+ raise ValueError(
255
+ f"{error_msg} Expected a List/Tuple, found {type(annotations)}."
256
+ )
257
+ if len(annotations) == 0:
258
+ return
259
+ if not isinstance(annotations[0], dict):
260
+ raise ValueError(
261
+ f"{error_msg} Expected a List[Dict], found {type(annotations[0])}."
262
+ )
263
+ if "image_id" not in annotations[0]:
264
+ raise ValueError(
265
+ f"{error_msg} Dict must contain 'image_id'."
266
+ )
267
+ if "bboxes_as_x1y1x2y2" not in annotations[0]:
268
+ raise ValueError(
269
+ f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'."
270
+ )
271
+ if "labels" not in annotations[0]:
272
+ raise ValueError(
273
+ f"{error_msg} Dict must contain 'labels'."
274
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:219c2b80e741b1d02e92f22701a38358a5606d6460ad8b6335091e909b212011
3
+ size 2063428286
utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as patches
6
+ from shapely.geometry import Point, box
7
+ import networkx as nx
8
+ from copy import deepcopy
9
+ from itertools import groupby
10
+
11
+ def move_to_device(inputs, device):
12
+ if hasattr(inputs, "keys"):
13
+ return {k: move_to_device(v, device) for k, v in inputs.items()}
14
+ elif isinstance(inputs, list):
15
+ return [move_to_device(v, device) for v in inputs]
16
+ elif isinstance(inputs, tuple):
17
+ return tuple([move_to_device(v, device) for v in inputs])
18
+ elif isinstance(inputs, np.ndarray):
19
+ return torch.from_numpy(inputs).to(device)
20
+ else:
21
+ return inputs.to(device)
22
+
23
+ class UnionFind:
24
+ def __init__(self, n):
25
+ self.parent = list(range(n))
26
+ self.size = [1] * n
27
+ self.num_components = n
28
+
29
+ @classmethod
30
+ def from_adj_matrix(cls, adj_matrix):
31
+ ufds = cls(adj_matrix.shape[0])
32
+ for i in range(adj_matrix.shape[0]):
33
+ for j in range(adj_matrix.shape[1]):
34
+ if adj_matrix[i, j] > 0:
35
+ ufds.unite(i, j)
36
+ return ufds
37
+
38
+ @classmethod
39
+ def from_adj_list(cls, adj_list):
40
+ ufds = cls(len(adj_list))
41
+ for i in range(len(adj_list)):
42
+ for j in adj_list[i]:
43
+ ufds.unite(i, j)
44
+ return ufds
45
+
46
+ @classmethod
47
+ def from_edge_list(cls, edge_list, num_nodes):
48
+ ufds = cls(num_nodes)
49
+ for edge in edge_list:
50
+ ufds.unite(edge[0], edge[1])
51
+ return ufds
52
+
53
+ def find(self, x):
54
+ if self.parent[x] == x:
55
+ return x
56
+ self.parent[x] = self.find(self.parent[x])
57
+ return self.parent[x]
58
+
59
+ def unite(self, x, y):
60
+ x = self.find(x)
61
+ y = self.find(y)
62
+ if x != y:
63
+ if self.size[x] < self.size[y]:
64
+ x, y = y, x
65
+ self.parent[y] = x
66
+ self.size[x] += self.size[y]
67
+ self.num_components -= 1
68
+
69
+ def get_components_of(self, x):
70
+ x = self.find(x)
71
+ return [i for i in range(len(self.parent)) if self.find(i) == x]
72
+
73
+ def are_connected(self, x, y):
74
+ return self.find(x) == self.find(y)
75
+
76
+ def get_size(self, x):
77
+ return self.size[self.find(x)]
78
+
79
+ def get_num_components(self):
80
+ return self.num_components
81
+
82
+ def get_labels_for_connected_components(self):
83
+ map_parent_to_label = {}
84
+ labels = []
85
+ for i in range(len(self.parent)):
86
+ parent = self.find(i)
87
+ if parent not in map_parent_to_label:
88
+ map_parent_to_label[parent] = len(map_parent_to_label)
89
+ labels.append(map_parent_to_label[parent])
90
+ return labels
91
+
92
+ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
93
+ figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
94
+ subplot.imshow(image_as_np_array)
95
+ plot_bboxes(subplot, predictions["panels"], color="green")
96
+ plot_bboxes(subplot, predictions["texts"], color="red", add_index=True)
97
+ plot_bboxes(subplot, predictions["characters"], color="blue")
98
+
99
+ COLOURS = [
100
+ "#b7ff51", # green
101
+ "#f50a8f", # pink
102
+ "#4b13b6", # purple
103
+ "#ddaa34", # orange
104
+ "#bea2a2", # brown
105
+ ]
106
+ colour_index = 0
107
+ character_cluster_labels = predictions["character_cluster_labels"]
108
+ unique_label_sorted_by_frequency = sorted(list(set(character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
109
+ for label in unique_label_sorted_by_frequency:
110
+ root = None
111
+ others = []
112
+ for i in range(len(predictions["characters"])):
113
+ if character_cluster_labels[i] == label:
114
+ if root is None:
115
+ root = i
116
+ else:
117
+ others.append(i)
118
+ if colour_index >= len(COLOURS):
119
+ random_colour = COLOURS[0]
120
+ while random_colour in COLOURS:
121
+ random_colour = "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
122
+ else:
123
+ random_colour = COLOURS[colour_index]
124
+ colour_index += 1
125
+ bbox_i = predictions["characters"][root]
126
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
127
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
128
+ subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
129
+ for j in others:
130
+ # draw line from centre of bbox i to centre of bbox j
131
+ bbox_j = predictions["characters"][j]
132
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
133
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
134
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
135
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
136
+ subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
137
+ subplot.plot([x2], [y2], color=random_colour, marker="o", markersize=5)
138
+
139
+ for (i, j) in predictions["text_character_associations"]:
140
+ score = predictions["dialog_confidences"][i]
141
+ bbox_i = predictions["texts"][i]
142
+ bbox_j = predictions["characters"][j]
143
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
144
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
145
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
146
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
147
+ subplot.plot([x1, x2], [y1, y2], color="red", linewidth=2, linestyle="dashed", alpha=score)
148
+
149
+ subplot.axis("off")
150
+ if filename is not None:
151
+ plt.savefig(filename, bbox_inches="tight", pad_inches=0)
152
+
153
+ figure.canvas.draw()
154
+ image = np.array(figure.canvas.renderer._renderer)
155
+ plt.close()
156
+ return image
157
+
158
+ def plot_bboxes(subplot, bboxes, color="red", add_index=False):
159
+ for id, bbox in enumerate(bboxes):
160
+ w = bbox[2] - bbox[0]
161
+ h = bbox[3] - bbox[1]
162
+ rect = patches.Rectangle(
163
+ bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
164
+ )
165
+ subplot.add_patch(rect)
166
+ if add_index:
167
+ cx, cy = bbox[0] + w / 2, bbox[1] + h / 2
168
+ subplot.text(cx, cy, str(id), color=color, fontsize=10, ha="center", va="center")
169
+
170
+ def sort_panels(rects):
171
+ before_rects = convert_to_list_of_lists(rects)
172
+ # slightly erode all rectangles initially to account for imperfect detections
173
+ rects = [erode_rectangle(rect, 0.05) for rect in before_rects]
174
+ G = nx.DiGraph()
175
+ G.add_nodes_from(range(len(rects)))
176
+ for i in range(len(rects)):
177
+ for j in range(len(rects)):
178
+ if i == j:
179
+ continue
180
+ if is_there_a_directed_edge(i, j, rects):
181
+ G.add_edge(i, j, weight=get_distance(rects[i], rects[j]))
182
+ else:
183
+ G.add_edge(j, i, weight=get_distance(rects[i], rects[j]))
184
+ while True:
185
+ cycles = sorted(nx.simple_cycles(G))
186
+ cycles = [cycle for cycle in cycles if len(cycle) > 1]
187
+ if len(cycles) == 0:
188
+ break
189
+ cycle = cycles[0]
190
+ edges = [e for e in zip(cycle, cycle[1:] + cycle[:1])]
191
+ max_cyclic_edge = max(edges, key=lambda x: G.edges[x]["weight"])
192
+ G.remove_edge(*max_cyclic_edge)
193
+ return list(nx.topological_sort(G))
194
+
195
+ def is_strictly_above(rectA, rectB):
196
+ x1A, y1A, x2A, y2A = rectA
197
+ x1B, y1B, x2B, y2B = rectB
198
+ return y2A < y1B
199
+
200
+ def is_strictly_below(rectA, rectB):
201
+ x1A, y1A, x2A, y2A = rectA
202
+ x1B, y1B, x2B, y2B = rectB
203
+ return y2B < y1A
204
+
205
+ def is_strictly_left_of(rectA, rectB):
206
+ x1A, y1A, x2A, y2A = rectA
207
+ x1B, y1B, x2B, y2B = rectB
208
+ return x2A < x1B
209
+
210
+ def is_strictly_right_of(rectA, rectB):
211
+ x1A, y1A, x2A, y2A = rectA
212
+ x1B, y1B, x2B, y2B = rectB
213
+ return x2B < x1A
214
+
215
+ def intersects(rectA, rectB):
216
+ return box(*rectA).intersects(box(*rectB))
217
+
218
+ def is_there_a_directed_edge(a, b, rects):
219
+ rectA = rects[a]
220
+ rectB = rects[b]
221
+ centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2, rectA[1] + (rectA[3] - rectA[1]) / 2]
222
+ centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2, rectB[1] + (rectB[3] - rectB[1]) / 2]
223
+ if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
224
+ return box(*rectA).area > (box(*rectB)).area
225
+ copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
226
+ copy_B = [rectB[0], rectB[1], rectB[2], rectB[3]]
227
+ while True:
228
+ if is_strictly_above(copy_A, copy_B) and not is_strictly_left_of(copy_A, copy_B):
229
+ return 1
230
+ if is_strictly_above(copy_B, copy_A) and not is_strictly_left_of(copy_B, copy_A):
231
+ return 0
232
+ if is_strictly_right_of(copy_A, copy_B) and not is_strictly_below(copy_A, copy_B):
233
+ return 1
234
+ if is_strictly_right_of(copy_B, copy_A) and not is_strictly_below(copy_B, copy_A):
235
+ return 0
236
+ if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
237
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
238
+ if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
239
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
240
+ # otherwise they intersect
241
+ copy_A = erode_rectangle(copy_A, 0.05)
242
+ copy_B = erode_rectangle(copy_B, 0.05)
243
+
244
+ def get_distance(rectA, rectB):
245
+ return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
246
+
247
+ def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
248
+ rects = deepcopy(rects)
249
+ while True:
250
+ xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
251
+ rect_index = [i for i in range(len(rects)) if intersects(rects[i], [xmin, ymin, xmax, ymax])]
252
+ rects_copy = [rect for rect in rects if intersects(rect, [xmin, ymin, xmax, ymax])]
253
+
254
+ # try to split the panels using a "horizontal" lines
255
+ overlapping_y_ranges = merge_overlapping_ranges([(y1, y2) for x1, y1, x2, y2 in rects_copy])
256
+ panel_index_to_split = {}
257
+ for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
258
+ for i, index in enumerate(rect_index):
259
+ if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
260
+ panel_index_to_split[index] = split_index
261
+
262
+ if panel_index_to_split[a] != panel_index_to_split[b]:
263
+ return panel_index_to_split[a] < panel_index_to_split[b]
264
+
265
+ # try to split the panels using a "vertical" lines
266
+ overlapping_x_ranges = merge_overlapping_ranges([(x1, x2) for x1, y1, x2, y2 in rects_copy])
267
+ panel_index_to_split = {}
268
+ for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
269
+ for i, index in enumerate(rect_index):
270
+ if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
271
+ panel_index_to_split[index] = split_index
272
+ if panel_index_to_split[a] != panel_index_to_split[b]:
273
+ return panel_index_to_split[a] < panel_index_to_split[b]
274
+
275
+ # otherwise, erode the rectangles and try again
276
+ rects = [erode_rectangle(rect, 0.05) for rect in rects]
277
+
278
+ def erode_rectangle(bbox, erosion_factor):
279
+ x1, y1, x2, y2 = bbox
280
+ w, h = x2 - x1, y2 - y1
281
+ cx, cy = x1 + w / 2, y1 + h / 2
282
+ if w < h:
283
+ aspect_ratio = w / h
284
+ erosion_factor_width = erosion_factor * aspect_ratio
285
+ erosion_factor_height = erosion_factor
286
+ else:
287
+ aspect_ratio = h / w
288
+ erosion_factor_width = erosion_factor
289
+ erosion_factor_height = erosion_factor * aspect_ratio
290
+ w = w - w * erosion_factor_width
291
+ h = h - h * erosion_factor_height
292
+ x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
293
+ return [x1, y1, x2, y2]
294
+
295
+ def merge_overlapping_ranges(ranges):
296
+ """
297
+ ranges: list of tuples (x1, x2)
298
+ """
299
+ if len(ranges) == 0:
300
+ return []
301
+ ranges = sorted(ranges, key=lambda x: x[0])
302
+ merged_ranges = []
303
+ for i, r in enumerate(ranges):
304
+ if i == 0:
305
+ prev_x1, prev_x2 = r
306
+ continue
307
+ x1, x2 = r
308
+ if x1 > prev_x2:
309
+ merged_ranges.append((prev_x1, prev_x2))
310
+ prev_x1, prev_x2 = x1, x2
311
+ else:
312
+ prev_x2 = max(prev_x2, x2)
313
+ merged_ranges.append((prev_x1, prev_x2))
314
+ return merged_ranges
315
+
316
+ def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
317
+ text_bboxes = convert_to_list_of_lists(text_bboxes)
318
+ sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
319
+
320
+ if len(text_bboxes) == 0:
321
+ return []
322
+
323
+ def indices_of_same_elements(nums):
324
+ groups = groupby(range(len(nums)), key=lambda i: nums[i])
325
+ return [list(indices) for _, indices in groups]
326
+
327
+ panel_id_for_text = get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes)
328
+ indices_of_texts = list(range(len(text_bboxes)))
329
+ indices_of_texts, panel_id_for_text = zip(*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
330
+ indices_of_texts = list(indices_of_texts)
331
+ grouped_indices = indices_of_same_elements(panel_id_for_text)
332
+ for group in grouped_indices:
333
+ subset_of_text_indices = [indices_of_texts[i] for i in group]
334
+ text_bboxes_of_subset = [text_bboxes[i] for i in subset_of_text_indices]
335
+ sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
336
+ indices_of_texts[group[0] : group[-1] + 1] = [subset_of_text_indices[i] for i in sorted_subset_indices]
337
+ return indices_of_texts
338
+
339
+ def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
340
+ text_to_panel_mapping = []
341
+ for text_bbox in text_bboxes:
342
+ shapely_text_polygon = box(*text_bbox)
343
+ all_intersections = []
344
+ all_distances = []
345
+ if len(sorted_panel_bboxes) == 0:
346
+ text_to_panel_mapping.append(-1)
347
+ continue
348
+ for j, annotation in enumerate(sorted_panel_bboxes):
349
+ shapely_annotation_polygon = box(*annotation)
350
+ if shapely_text_polygon.intersects(shapely_annotation_polygon):
351
+ all_intersections.append((shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
352
+ all_distances.append((shapely_text_polygon.distance(shapely_annotation_polygon), j))
353
+ if len(all_intersections) == 0:
354
+ text_to_panel_mapping.append(min(all_distances, key=lambda x: x[0])[1])
355
+ else:
356
+ text_to_panel_mapping.append(max(all_intersections, key=lambda x: x[0])[1])
357
+ return text_to_panel_mapping
358
+
359
+ def sort_texts_within_panel(rects):
360
+ smallest_y = float("inf")
361
+ greatest_x = float("-inf")
362
+ for i, rect in enumerate(rects):
363
+ x1, y1, x2, y2 = rect
364
+ smallest_y = min(smallest_y, y1)
365
+ greatest_x = max(greatest_x, x2)
366
+
367
+ reference_point = Point(greatest_x, smallest_y)
368
+
369
+ polygons_and_index = []
370
+ for i, rect in enumerate(rects):
371
+ x1, y1, x2, y2 = rect
372
+ polygons_and_index.append((box(x1,y1,x2,y2), i))
373
+ # sort points by closest to reference point
374
+ polygons_and_index = sorted(polygons_and_index, key=lambda x: reference_point.distance(x[0]))
375
+ indices = [x[1] for x in polygons_and_index]
376
+ return indices
377
+
378
+ def x1y1wh_to_x1y1x2y2(bbox):
379
+ x1, y1, w, h = bbox
380
+ return [x1, y1, x1 + w, y1 + h]
381
+
382
+ def x1y1x2y2_to_xywh(bbox):
383
+ x1, y1, x2, y2 = bbox
384
+ return [x1, y1, x2 - x1, y2 - y1]
385
+
386
+ def convert_to_list_of_lists(rects):
387
+ if isinstance(rects, torch.Tensor):
388
+ return rects.tolist()
389
+ if isinstance(rects, np.ndarray):
390
+ return rects.tolist()
391
+ return [[a, b, c, d] for a, b, c, d in rects]