Upload processing_kosmos2.py
Browse files- processing_kosmos2.py +117 -2
processing_kosmos2.py
CHANGED
@@ -59,7 +59,7 @@ class Kosmos2Processor(ProcessorMixin):
|
|
59 |
"""
|
60 |
attributes = ["image_processor", "tokenizer"]
|
61 |
image_processor_class = "CLIPImageProcessor"
|
62 |
-
tokenizer_class =
|
63 |
|
64 |
def __init__(self, image_processor, tokenizer):
|
65 |
tokenizer.return_token_type_ids = False
|
@@ -332,7 +332,9 @@ class Kosmos2Processor(ProcessorMixin):
|
|
332 |
return self.tokenizer.decode(*args, **kwargs)
|
333 |
|
334 |
def post_processor_generation(self, text):
|
335 |
-
|
|
|
|
|
336 |
|
337 |
@property
|
338 |
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
|
@@ -455,6 +457,7 @@ def coordinate_to_patch_index(bbox: Tuple[float, float, float, float], num_patch
|
|
455 |
|
456 |
|
457 |
# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38
|
|
|
458 |
def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int):
|
459 |
"""
|
460 |
Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a
|
@@ -496,3 +499,115 @@ def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: in
|
|
496 |
y2 = lr_y * cell_size + cell_size / 2
|
497 |
|
498 |
return x1, y1, x2, y2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
"""
|
60 |
attributes = ["image_processor", "tokenizer"]
|
61 |
image_processor_class = "CLIPImageProcessor"
|
62 |
+
tokenizer_class = ("Kosmos2Tokenizer", "Kosmos2TokenizerFast")
|
63 |
|
64 |
def __init__(self, image_processor, tokenizer):
|
65 |
tokenizer.return_token_type_ids = False
|
|
|
332 |
return self.tokenizer.decode(*args, **kwargs)
|
333 |
|
334 |
def post_processor_generation(self, text):
|
335 |
+
|
336 |
+
caption = text.split("</image>")[-1]
|
337 |
+
return clean_text_and_extract_entities_with_bboxes(caption)
|
338 |
|
339 |
@property
|
340 |
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
|
|
|
457 |
|
458 |
|
459 |
# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38
|
460 |
+
# (with format modifications)
|
461 |
def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int):
|
462 |
"""
|
463 |
Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a
|
|
|
499 |
y2 = lr_y * cell_size + cell_size / 2
|
500 |
|
501 |
return x1, y1, x2, y2
|
502 |
+
|
503 |
+
|
504 |
+
# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L4-L33
|
505 |
+
# (with format modifications)
|
506 |
+
def extract_entities_with_patch_indices(text):
|
507 |
+
# The regular expression pattern for matching the required formats
|
508 |
+
pattern = r'(?:(<phrase>([^<]+)</phrase>))?<object>((?:<patch_index_\d+><patch_index_\d+></delimiter_of_multi_objects/>)*<patch_index_\d+><patch_index_\d+>)</object>'
|
509 |
+
|
510 |
+
# Find all matches in the given string
|
511 |
+
matches = re.finditer(pattern, text)
|
512 |
+
|
513 |
+
# Initialize an empty list to store the valid patch_index combinations
|
514 |
+
entities_with_patch_indices = []
|
515 |
+
|
516 |
+
for match in matches:
|
517 |
+
# span of a `phrase` that is between <phrase> and </phrase>
|
518 |
+
span = match.span(2)
|
519 |
+
phrase_tag, phrase, match_content = match.groups()
|
520 |
+
if not phrase_tag:
|
521 |
+
phrase = None
|
522 |
+
span = (None, None)
|
523 |
+
|
524 |
+
# Split the match_content by the delimiter to get individual patch_index pairs
|
525 |
+
patch_index_pairs = match_content.split('</delimiter_of_multi_objects/>')
|
526 |
+
|
527 |
+
entity_bboxes = []
|
528 |
+
for pair in patch_index_pairs:
|
529 |
+
# Extract the xxxx and yyyy values from the patch_index pair
|
530 |
+
x = re.search(r'<patch_index_(\d+)>', pair)
|
531 |
+
y = re.search(r'<patch_index_(\d+)>', pair[1:])
|
532 |
+
|
533 |
+
if x and y:
|
534 |
+
if phrase:
|
535 |
+
entity_bboxes.append((int(x.group(1)), int(y.group(1))))
|
536 |
+
else:
|
537 |
+
entity_bboxes.append((int(x.group(1)), int(y.group(1))))
|
538 |
+
|
539 |
+
if phrase:
|
540 |
+
entities_with_patch_indices.append((phrase, span, entity_bboxes))
|
541 |
+
else:
|
542 |
+
for bbox in entity_bboxes:
|
543 |
+
# fake entity name
|
544 |
+
entity = f"<patch_index_{bbox[0]}><patch_index_{bbox[1]}>"
|
545 |
+
entities_with_patch_indices.append((entity, span, [bbox]))
|
546 |
+
|
547 |
+
def cleanup_spaces(text, entities):
|
548 |
+
new_text = text.strip()
|
549 |
+
|
550 |
+
leading_spaces = text - text.lstrip(text)
|
551 |
+
|
552 |
+
new_entities = []
|
553 |
+
for entity_name, (start, end), bboxes in entities:
|
554 |
+
|
555 |
+
start = start - leading_spaces + (entity_name.lstrip(entity_name))
|
556 |
+
end = end - leading_spaces - (entity_name.rstrip(entity_name))
|
557 |
+
entity_name = entity_name.strip()
|
558 |
+
|
559 |
+
new_entities.append((entity_name, (start, end), bboxes))
|
560 |
+
|
561 |
+
return new_text, new_entities
|
562 |
+
|
563 |
+
return cleanup_spaces(entities_with_patch_indices)
|
564 |
+
|
565 |
+
|
566 |
+
# TODO: Be careful
|
567 |
+
def remove_special_fields(text):
|
568 |
+
return re.sub('<.*?>', '', text)
|
569 |
+
|
570 |
+
|
571 |
+
def adjust_entity_positions(entity, text):
|
572 |
+
|
573 |
+
entity_name, (start, end) = entity
|
574 |
+
adjusted_start = len(remove_special_fields(text[:start]))
|
575 |
+
adjusted_end = len(remove_special_fields(text[:end]))
|
576 |
+
adjusted_entity = (entity_name, (adjusted_start, adjusted_end))
|
577 |
+
return adjusted_entity
|
578 |
+
|
579 |
+
|
580 |
+
# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L77-L87
|
581 |
+
# (with format modifications)
|
582 |
+
def clean_text_and_extract_entities_with_bboxes(text, num_patches_per_side=32):
|
583 |
+
|
584 |
+
processed_text = remove_special_fields(text)
|
585 |
+
|
586 |
+
entities_with_patch_indices = extract_entities_with_patch_indices(text)
|
587 |
+
entities = []
|
588 |
+
for item in entities_with_patch_indices:
|
589 |
+
entity, bboxes = item[0:2], item[2]
|
590 |
+
adjusted_entity = adjust_entity_positions(entity, text)
|
591 |
+
bboxes_in_coords = list(map(lambda bbox: patch_index_to_coordinate(bbox[0], bbox[1], num_patches_per_side), bboxes))
|
592 |
+
|
593 |
+
entities.append((adjusted_entity) + (bboxes_in_coords,))
|
594 |
+
|
595 |
+
def cleanup_spaces(text, entities):
|
596 |
+
new_text = text.strip()
|
597 |
+
leading_spaces = len(text) - len(text.lstrip())
|
598 |
+
|
599 |
+
new_entities = []
|
600 |
+
for entity_name, (start, end), bboxes in entities:
|
601 |
+
|
602 |
+
entity_name_leading_spaces = len(entity_name) - len(entity_name.lstrip())
|
603 |
+
entity_name_trailing_spaces = len(entity_name) - len(entity_name.rstrip())
|
604 |
+
|
605 |
+
start = start - leading_spaces + entity_name_leading_spaces
|
606 |
+
end = end - leading_spaces - entity_name_trailing_spaces
|
607 |
+
entity_name = entity_name.strip()
|
608 |
+
|
609 |
+
new_entities.append((entity_name, (start, end), bboxes))
|
610 |
+
|
611 |
+
return new_text, new_entities
|
612 |
+
|
613 |
+
return cleanup_spaces(processed_text, entities)
|