Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
import torch.nn.functional as F | |
from mmengine.model import BaseModule | |
from mmpretrain.models.utils.box_utils import (box_cxcywh_to_xyxy, | |
generalized_box_iou) | |
from mmpretrain.registry import MODELS, TOKENIZER | |
class GroundingHead(BaseModule): | |
"""bbox Coordination generation head for multi-modal pre-trained task, | |
adapted by BLIP. Normally used for visual grounding. | |
Args: | |
loss: dict, | |
decoder: dict, | |
init_cfg (dict, optional): the config to control the initialization. | |
Defaults to None. | |
""" | |
def __init__( | |
self, | |
decoder: dict = None, | |
tokenizer: dict = None, | |
box_l1_loss_coeff=4.0, | |
box_giou_loss_coeff=2.0, | |
init_cfg: Optional[dict] = None, | |
) -> None: | |
super(GroundingHead, self).__init__(init_cfg=init_cfg) | |
''' init the decoder from med_config''' | |
self.decoder = None | |
if decoder: | |
self.decoder = MODELS.build(decoder) | |
self.loss_fn = torch.nn.CrossEntropyLoss( | |
reduction='none', ignore_index=-100) | |
self.box_l1_loss_coeff = box_l1_loss_coeff | |
self.box_giou_loss_coeff = box_giou_loss_coeff | |
if isinstance(tokenizer, dict): | |
self.tokenizer = TOKENIZER.build(tokenizer) | |
else: | |
self.tokenizer = tokenizer | |
self.image_res = 640 | |
prefix_ids = torch.tensor( | |
self.tokenizer.convert_tokens_to_ids(['[unused339]'])) | |
target_ids = torch.tensor( | |
self.tokenizer.convert_tokens_to_ids( | |
[f'[unused{340+_}]' for _ in range(self.image_res + 1)])) | |
self.register_buffer('prefix_ids', prefix_ids) | |
self.register_buffer('target_ids', target_ids) | |
bbox_prob_mask = torch.zeros(len(self.tokenizer)) | |
bbox_prob_mask[self.target_ids[0]:self.target_ids[-1] + 1] = 1 | |
bbox_prob_mask = (1.0 - bbox_prob_mask) * -10000.0 | |
self.register_buffer('bbox_prob_mask', bbox_prob_mask) | |
self.bin_start_idx = self.target_ids[0] | |
def forward(self, text_embedding, text_embedding_mask, | |
encoder_hidden_states, encoder_attention_mask): | |
# localize prompt token, text embedding | |
merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], | |
1) | |
merge_att_mask = torch.cat( | |
[encoder_attention_mask, text_embedding_mask], 1) | |
loc_prompt = self.prompt.weight.T | |
loc_prompt = torch.repeat_interleave(loc_prompt, | |
merge_att_mask.shape[0], | |
0).unsqueeze(1) | |
loc_prompt_mask = torch.ones(loc_prompt.shape[:-1]).long().to( | |
loc_prompt.device) | |
decoder_out = self.decoder( | |
inputs_embeds=loc_prompt, | |
attention_mask=loc_prompt_mask, | |
encoder_hidden_states=merged_encode_hs, | |
encoder_attention_mask=merge_att_mask, | |
output_hidden_states=True, | |
labels=None, | |
) | |
decoder_hs = decoder_out.hidden_states[-1][:, 0, :] | |
box_pred = self.box_head(decoder_hs) | |
return decoder_out, decoder_hs, box_pred | |
def loss(self, | |
text_embedding, | |
text_embedding_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
decoder_targets, | |
return_scores=False): | |
"""Calculate losses from the extracted features. | |
Args: | |
feats (dict): The features extracted from the backbone. | |
data_samples (List[BaseDataElement]): The annotation data of | |
every samples. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], | |
1) | |
merge_att_mask = torch.cat( | |
[encoder_attention_mask, text_embedding_mask], 1) | |
answer_targets = (decoder_targets * | |
self.image_res).long() + self.bin_start_idx | |
prefix_ids = torch.repeat_interleave(self.prefix_ids, | |
merge_att_mask.shape[0], | |
0).unsqueeze(-1) | |
prefix_ids = torch.cat([prefix_ids, answer_targets], dim=1) | |
answer_output = self.decoder( | |
prefix_ids, | |
encoder_hidden_states=merged_encode_hs, | |
encoder_attention_mask=merge_att_mask, | |
labels=None, | |
return_dict=True, | |
) | |
prob_mask = self.bbox_prob_mask.view(1, 1, | |
self.bbox_prob_mask.shape[-1]) | |
prediction_scores = answer_output.logits + prob_mask | |
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() | |
labels = prefix_ids[:, 1:].contiguous() | |
vocab_size = len(self.tokenizer) | |
loss_seq_init = self.loss_fn( | |
shifted_prediction_scores.view(-1, vocab_size), labels.view(-1)) | |
with torch.no_grad(): | |
pred_box = (torch.argmax( | |
prediction_scores[:, :-1, :].contiguous(), dim=-1) - | |
self.bin_start_idx) / self.image_res | |
weight_bbox = F.l1_loss( | |
pred_box, decoder_targets, reduction='none').clamp( | |
0, 5) * self.box_l1_loss_coeff | |
weight_giou = (1 - torch.diag( | |
generalized_box_iou( | |
box_cxcywh_to_xyxy(pred_box), | |
box_cxcywh_to_xyxy(decoder_targets))) | |
) * self.box_giou_loss_coeff | |
bs = text_embedding.shape[0] | |
loss_seq = loss_seq_init[:].view(bs, -1, 4) | |
loss_seq = loss_seq * weight_bbox | |
loss_seq = loss_seq * weight_giou.unsqueeze(1) | |
loss_seq = loss_seq.mean() | |
losses = { | |
'loss_seq': loss_seq, | |
'loss_seq_init': loss_seq_init.mean(), | |
'loss': loss_seq, | |
'box_l1': weight_bbox.mean(-1).mean().detach(), | |
'box_giou': weight_giou.mean().detach() | |
} | |
return losses | |
def predict( | |
self, | |
text_embedding, | |
text_embedding_mask, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
): | |
"""Generates the bbox coordinates at inference time.""" | |
merged_encode_hs = torch.cat([encoder_hidden_states, text_embedding], | |
1) | |
merge_att_mask = torch.cat( | |
[encoder_attention_mask, text_embedding_mask], 1) | |
prefix_ids = torch.repeat_interleave(self.prefix_ids, | |
merge_att_mask.shape[0], | |
0).unsqueeze(-1) | |
for _ in range(4): | |
decoder_output = self.decoder( | |
prefix_ids, | |
encoder_hidden_states=merged_encode_hs, | |
encoder_attention_mask=merge_att_mask, | |
labels=None, | |
return_dict=True, | |
) | |
prob_mask = self.bbox_prob_mask.view(1, 1, | |
self.bbox_prob_mask.shape[-1]) | |
prediction_scores = decoder_output.logits + prob_mask | |
prefix_ids = torch.cat([ | |
prefix_ids, | |
torch.argmax(prediction_scores[:, -1, :], dim=-1).unsqueeze(1) | |
], | |
dim=1) | |
pred_box = self.process_bbox(prefix_ids[:, 1:]) # xywh 0-1 to xyxy 0-1 | |
return pred_box | |
def process_bbox(self, bbox): | |
bbox = bbox - self.bin_start_idx | |
bbox = torch.true_divide(bbox, self.image_res) | |
bbox = box_cxcywh_to_xyxy(bbox) | |
bbox = torch.clip(bbox, 0, 1) | |
assert torch.all(bbox <= 1) | |
return bbox | |