Spaces:
Runtime error
Runtime error
""" | |
Author: Siyuan Li | |
Licensed: Apache-2.0 License | |
""" | |
import copy | |
import logging | |
import re | |
import warnings | |
from mmdet.registry import MODELS | |
from mmengine.logging import MMLogger, print_log | |
from mmengine.model.weight_init import (PretrainedInit, initialize, | |
update_init_info) | |
from .grounding_dino import GroundingDINO | |
def clean_label_name(name: str) -> str: | |
name = re.sub(r"\(.*\)", "", name) | |
name = re.sub(r"_", " ", name) | |
name = re.sub(r" ", " ", name) | |
return name | |
def chunks(lst: list, n: int) -> list: | |
"""Yield successive n-sized chunks from lst.""" | |
all_ = [] | |
for i in range(0, len(lst), n): | |
data_index = lst[i : i + n] | |
all_.append(data_index) | |
counter = 0 | |
for i in all_: | |
counter += len(i) | |
assert counter == len(lst) | |
return all_ | |
class GroundingDINOMasa(GroundingDINO): | |
"""Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- | |
Training for Open-Set Object Detection. | |
<https://arxiv.org/abs/2303.05499>`_ | |
Code is modified from the `official github repo | |
<https://github.com/IDEA-Research/GroundingDINO>`_. | |
""" | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.track_text_prompt = None | |
self.track_text_dict = None | |
self.token_positive_maps = None | |
self.track_entities = None | |
def init_weights(self) -> None: | |
"""Initialize weights for Transformer and other components.""" | |
if self.init_cfg: | |
print_log( | |
f"initialize {self.__class__.__name__} with init_cfg {self.init_cfg}", | |
logger="current", | |
level=logging.DEBUG, | |
) | |
init_cfgs = self.init_cfg | |
if isinstance(self.init_cfg, dict): | |
init_cfgs = [self.init_cfg] | |
# PretrainedInit has higher priority than any other init_cfg. | |
# Therefore we initialize `pretrained_cfg` last to overwrite | |
# the previous initialized weights. | |
# See details in https://github.com/open-mmlab/mmengine/issues/691 # noqa E501 | |
other_cfgs = [] | |
pretrained_cfg = [] | |
for init_cfg in init_cfgs: | |
assert isinstance(init_cfg, dict) | |
if ( | |
init_cfg["type"] == "Pretrained" | |
or init_cfg["type"] is PretrainedInit | |
): | |
pretrained_cfg.append(init_cfg) | |
else: | |
other_cfgs.append(init_cfg) | |
initialize(self, other_cfgs) | |
else: | |
super().init_weights() | |
initialize(self, pretrained_cfg) | |
def predict( | |
self, batch_inputs, detection_features, batch_data_samples, rescale: bool = True | |
): | |
text_prompts = [] | |
enhanced_text_prompts = [] | |
tokens_positives = [] | |
for data_samples in batch_data_samples: | |
text_prompts.append(data_samples.text) | |
if "caption_prompt" in data_samples: | |
enhanced_text_prompts.append(data_samples.caption_prompt) | |
else: | |
enhanced_text_prompts.append(None) | |
tokens_positives.append(data_samples.get("tokens_positive", None)) | |
if "custom_entities" in batch_data_samples[0]: | |
# Assuming that the `custom_entities` flag | |
# inside a batch is always the same. For single image inference | |
custom_entities = batch_data_samples[0].custom_entities | |
else: | |
custom_entities = False | |
if self.track_text_dict is not None and self.track_text_prompt == text_prompts: | |
# text feature map layer | |
is_rec_tasks = [] | |
for i, data_samples in enumerate(batch_data_samples): | |
if self.token_positive_maps[i] is not None: | |
is_rec_tasks.append(False) | |
else: | |
is_rec_tasks.append(True) | |
data_samples.token_positive_map = self.token_positive_maps[i] | |
visual_feats = detection_features | |
head_inputs_dict = self.forward_transformer( | |
visual_feats, self.track_text_dict, batch_data_samples | |
) | |
results_list = self.bbox_head.predict( | |
**head_inputs_dict, | |
rescale=rescale, | |
batch_data_samples=batch_data_samples, | |
) | |
entities = self.track_entities | |
else: | |
self.track_text_prompt = text_prompts | |
if len(text_prompts) == 1: | |
# All the text prompts are the same, | |
# so there is no need to calculate them multiple times. | |
_positive_maps_and_prompts = [ | |
self.get_tokens_positive_and_prompts( | |
text_prompts[0], | |
custom_entities, | |
enhanced_text_prompts[0], | |
tokens_positives[0], | |
) | |
] * len(batch_inputs) | |
else: | |
_positive_maps_and_prompts = [ | |
self.get_tokens_positive_and_prompts( | |
text_prompt, | |
custom_entities, | |
enhanced_text_prompt, | |
tokens_positive, | |
) | |
for text_prompt, enhanced_text_prompt, tokens_positive in zip( | |
text_prompts, enhanced_text_prompts, tokens_positives | |
) | |
] | |
token_positive_maps, text_prompts, _, entities = zip( | |
*_positive_maps_and_prompts | |
) | |
self.token_positive_maps = token_positive_maps | |
self.track_entities = entities | |
# image feature extraction | |
visual_feats = detection_features | |
if isinstance(text_prompts[0], list): | |
# chunked text prompts, only bs=1 is supported | |
assert len(batch_inputs) == 1 | |
count = 0 | |
results_list = [] | |
entities = [[item for lst in entities[0] for item in lst]] | |
for b in range(len(text_prompts[0])): | |
text_prompts_once = [text_prompts[0][b]] | |
token_positive_maps_once = token_positive_maps[0][b] | |
text_dict = self.language_model(text_prompts_once) | |
# text feature map layer | |
if self.text_feat_map is not None: | |
text_dict["embedded"] = self.text_feat_map( | |
text_dict["embedded"] | |
) | |
batch_data_samples[0].token_positive_map = token_positive_maps_once | |
head_inputs_dict = self.forward_transformer( | |
copy.deepcopy(visual_feats), text_dict, batch_data_samples | |
) | |
pred_instances = self.bbox_head.predict( | |
**head_inputs_dict, | |
rescale=rescale, | |
batch_data_samples=batch_data_samples, | |
)[0] | |
if len(pred_instances) > 0: | |
pred_instances.labels += count | |
count += len(token_positive_maps_once) | |
results_list.append(pred_instances) | |
results_list = [results_list[0].cat(results_list)] | |
is_rec_tasks = [False] * len(results_list) | |
else: | |
# extract text feats | |
text_dict = self.language_model(list(text_prompts)) | |
# text feature map layer | |
if self.text_feat_map is not None: | |
text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) | |
is_rec_tasks = [] | |
for i, data_samples in enumerate(batch_data_samples): | |
if token_positive_maps[i] is not None: | |
is_rec_tasks.append(False) | |
else: | |
is_rec_tasks.append(True) | |
data_samples.token_positive_map = token_positive_maps[i] | |
if self.track_text_dict is None: | |
self.track_text_dict = text_dict | |
head_inputs_dict = self.forward_transformer( | |
visual_feats, text_dict, batch_data_samples | |
) | |
results_list = self.bbox_head.predict( | |
**head_inputs_dict, | |
rescale=rescale, | |
batch_data_samples=batch_data_samples, | |
) | |
for data_sample, pred_instances, entity, is_rec_task in zip( | |
batch_data_samples, results_list, entities, is_rec_tasks | |
): | |
if len(pred_instances) > 0: | |
label_names = [] | |
for labels in pred_instances.labels: | |
if is_rec_task: | |
label_names.append(entity) | |
continue | |
if labels >= len(entity): | |
warnings.warn( | |
"The unexpected output indicates an issue with " | |
"named entity recognition. You can try " | |
"setting custom_entities=True and running " | |
"again to see if it helps." | |
) | |
label_names.append("unobject") | |
else: | |
label_names.append(entity[labels]) | |
# for visualization | |
pred_instances.label_names = label_names | |
data_sample.pred_instances = pred_instances | |
return batch_data_samples | |