Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import warnings | |
from typing import Dict, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from mmdet.structures import OptSampleList, SampleList | |
from ..layers import SinePositionalEncoding | |
from ..layers.transformer.grounding_dino_layers import ( | |
GroundingDinoTransformerDecoder, GroundingDinoTransformerEncoder) | |
from .dino import DINO | |
from .glip import (create_positive_map, create_positive_map_label_to_token, | |
run_ner) | |
class GroundingDINO(DINO): | |
"""Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- | |
Training for Open-Set Object Detection. | |
<https://arxiv.org/abs/2303.05499>`_ | |
Code is modified from the `official github repo | |
<https://github.com/IDEA-Research/GroundingDINO>`_. | |
""" | |
def __init__(self, language_model, *args, **kwargs) -> None: | |
self.language_model_cfg = language_model | |
self._special_tokens = '. ' | |
super().__init__(*args, **kwargs) | |
def _init_layers(self) -> None: | |
"""Initialize layers except for backbone, neck and bbox_head.""" | |
self.positional_encoding = SinePositionalEncoding( | |
**self.positional_encoding) | |
self.encoder = GroundingDinoTransformerEncoder(**self.encoder) | |
self.decoder = GroundingDinoTransformerDecoder(**self.decoder) | |
self.embed_dims = self.encoder.embed_dims | |
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) | |
num_feats = self.positional_encoding.num_feats | |
assert num_feats * 2 == self.embed_dims, \ | |
f'embed_dims should be exactly 2 times of num_feats. ' \ | |
f'Found {self.embed_dims} and {num_feats}.' | |
self.level_embed = nn.Parameter( | |
torch.Tensor(self.num_feature_levels, self.embed_dims)) | |
self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) | |
self.memory_trans_norm = nn.LayerNorm(self.embed_dims) | |
# text modules | |
self.language_model = MODELS.build(self.language_model_cfg) | |
self.text_feat_map = nn.Linear( | |
self.language_model.language_backbone.body.language_dim, | |
self.embed_dims, | |
bias=True) | |
def init_weights(self) -> None: | |
"""Initialize weights for Transformer and other components.""" | |
super().init_weights() | |
nn.init.constant_(self.text_feat_map.bias.data, 0) | |
nn.init.xavier_uniform_(self.text_feat_map.weight.data) | |
def get_tokens_and_prompts( | |
self, | |
original_caption: Union[str, list, tuple], | |
custom_entities: bool = False) -> Tuple[dict, str, list]: | |
"""Get the tokens positive and prompts for the caption.""" | |
if isinstance(original_caption, (list, tuple)) or custom_entities: | |
if custom_entities and isinstance(original_caption, str): | |
original_caption = original_caption.strip(self._special_tokens) | |
original_caption = original_caption.split(self._special_tokens) | |
original_caption = list( | |
filter(lambda x: len(x) > 0, original_caption)) | |
caption_string = '' | |
tokens_positive = [] | |
for idx, word in enumerate(original_caption): | |
tokens_positive.append( | |
[[len(caption_string), | |
len(caption_string) + len(word)]]) | |
caption_string += word | |
caption_string += self._special_tokens | |
# NOTE: Tokenizer in Grounding DINO is different from | |
# that in GLIP. The tokenizer in GLIP will pad the | |
# caption_string to max_length, while the tokenizer | |
# in Grounding DINO will not. | |
tokenized = self.language_model.tokenizer( | |
[caption_string], | |
padding='max_length' | |
if self.language_model.pad_to_max else 'longest', | |
return_tensors='pt') | |
entities = original_caption | |
else: | |
if not original_caption.endswith('.'): | |
original_caption = original_caption + self._special_tokens | |
# NOTE: Tokenizer in Grounding DINO is different from | |
# that in GLIP. The tokenizer in GLIP will pad the | |
# caption_string to max_length, while the tokenizer | |
# in Grounding DINO will not. | |
tokenized = self.language_model.tokenizer( | |
[original_caption], | |
padding='max_length' | |
if self.language_model.pad_to_max else 'longest', | |
return_tensors='pt') | |
tokens_positive, noun_phrases = run_ner(original_caption) | |
entities = noun_phrases | |
caption_string = original_caption | |
return tokenized, caption_string, tokens_positive, entities | |
def get_positive_map(self, tokenized, tokens_positive): | |
positive_map = create_positive_map(tokenized, tokens_positive) | |
positive_map_label_to_token = create_positive_map_label_to_token( | |
positive_map, plus=1) | |
return positive_map_label_to_token, positive_map | |
def get_tokens_positive_and_prompts( | |
self, | |
original_caption: Union[str, list, tuple], | |
custom_entities: bool = False) -> Tuple[dict, str, Tensor, list]: | |
"""Get the tokens positive and prompts for the caption. | |
Args: | |
original_caption (str): The original caption, e.g. 'bench . car .' | |
custom_entities (bool, optional): Whether to use custom entities. | |
If ``True``, the ``original_caption`` should be a list of | |
strings, each of which is a word. Defaults to False. | |
Returns: | |
Tuple[dict, str, dict, str]: The dict is a mapping from each entity | |
id, which is numbered from 1, to its positive token id. | |
The str represents the prompts. | |
""" | |
tokenized, caption_string, tokens_positive, entities = \ | |
self.get_tokens_and_prompts( | |
original_caption, custom_entities) | |
positive_map_label_to_token, positive_map = self.get_positive_map( | |
tokenized, tokens_positive) | |
return positive_map_label_to_token, caption_string, \ | |
positive_map, entities | |
def forward_transformer( | |
self, | |
img_feats: Tuple[Tensor], | |
text_dict: Dict, | |
batch_data_samples: OptSampleList = None, | |
) -> Dict: | |
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( | |
img_feats, batch_data_samples) | |
encoder_outputs_dict = self.forward_encoder( | |
**encoder_inputs_dict, text_dict=text_dict) | |
tmp_dec_in, head_inputs_dict = self.pre_decoder( | |
**encoder_outputs_dict, batch_data_samples=batch_data_samples) | |
decoder_inputs_dict.update(tmp_dec_in) | |
decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) | |
head_inputs_dict.update(decoder_outputs_dict) | |
return head_inputs_dict | |
def forward_encoder(self, feat: Tensor, feat_mask: Tensor, | |
feat_pos: Tensor, spatial_shapes: Tensor, | |
level_start_index: Tensor, valid_ratios: Tensor, | |
text_dict: Dict) -> Dict: | |
text_token_mask = text_dict['text_token_mask'] | |
memory, memory_text = self.encoder( | |
query=feat, | |
query_pos=feat_pos, | |
key_padding_mask=feat_mask, # for self_attn | |
spatial_shapes=spatial_shapes, | |
level_start_index=level_start_index, | |
valid_ratios=valid_ratios, | |
# for text encoder | |
memory_text=text_dict['embedded'], | |
text_attention_mask=~text_token_mask, | |
position_ids=text_dict['position_ids'], | |
text_self_attention_masks=text_dict['masks']) | |
encoder_outputs_dict = dict( | |
memory=memory, | |
memory_mask=feat_mask, | |
spatial_shapes=spatial_shapes, | |
memory_text=memory_text, | |
text_token_mask=text_token_mask) | |
return encoder_outputs_dict | |
def pre_decoder( | |
self, | |
memory: Tensor, | |
memory_mask: Tensor, | |
spatial_shapes: Tensor, | |
memory_text: Tensor, | |
text_token_mask: Tensor, | |
batch_data_samples: OptSampleList = None, | |
) -> Tuple[Dict]: | |
bs, _, c = memory.shape | |
output_memory, output_proposals = self.gen_encoder_output_proposals( | |
memory, memory_mask, spatial_shapes) | |
enc_outputs_class = self.bbox_head.cls_branches[ | |
self.decoder.num_layers](output_memory, memory_text, | |
text_token_mask) | |
cls_out_features = self.bbox_head.cls_branches[ | |
self.decoder.num_layers].max_text_len | |
enc_outputs_coord_unact = self.bbox_head.reg_branches[ | |
self.decoder.num_layers](output_memory) + output_proposals | |
# NOTE The DINO selects top-k proposals according to scores of | |
# multi-class classification, while DeformDETR, where the input | |
# is `enc_outputs_class[..., 0]` selects according to scores of | |
# binary classification. | |
topk_indices = torch.topk( | |
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1] | |
topk_score = torch.gather( | |
enc_outputs_class, 1, | |
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) | |
topk_coords_unact = torch.gather( | |
enc_outputs_coord_unact, 1, | |
topk_indices.unsqueeze(-1).repeat(1, 1, 4)) | |
topk_coords = topk_coords_unact.sigmoid() | |
topk_coords_unact = topk_coords_unact.detach() | |
query = self.query_embedding.weight[:, None, :] | |
query = query.repeat(1, bs, 1).transpose(0, 1) | |
if self.training: | |
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ | |
self.dn_query_generator(batch_data_samples) | |
query = torch.cat([dn_label_query, query], dim=1) | |
reference_points = torch.cat([dn_bbox_query, topk_coords_unact], | |
dim=1) | |
else: | |
reference_points = topk_coords_unact | |
dn_mask, dn_meta = None, None | |
reference_points = reference_points.sigmoid() | |
decoder_inputs_dict = dict( | |
query=query, | |
memory=memory, | |
reference_points=reference_points, | |
dn_mask=dn_mask, | |
memory_text=memory_text, | |
text_attention_mask=~text_token_mask, | |
) | |
# NOTE DINO calculates encoder losses on scores and coordinates | |
# of selected top-k encoder queries, while DeformDETR is of all | |
# encoder queries. | |
head_inputs_dict = dict( | |
enc_outputs_class=topk_score, | |
enc_outputs_coord=topk_coords, | |
dn_meta=dn_meta) if self.training else dict() | |
# append text_feats to head_inputs_dict | |
head_inputs_dict['memory_text'] = memory_text | |
head_inputs_dict['text_token_mask'] = text_token_mask | |
return decoder_inputs_dict, head_inputs_dict | |
def loss(self, batch_inputs: Tensor, | |
batch_data_samples: SampleList) -> Union[dict, list]: | |
# TODO: Only open vocabulary tasks are supported for training now. | |
text_prompts = [ | |
data_samples.text for data_samples in batch_data_samples | |
] | |
gt_labels = [ | |
data_samples.gt_instances.labels | |
for data_samples in batch_data_samples | |
] | |
new_text_prompts = [] | |
positive_maps = [] | |
if len(set(text_prompts)) == 1: | |
# All the text prompts are the same, | |
# so there is no need to calculate them multiple times. | |
tokenized, caption_string, tokens_positive, _ = \ | |
self.get_tokens_and_prompts( | |
text_prompts[0], True) | |
new_text_prompts = [caption_string] * len(batch_inputs) | |
for gt_label in gt_labels: | |
new_tokens_positive = [ | |
tokens_positive[label] for label in gt_label | |
] | |
_, positive_map = self.get_positive_map( | |
tokenized, new_tokens_positive) | |
positive_maps.append(positive_map) | |
else: | |
for text_prompt, gt_label in zip(text_prompts, gt_labels): | |
tokenized, caption_string, tokens_positive, _ = \ | |
self.get_tokens_and_prompts( | |
text_prompt, True) | |
new_tokens_positive = [ | |
tokens_positive[label] for label in gt_label | |
] | |
_, positive_map = self.get_positive_map( | |
tokenized, new_tokens_positive) | |
positive_maps.append(positive_map) | |
new_text_prompts.append(caption_string) | |
text_dict = self.language_model(new_text_prompts) | |
if self.text_feat_map is not None: | |
text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) | |
for i, data_samples in enumerate(batch_data_samples): | |
positive_map = positive_maps[i].to( | |
batch_inputs.device).bool().float() | |
text_token_mask = text_dict['text_token_mask'][i] | |
data_samples.gt_instances.positive_maps = positive_map | |
data_samples.gt_instances.text_token_mask = \ | |
text_token_mask.unsqueeze(0).repeat( | |
len(positive_map), 1) | |
visual_features = self.extract_feat(batch_inputs) | |
head_inputs_dict = self.forward_transformer(visual_features, text_dict, | |
batch_data_samples) | |
losses = self.bbox_head.loss( | |
**head_inputs_dict, batch_data_samples=batch_data_samples) | |
return losses | |
def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): | |
text_prompts = [ | |
data_samples.text for data_samples in batch_data_samples | |
] | |
if 'custom_entities' in batch_data_samples[0]: | |
# Assuming that the `custom_entities` flag | |
# inside a batch is always the same. For single image inference | |
custom_entities = batch_data_samples[0].custom_entities | |
else: | |
custom_entities = False | |
if len(text_prompts) == 1: | |
# All the text prompts are the same, | |
# so there is no need to calculate them multiple times. | |
_positive_maps_and_prompts = [ | |
self.get_tokens_positive_and_prompts(text_prompts[0], | |
custom_entities) | |
] * len(batch_inputs) | |
else: | |
_positive_maps_and_prompts = [ | |
self.get_tokens_positive_and_prompts(text_prompt, | |
custom_entities) | |
for text_prompt in text_prompts | |
] | |
token_positive_maps, text_prompts, _, entities = zip( | |
*_positive_maps_and_prompts) | |
# extract text feats | |
text_dict = self.language_model(list(text_prompts)) | |
# text feature map layer | |
if self.text_feat_map is not None: | |
text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) | |
for i, data_samples in enumerate(batch_data_samples): | |
data_samples.token_positive_map = token_positive_maps[i] | |
# image feature extraction | |
visual_feats = self.extract_feat(batch_inputs) | |
head_inputs_dict = self.forward_transformer(visual_feats, text_dict, | |
batch_data_samples) | |
results_list = self.bbox_head.predict( | |
**head_inputs_dict, | |
rescale=rescale, | |
batch_data_samples=batch_data_samples) | |
for data_sample, pred_instances, entity in zip(batch_data_samples, | |
results_list, entities): | |
if len(pred_instances) > 0: | |
label_names = [] | |
for labels in pred_instances.labels: | |
if labels >= len(entity): | |
warnings.warn( | |
'The unexpected output indicates an issue with ' | |
'named entity recognition. You can try ' | |
'setting custom_entities=True and running ' | |
'again to see if it helps.') | |
label_names.append('unobject') | |
else: | |
label_names.append(entity[labels]) | |
# for visualization | |
pred_instances.label_names = label_names | |
data_sample.pred_instances = pred_instances | |
return batch_data_samples | |