|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
from typing import List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torchvision.ops.boxes import nms |
|
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast |
|
|
|
from groundingdino.util import box_ops, get_tokenlizer |
|
from groundingdino.util.misc import ( |
|
NestedTensor, |
|
accuracy, |
|
get_world_size, |
|
interpolate, |
|
inverse_sigmoid, |
|
is_dist_avail_and_initialized, |
|
nested_tensor_from_tensor_list, |
|
) |
|
from groundingdino.util.utils import get_phrases_from_posmap |
|
from groundingdino.util.visualizer import COCOVisualizer |
|
from groundingdino.util.vl_utils import create_positive_map_from_span |
|
|
|
from ..registry import MODULE_BUILD_FUNCS |
|
from .backbone import build_backbone |
|
from .bertwarper import ( |
|
BertModelWarper, |
|
generate_masks_with_special_tokens, |
|
generate_masks_with_special_tokens_and_transfer_map, |
|
) |
|
from .transformer import build_transformer |
|
from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss |
|
|
|
|
|
class GroundingDINO(nn.Module): |
|
"""This is the Cross-Attention Detector module that performs object detection""" |
|
|
|
def __init__( |
|
self, |
|
backbone, |
|
transformer, |
|
num_queries, |
|
aux_loss=False, |
|
iter_update=False, |
|
query_dim=2, |
|
num_feature_levels=1, |
|
nheads=8, |
|
|
|
two_stage_type="no", |
|
dec_pred_bbox_embed_share=True, |
|
two_stage_class_embed_share=True, |
|
two_stage_bbox_embed_share=True, |
|
num_patterns=0, |
|
dn_number=100, |
|
dn_box_noise_scale=0.4, |
|
dn_label_noise_ratio=0.5, |
|
dn_labelbook_size=100, |
|
text_encoder_type="bert-base-uncased", |
|
sub_sentence_present=True, |
|
max_text_len=256, |
|
): |
|
"""Initializes the model. |
|
Parameters: |
|
backbone: torch module of the backbone to be used. See backbone.py |
|
transformer: torch module of the transformer architecture. See transformer.py |
|
num_queries: number of object queries, ie detection slot. This is the maximal number of objects |
|
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. |
|
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. |
|
""" |
|
super().__init__() |
|
self.num_queries = num_queries |
|
self.transformer = transformer |
|
self.hidden_dim = hidden_dim = transformer.d_model |
|
self.num_feature_levels = num_feature_levels |
|
self.nheads = nheads |
|
self.max_text_len = 256 |
|
self.sub_sentence_present = sub_sentence_present |
|
|
|
|
|
self.query_dim = query_dim |
|
assert query_dim == 4 |
|
|
|
|
|
self.num_patterns = num_patterns |
|
self.dn_number = dn_number |
|
self.dn_box_noise_scale = dn_box_noise_scale |
|
self.dn_label_noise_ratio = dn_label_noise_ratio |
|
self.dn_labelbook_size = dn_labelbook_size |
|
|
|
|
|
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) |
|
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type) |
|
self.bert.pooler.dense.weight.requires_grad_(False) |
|
self.bert.pooler.dense.bias.requires_grad_(False) |
|
self.bert = BertModelWarper(bert_model=self.bert) |
|
|
|
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True) |
|
nn.init.constant_(self.feat_map.bias.data, 0) |
|
nn.init.xavier_uniform_(self.feat_map.weight.data) |
|
|
|
|
|
|
|
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"]) |
|
|
|
|
|
if num_feature_levels > 1: |
|
num_backbone_outs = len(backbone.num_channels) |
|
input_proj_list = [] |
|
for _ in range(num_backbone_outs): |
|
in_channels = backbone.num_channels[_] |
|
input_proj_list.append( |
|
nn.Sequential( |
|
nn.Conv2d(in_channels, hidden_dim, kernel_size=1), |
|
nn.GroupNorm(32, hidden_dim), |
|
) |
|
) |
|
for _ in range(num_feature_levels - num_backbone_outs): |
|
input_proj_list.append( |
|
nn.Sequential( |
|
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), |
|
nn.GroupNorm(32, hidden_dim), |
|
) |
|
) |
|
in_channels = hidden_dim |
|
self.input_proj = nn.ModuleList(input_proj_list) |
|
else: |
|
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!" |
|
self.input_proj = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), |
|
nn.GroupNorm(32, hidden_dim), |
|
) |
|
] |
|
) |
|
|
|
self.backbone = backbone |
|
self.aux_loss = aux_loss |
|
self.box_pred_damping = box_pred_damping = None |
|
|
|
self.iter_update = iter_update |
|
assert iter_update, "Why not iter_update?" |
|
|
|
|
|
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share |
|
|
|
_class_embed = ContrastiveEmbed() |
|
|
|
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) |
|
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) |
|
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) |
|
|
|
if dec_pred_bbox_embed_share: |
|
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)] |
|
else: |
|
box_embed_layerlist = [ |
|
copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers) |
|
] |
|
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)] |
|
self.bbox_embed = nn.ModuleList(box_embed_layerlist) |
|
self.class_embed = nn.ModuleList(class_embed_layerlist) |
|
self.transformer.decoder.bbox_embed = self.bbox_embed |
|
self.transformer.decoder.class_embed = self.class_embed |
|
|
|
|
|
self.two_stage_type = two_stage_type |
|
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format( |
|
two_stage_type |
|
) |
|
if two_stage_type != "no": |
|
if two_stage_bbox_embed_share: |
|
assert dec_pred_bbox_embed_share |
|
self.transformer.enc_out_bbox_embed = _bbox_embed |
|
else: |
|
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) |
|
|
|
if two_stage_class_embed_share: |
|
assert dec_pred_bbox_embed_share |
|
self.transformer.enc_out_class_embed = _class_embed |
|
else: |
|
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) |
|
|
|
self.refpoint_embed = None |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
|
|
for proj in self.input_proj: |
|
nn.init.xavier_uniform_(proj[0].weight, gain=1) |
|
nn.init.constant_(proj[0].bias, 0) |
|
|
|
def set_image_tensor(self, samples: NestedTensor): |
|
if isinstance(samples, (list, torch.Tensor)): |
|
samples = nested_tensor_from_tensor_list(samples) |
|
self.features, self.poss = self.backbone(samples) |
|
|
|
def unset_image_tensor(self): |
|
if hasattr(self, 'features'): |
|
del self.features |
|
if hasattr(self,'poss'): |
|
del self.poss |
|
|
|
def set_image_features(self, features , poss): |
|
self.features = features |
|
self.poss = poss |
|
|
|
def init_ref_points(self, use_num_queries): |
|
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) |
|
|
|
def forward(self, samples: NestedTensor, targets: List = None, **kw): |
|
"""The forward expects a NestedTensor, which consists of: |
|
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] |
|
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels |
|
|
|
It returns a dict with the following elements: |
|
- "pred_logits": the classification logits (including no-object) for all queries. |
|
Shape= [batch_size x num_queries x num_classes] |
|
- "pred_boxes": The normalized boxes coordinates for all queries, represented as |
|
(center_x, center_y, width, height). These values are normalized in [0, 1], |
|
relative to the size of each individual image (disregarding possible padding). |
|
See PostProcess for information on how to retrieve the unnormalized bounding box. |
|
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of |
|
dictionnaries containing the two above keys for each decoder layer. |
|
""" |
|
if targets is None: |
|
captions = kw["captions"] |
|
else: |
|
captions = [t["caption"] for t in targets] |
|
|
|
|
|
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to( |
|
samples.device |
|
) |
|
( |
|
text_self_attention_masks, |
|
position_ids, |
|
cate_to_token_mask_list, |
|
) = generate_masks_with_special_tokens_and_transfer_map( |
|
tokenized, self.specical_tokens, self.tokenizer |
|
) |
|
|
|
if text_self_attention_masks.shape[1] > self.max_text_len: |
|
text_self_attention_masks = text_self_attention_masks[ |
|
:, : self.max_text_len, : self.max_text_len |
|
] |
|
position_ids = position_ids[:, : self.max_text_len] |
|
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] |
|
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len] |
|
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len] |
|
|
|
|
|
if self.sub_sentence_present: |
|
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"} |
|
tokenized_for_encoder["attention_mask"] = text_self_attention_masks |
|
tokenized_for_encoder["position_ids"] = position_ids |
|
else: |
|
|
|
tokenized_for_encoder = tokenized |
|
|
|
bert_output = self.bert(**tokenized_for_encoder) |
|
|
|
encoded_text = self.feat_map(bert_output["last_hidden_state"]) |
|
text_token_mask = tokenized.attention_mask.bool() |
|
|
|
|
|
|
|
if encoded_text.shape[1] > self.max_text_len: |
|
encoded_text = encoded_text[:, : self.max_text_len, :] |
|
text_token_mask = text_token_mask[:, : self.max_text_len] |
|
position_ids = position_ids[:, : self.max_text_len] |
|
text_self_attention_masks = text_self_attention_masks[ |
|
:, : self.max_text_len, : self.max_text_len |
|
] |
|
|
|
text_dict = { |
|
"encoded_text": encoded_text, |
|
"text_token_mask": text_token_mask, |
|
"position_ids": position_ids, |
|
"text_self_attention_masks": text_self_attention_masks, |
|
} |
|
|
|
|
|
if isinstance(samples, (list, torch.Tensor)): |
|
samples = nested_tensor_from_tensor_list(samples) |
|
if not hasattr(self, 'features') or not hasattr(self, 'poss'): |
|
self.set_image_tensor(samples) |
|
|
|
srcs = [] |
|
masks = [] |
|
for l, feat in enumerate(self.features): |
|
src, mask = feat.decompose() |
|
srcs.append(self.input_proj[l](src)) |
|
masks.append(mask) |
|
assert mask is not None |
|
if self.num_feature_levels > len(srcs): |
|
_len_srcs = len(srcs) |
|
for l in range(_len_srcs, self.num_feature_levels): |
|
if l == _len_srcs: |
|
src = self.input_proj[l](self.features[-1].tensors) |
|
else: |
|
src = self.input_proj[l](srcs[-1]) |
|
m = samples.mask |
|
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] |
|
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) |
|
srcs.append(src) |
|
masks.append(mask) |
|
self.poss.append(pos_l) |
|
|
|
input_query_bbox = input_query_label = attn_mask = dn_meta = None |
|
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( |
|
srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict |
|
) |
|
|
|
|
|
outputs_coord_list = [] |
|
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate( |
|
zip(reference[:-1], self.bbox_embed, hs) |
|
): |
|
layer_delta_unsig = layer_bbox_embed(layer_hs) |
|
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) |
|
layer_outputs_unsig = layer_outputs_unsig.sigmoid() |
|
outputs_coord_list.append(layer_outputs_unsig) |
|
outputs_coord_list = torch.stack(outputs_coord_list) |
|
|
|
|
|
outputs_class = torch.stack( |
|
[ |
|
layer_cls_embed(layer_hs, text_dict) |
|
for layer_cls_embed, layer_hs in zip(self.class_embed, hs) |
|
] |
|
) |
|
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unset_image_tensor = kw.get('unset_image_tensor', True) |
|
if unset_image_tensor: |
|
self.unset_image_tensor() |
|
return out |
|
|
|
@torch.jit.unused |
|
def _set_aux_loss(self, outputs_class, outputs_coord): |
|
|
|
|
|
|
|
return [ |
|
{"pred_logits": a, "pred_boxes": b} |
|
for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) |
|
] |
|
|
|
|
|
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino") |
|
def build_groundingdino(args): |
|
|
|
backbone = build_backbone(args) |
|
transformer = build_transformer(args) |
|
|
|
dn_labelbook_size = args.dn_labelbook_size |
|
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share |
|
sub_sentence_present = args.sub_sentence_present |
|
|
|
model = GroundingDINO( |
|
backbone, |
|
transformer, |
|
num_queries=args.num_queries, |
|
aux_loss=True, |
|
iter_update=True, |
|
query_dim=4, |
|
num_feature_levels=args.num_feature_levels, |
|
nheads=args.nheads, |
|
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, |
|
two_stage_type=args.two_stage_type, |
|
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, |
|
two_stage_class_embed_share=args.two_stage_class_embed_share, |
|
num_patterns=args.num_patterns, |
|
dn_number=0, |
|
dn_box_noise_scale=args.dn_box_noise_scale, |
|
dn_label_noise_ratio=args.dn_label_noise_ratio, |
|
dn_labelbook_size=dn_labelbook_size, |
|
text_encoder_type=args.text_encoder_type, |
|
sub_sentence_present=sub_sentence_present, |
|
max_text_len=args.max_text_len, |
|
) |
|
|
|
return model |
|
|
|
|