# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Union import mmengine import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmengine.model import BaseModule from mmpretrain.registry import MODELS @MODELS.register_module() class VQAGenerationHead(BaseModule): """Generation head for multi-modal pre-trained task, adapted by BLIP. Normally used for qa generation task (open-set) Args: decoder (dict): Decoder for decoding answers. inference_method (str): Inference method. One of 'rank', 'generate'. - If 'rank', the model will return answers with the highest probability from the answer list. - If 'generate', the model will generate answers. - Only for test, not for train / val. num_beams (int): Number of beams for beam search. 1 means no beam search. Only support when inference_method=='generate'. Defaults to 3. num_ans_candidates (int): Number of answer candidates, used to filter out answers with low probability. Only support when inference_method=='rank'. Defaults to 128. loss (dict or nn.Module): Config of loss or module of loss. Defaults to ``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``. init_cfg (dict, optional): the config to control the initialization. Defaults to None. answer_list_path (str, optional): Path to `answer_list.json` (json file of a answer list). Required when inference_method=='rank'. TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param. Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to maintain compatibility with torch < 1.10.0 """ def __init__( self, decoder: dict, inference_method: str = 'generate', num_beams: int = 3, num_ans_candidates: int = 128, loss: Union[dict, nn.Module] = nn.CrossEntropyLoss( reduction='none', ignore_index=-100), init_cfg: Optional[dict] = None, answer_list_path: Optional[str] = None, ) -> None: super(VQAGenerationHead, self).__init__(init_cfg=init_cfg) self.decoder = MODELS.build(decoder) if inference_method == 'generate': assert isinstance(num_beams, int), \ 'for VQA `generate` mode, `num_beams` must be a int.' self.num_beams = num_beams self.num_ans_candidates = None self.answer_list = None elif inference_method == 'rank': assert isinstance(num_ans_candidates, int), \ 'for VQA `rank` mode, `num_ans_candidates` must be a int.' assert isinstance(answer_list_path, str), \ 'for VQA `rank` mode, `answer_list_path` must be set as ' \ 'the path to `answer_list.json`.' self.num_beams = None self.answer_list = mmengine.load(answer_list_path) if isinstance(self.answer_list, dict): self.answer_list = list(self.answer_list.keys()) assert isinstance(self.answer_list, list) and all( isinstance(item, str) for item in self.answer_list), \ 'for VQA `rank` mode, `answer_list.json` must be a list of str' self.num_ans_candidates = min(num_ans_candidates, len(self.answer_list)) else: raise AssertionError( 'for VQA, `inference_method` must be "generate" or "rank", ' 'got {}.'.format(inference_method)) self.inference_method = inference_method if not isinstance(loss, nn.Module): loss = MODELS.build(loss) self.loss_module = loss def forward(self, feats: dict): prediction_logits = self.decoder( feats['answer_input_ids'], attention_mask=feats['answer_attention_mask'], encoder_hidden_states=feats['question_states'], encoder_attention_mask=feats['question_atts'], labels=feats['answer_targets'], return_dict=True, return_logits=True, # directly return logits, not computing loss reduction='none', ) return prediction_logits def loss(self, feats: dict, data_samples=None): """Calculate losses from the extracted features. Args: feats (dict): The features extracted from the backbone. data_samples (List[BaseDataElement]): The annotation data of every samples. Returns: dict[str, Tensor]: a dictionary of loss components """ shifted_prediction_scores = self(feats) labels = feats['answer_targets'] lm_loss = None # we are doing next-token prediction; # shift prediction scores and input ids by one labels = labels[:, 1:].contiguous() lm_loss = self.loss_module( shifted_prediction_scores.view(-1, self.decoder.med_config.vocab_size), labels.view(-1)) lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1) # compute weighted loss losses = dict() loss = feats['answer_weight'] * lm_loss loss = loss.sum() / feats['batch_size'] losses['vqa_loss'] = loss return losses def predict_rank(self, feats: dict, data_samples=None): """Predict rank in a close-set answer list.""" question_states = feats['multimodal_embeds'] question_atts = feats['question_atts'] answer_candidates = feats['answer_candidates'] assert answer_candidates is not None answer_ids = answer_candidates.input_ids answer_atts = answer_candidates.attention_mask num_ques = question_states.size(0) start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token start_output = self.decoder( start_ids, encoder_hidden_states=question_states, encoder_attention_mask=question_atts, return_dict=True, reduction='none', ) logits = start_output.logits[:, 0, :] # first token's logit # topk_probs: top-k probability # topk_ids: [num_question, k] answer_first_token = answer_ids[:, 1] prob_first_token = F.softmax( logits, dim=1).index_select( dim=1, index=answer_first_token) topk_probs, topk_ids = prob_first_token.topk( self.num_ans_candidates, dim=1) # answer input: [num_question*k, answer_len] input_ids = [] input_atts = [] for b, topk_id in enumerate(topk_ids): input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) input_ids = torch.cat(input_ids, dim=0) input_atts = torch.cat(input_atts, dim=0) targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'], -100) def tile(x, dim, n_tile): init_dim = x.size(dim) repeat_idx = [1] * x.dim() repeat_idx[dim] = n_tile x = x.repeat(*(repeat_idx)) order_index = torch.LongTensor( np.concatenate([ init_dim * np.arange(n_tile) + i for i in range(init_dim) ])) return torch.index_select(x, dim, order_index.to(x.device)) # repeat encoder's output for top-k answers question_states = tile(question_states, 0, self.num_ans_candidates) question_atts = tile(question_atts, 0, self.num_ans_candidates) output = self.decoder( input_ids, attention_mask=input_atts, encoder_hidden_states=question_states, encoder_attention_mask=question_atts, labels=targets_ids, return_dict=True, reduction='none', ) log_probs_sum = -output.loss log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates) max_topk_ids = log_probs_sum.argmax(dim=1) max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] answers = [self.answer_list[max_id] for max_id in max_ids] return answers def predict_generate(self, feats: dict, data_samples=None): """Predict answers in a generation manner.""" device = feats['multimodal_embeds'].device question_states = feats['multimodal_embeds'] question_atts = torch.ones( question_states.size()[:-1], dtype=torch.long).to(device) model_kwargs = { 'encoder_hidden_states': question_states, 'encoder_attention_mask': question_atts } bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1), fill_value=feats['bos_token_id'], device=device) outputs = self.decoder.generate( input_ids=bos_ids, max_length=10, min_length=1, num_beams=self.num_beams, eos_token_id=feats['sep_token_id'], pad_token_id=feats['pad_token_id'], **model_kwargs) return outputs def predict(self, feats: dict, data_samples=None): """Predict results from the extracted features.""" if self.inference_method == 'generate': return self.predict_generate(feats, data_samples) elif self.inference_method == 'rank': return self.predict_rank(feats, data_samples)