# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ Implements the Generalized VL R-CNN framework """ import torch from torch import nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from maskrcnn_benchmark.structures.image_list import to_image_list from maskrcnn_benchmark.structures.bounding_box import BoxList from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist from ..backbone import build_backbone, build_fusion_backbone from ..rpn import build_rpn from ..roi_heads import build_roi_heads from ..language_backbone import build_language_backbone from transformers import AutoTokenizer import random import timeit import pdb from copy import deepcopy def random_word(input_ids, mask_token_id, vocabs, padding_token_id, greenlight_map): """ greenlight_map, batch_size x 256 (seq_len): 0 means this location cannot be calculated in the MLM loss -1 means this location cannot be masked!! 1 means this location can be masked and can be calculated in the MLM loss """ output_label = deepcopy(input_ids) for j in range(input_ids.size(0)): for i in range(input_ids.size(1)): prob = random.random() # mask token with probability ratio = 0.15 if greenlight_map is not None and greenlight_map[j, i] == -1: output_label[j, i] = -100 continue if (not input_ids[j, i] == padding_token_id) and prob < ratio: prob /= ratio # 80% randomly change token to mask token if prob < 0.8: input_ids[j, i] = mask_token_id # 10% randomly change token to random token elif prob < 0.9: input_ids[j, i] = random.choice(vocabs) else: # no masking token (will be ignored by loss function later) output_label[j, i] = -100 if greenlight_map is not None and greenlight_map[j, i] != 1: output_label[j, i] = -100 # If this location should not be masked return input_ids, output_label def get_char_token_with_relaxation(tokenized, beg, end, batch_index = None): beg_pos = tokenized.char_to_token(batch_index, beg) end_pos = tokenized.char_to_token(batch_index, end - 1) if beg_pos is None: try: beg_pos = tokenized.char_to_token(batch_index, beg + 1) if beg_pos is None: beg_pos = tokenized.char_to_token(batch_index, beg + 2) except: beg_pos = None if end_pos is None: try: end_pos = tokenized.char_to_token(batch_index, end - 2) if end_pos is None: end_pos = tokenized.char_to_token(batch_index, end - 3) except: end_pos = None if beg_pos is None or end_pos is None: return None, None return beg_pos, end_pos + 1 class GeneralizedVLRCNN(nn.Module): """ Main class for Generalized R-CNN. Currently supports boxes and masks. It consists of three main parts: - backbone - rpn - heads: takes the features + the proposals from the RPN and computes detections / masks from it. """ def __init__(self, cfg): super(GeneralizedVLRCNN, self).__init__() self.cfg = cfg self.fusion_in_backbone = cfg.MODEL.SWINT.VERSION == "fusion" # visual encoder backbone = build_backbone(cfg) # language encoder if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": # self.tokenizer = build_tokenizer("clip") from transformers import CLIPTokenizerFast if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: print("Reuse token 'ðŁĴij' (token_id = 49404) for mask token!") self.tokenizer = CLIPTokenizerFast.from_pretrained( "openai/clip-vit-base-patch32", from_slow=True, mask_token="ðŁĴij" ) else: self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True) else: self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) self.tokenizer_vocab = self.tokenizer.get_vocab() self.tokenizer_vocab_ids = [item for key, item in self.tokenizer_vocab.items()] # if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": # self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") # elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "roberta-base": # self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") # else: # raise NotImplementedError # self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) language_backbone = build_language_backbone(cfg) if self.fusion_in_backbone: self.fusion_backbone = build_fusion_backbone( backbone, language_backbone, cfg.MODEL.BACKBONE.FUSION_VERSION, add_linear_layer=cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER, ) else: self.backbone = backbone self.language_backbone = language_backbone self.rpn = build_rpn(cfg) self.roi_heads = build_roi_heads(cfg) self.DEBUG = cfg.MODEL.DEBUG self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE self.freeze_fpn = cfg.MODEL.FPN.FREEZE self.freeze_rpn = cfg.MODEL.RPN.FREEZE self.add_linear_layer = cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER self.force_boxes = cfg.MODEL.RPN.FORCE_BOXES if cfg.MODEL.LINEAR_PROB: assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!" if self.fusion_in_backbone: if hasattr(self.fusion_backbone.backbone, "fpn"): assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!" else: if hasattr(self.backbone, "fpn"): assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!" self.linear_prob = cfg.MODEL.LINEAR_PROB self.freeze_cls_logits = cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS if cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: # disable cls_logits if hasattr(self.rpn.head, "cls_logits"): for p in self.rpn.head.cls_logits.parameters(): p.requires_grad = False self.freeze_language_backbone = self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: if self.fusion_in_backbone: for p in self.fusion_backbone.language_backbone.parameters(): p.requires_grad = False else: for p in self.language_backbone.parameters(): p.requires_grad = False self.use_mlm_loss = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS self.mlm_loss_for_only_positives = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_FOR_ONLY_POSITIVES if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER and not self.fusion_in_backbone: self.tunable_linear = torch.nn.Linear(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, 1000, bias=False) self.tunable_linear.weight.data.fill_(0.0) def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(GeneralizedVLRCNN, self).train(mode) if self.freeze_backbone: if self.fusion_in_backbone: self.fusion_backbone.backbone.body.eval() for p in self.fusion_backbone.backbone.body.parameters(): p.requires_grad = False else: self.backbone.body.eval() for p in self.backbone.body.parameters(): p.requires_grad = False if self.freeze_fpn: if self.fusion_in_backbone: self.fusion_backbone.backbone.fpn.eval() for p in self.fusion_backbone.backbone.fpn.parameters(): p.requires_grad = False else: self.backbone.fpn.eval() for p in self.backbone.fpn.parameters(): p.requires_grad = False if self.freeze_rpn: if hasattr(self.rpn, "head"): self.rpn.head.eval() for p in self.rpn.parameters(): p.requires_grad = False if self.linear_prob: if self.rpn is not None: for key, value in self.rpn.named_parameters(): if not ( "bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key or "dot_product_projection_text" in key or "head.log_scale" in key or "head.bias_lang" in key or "head.bias0" in key ): value.requires_grad = False if self.roi_heads is not None: for key, value in self.roi_heads.named_parameters(): if not ( "bbox_pred" in key or "cls_logits" in key or "centerness" in key or "cosine_scale" in key or "dot_product_projection_text" in key or "head.log_scale" in key or "head.bias_lang" in key or "head.bias0" in key ): value.requires_grad = False if self.freeze_cls_logits: if hasattr(self.rpn.head, "cls_logits"): self.rpn.head.cls_logits.eval() for p in self.rpn.head.cls_logits.parameters(): p.requires_grad = False if self.add_linear_layer: if not self.fusion_in_backbone: if self.rpn is not None: for key, p in self.rpn.named_parameters(): if "tunable_linear" in key: p.requires_grad = True else: for key, p in self.fusion_backbone.named_parameters(): if "tunable_linear" in key: p.requires_grad = True if self.freeze_language_backbone: if self.fusion_in_backbone: self.fusion_backbone.language_backbone.eval() for p in self.fusion_backbone.language_backbone.parameters(): p.requires_grad = False else: self.language_backbone.eval() for p in self.language_backbone.parameters(): p.requires_grad = False def forward(self, images, targets=None, captions=None, positive_map=None, greenlight_map=None, spans = None, span_map = None): """ Arguments: images (list[Tensor] or ImageList): images to be processed targets (list[BoxList]): ground-truth boxes present in the image (optional) mask_black_list: batch x 256, indicates whether or not a certain token is maskable or not Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. During testing, it returns list[BoxList] contains additional fields like `scores`, `labels` and `mask` (for Mask R-CNN models). """ if self.training and targets is None: raise ValueError("In training mode, targets should be passed") images = to_image_list(images) # batch_size = images.tensors.shape[0] device = images.tensors.device # if we use the advanced span prediction version, we need to do both preprocessing and postprocessing if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SPAN_VERSION is not None and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SPAN_VERSION.startswith("v2"): if spans is None: spans = [i.extra_fields['spans'] if "spans" in i.extra_fields else [] for i in targets] # if we did not pass the spans explicitly assert(len(spans) == len(captions)) new_captions = [] mapping_batch_span_to_caption_num = {} # (batch_num, start, end) -> caption_num mapping_batch_to_caption_num = {} corrected_spans = deepcopy(spans) for i in range(len(captions)): if len(spans[i]) == 0: # if this instance does not have span mapping_batch_to_caption_num[i] = len(new_captions) new_captions.append(captions[i]) continue for j in range(len(spans[i])): ''' spans[i][j]: [[230, 241], [241, 254], [254, 269], [269, 298],] ''' valid_spans = [k for k in spans[i][j] if k[0] != -1] if "independent" in self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SPAN_VERSION: for k, span_i_j_k in enumerate(valid_spans): mapping_batch_span_to_caption_num[(i, span_i_j_k[0], span_i_j_k[1])] = len(new_captions) corrected_spans[i][j][k] = (0, span_i_j_k[1] - span_i_j_k[0]) new_captions.append(captions[i][span_i_j_k[0]:span_i_j_k[1]]) else: start = valid_spans[0][0] end = valid_spans[-1][-1] # rewrite the spans !! corrected_spans[i][j] = [(k[0] - start, k[1] - start) for k in spans[i][j]] for k in valid_spans: mapping_batch_span_to_caption_num[(i, k[0], k[1])] = len(new_captions) #mapping_batch_to_caption_num[i] = len(new_captions) new_captions.append(captions[i][start:end]) captions = new_captions padding_method = "longest" #print(new_captions) else: mapping_batch_span_to_caption_num = None padding_method = "max_length" if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest" # language embedding language_dict_features = {} if captions is not None: # print(captions[0]) tokenized = self.tokenizer.batch_encode_plus( captions, max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, padding=padding_method, return_special_tokens_mask=True, return_tensors="pt", truncation=True, ).to(device) if self.use_mlm_loss: if not self.mlm_loss_for_only_positives: greenlight_map = None input_ids, mlm_labels = random_word( input_ids=tokenized.input_ids, mask_token_id=self.tokenizer.mask_token_id, vocabs=self.tokenizer_vocab_ids, padding_token_id=self.tokenizer.pad_token_id, greenlight_map=greenlight_map, ) else: input_ids = tokenized.input_ids mlm_labels = None tokenizer_input = {"input_ids": input_ids, "attention_mask": tokenized.attention_mask} if not self.fusion_in_backbone: if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: with torch.no_grad(): language_dict_features = self.language_backbone(tokenizer_input) else: language_dict_features = self.language_backbone(tokenizer_input) # ONE HOT if self.cfg.DATASETS.ONE_HOT: new_masks = torch.zeros_like( language_dict_features["masks"], device=language_dict_features["masks"].device ) new_masks[:, : self.cfg.MODEL.DYHEAD.NUM_CLASSES] = 1 language_dict_features["masks"] = new_masks # MASK ALL SPECIAL TOKENS if self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL: language_dict_features["masks"] = 1 - tokenized.special_tokens_mask language_dict_features["mlm_labels"] = mlm_labels if not self.fusion_in_backbone: # visual embedding swint_feature_c4 = None if "vl" in self.cfg.MODEL.SWINT.VERSION: # the backbone only updates the "hidden" field in language_dict_features inputs = {"img": images.tensors, "lang": language_dict_features} visual_features, language_dict_features, swint_feature_c4 = self.backbone(inputs) else: visual_features = self.backbone(images.tensors) else: visual_features, language_dict_features, swint_feature_c4 = self.fusion_backbone(tokenizer_input, images) language_dict_features["mlm_labels"] = mlm_labels # add the prompt tuning linear layer if not fusion, for fusion do it inside the backbone if not self.fusion_in_backbone: if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER: embedding = language_dict_features["embedded"] embedding = self.tunable_linear.weight[: embedding.size(1), :].unsqueeze(0) + embedding language_dict_features["embedded"] = embedding language_dict_features["hidden"] = ( self.tunable_linear.weight[: embedding.size(1), :].unsqueeze(0) + language_dict_features["hidden"] ) # if we do span prediction if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SPAN_VERSION is not None: loss_version = self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SPAN_VERSION.split(".")[0] pooling_version = self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SPAN_VERSION.split(".")[-1] # will just override everything # Step 1. get the spans embedding = language_dict_features["hidden"] if spans is None: spans = [i.extra_fields['spans'] if "spans" in i.extra_fields else [] for i in targets] # if we did not pass the spans explicitly if mapping_batch_span_to_caption_num is not None: # need to do a remapping flatterned_spans = [] for i in spans: _ = [] for j in i: _.extend(j) flatterned_spans.append(_) spans = flatterned_spans # flattern corrected spans flatterned_corrected_spans = [] for i in corrected_spans: _ = [] for j in i: _.extend(j) flatterned_corrected_spans.append(_) corrected_spans = flatterned_corrected_spans max_span_num = max([len(i) for i in spans]) # go over the batch, see if there is an instance without spans; if so, we override the span_num to the token_num of that instance for i, spans_i in enumerate(spans): if len(spans_i) == 0: # no spans text_length = sum(tokenized.attention_mask[mapping_batch_to_caption_num[i]]) max_span_num = max(text_length, max_span_num) # override # Step 2. Get the Masks span_masks = torch.zeros((len(spans), max_span_num), device=embedding.device, dtype=torch.long) for i, spans_i in enumerate(spans): if len(spans_i) == 0: # this would be the text masks text_mask_i = tokenized.attention_mask[mapping_batch_to_caption_num[i]] text_length = sum(text_mask_i) span_masks[i, : text_length] = text_mask_i[:text_length] else: span_masks[i, : len(spans_i)] = 1 # Step 3. get the span features span_features = torch.zeros((len(spans), max_span_num, embedding.size(2)), device=embedding.device, dtype=embedding.dtype) # the complexity is just batch x span_num; should be begign for a foor loop for i, spans_i in enumerate(spans): if len(spans_i) == 0: # directly override with the embedding __len = min(max_span_num, embedding[mapping_batch_to_caption_num[i]].size(0)) span_features[i, :__len, :] = embedding[mapping_batch_to_caption_num[i], :__len, :] else: for j, span in enumerate(spans_i): # first need to get the correct tokenized version mapped_sentence_index = mapping_batch_span_to_caption_num[(i, span[0], span[1])] # here we use the original span start, end = get_char_token_with_relaxation(tokenized, corrected_spans[i][j][0], corrected_spans[i][j][1], batch_index = mapped_sentence_index) # here use the span location after we have partitioned the sentence if start is None or end is None: span_masks[i, j] = 0 # mark this span as invalid if pooling_version == "mean": span_rep_i_j = torch.mean(embedding[mapped_sentence_index, start:end, :], dim=0) elif pooling_version == "max": span_rep_i_j = torch.max(embedding[mapped_sentence_index, start:end, :], dim=0)[0] span_features[i, j, :] = span_rep_i_j else: assert(0) # max_span_num = max([len(i) for i in spans]) # # Step 2. Get the Masks # span_masks = torch.zeros((len(spans), max_span_num), device=embedding.device, dtype=torch.long) # for i, spans_i in enumerate(spans): # span_masks[i, : len(spans_i)] = 1 # # Step 3. get the span features # span_features = torch.zeros((len(spans), max_span_num, embedding.size(2)), device=embedding.device, dtype=embedding.dtype) # # the complexity is just batch x span_num; should be begign for a foor loop # for i, spans in enumerate(spans): # for j, span in enumerate(spans): # # span records the char location; needs to convert to token location first # start, end = get_char_token_with_relaxation(tokenized, span[0], span[1], batch_index = i) # if start is None or end is None: # span_masks[i, j] = 0 # mark this span as invalid # if pooling_version == "mean": # span_rep_i_j = torch.mean(embedding[i, start:end, :], dim=0) # elif pooling_version == "max": # span_rep_i_j = torch.max(embedding[i, start:end, :], dim=0)[0] # span_features[i, j, :] = span_rep_i_j # Step 4. Rewrite the labels (?) # we need to rewrite targets, positive_map, text_masks, text_embeddings if span_map is None: span_map = torch.zeros((positive_map.size(0), max_span_num), device=embedding.device, dtype=torch.float) _all_span_map_flattern = [] # box_num x span_num for target_i in targets: if "span_map" in target_i.extra_fields: _all_span_map_flattern.extend([j for j in target_i.extra_fields["span_map"]]) else: # if not, create a list of empty lists num_box = target_i.bbox.size(0) _all_span_map_flattern.extend([[]] * num_box) # very important assert(len(_all_span_map_flattern) == positive_map.size(0)) for i, span_map_i in enumerate(_all_span_map_flattern): if len(span_map_i) == 0: seq_len = min(max_span_num, positive_map.size(1)) span_map[i, :seq_len] = positive_map[i, :seq_len] # use the original positive map in this case! else: span_map[i, :len(span_map_i)] = span_map_i # Step 5. Override positive_map = span_map language_dict_features["masks"] = span_masks language_dict_features["embedded"] = span_features language_dict_features["hidden"] = span_features if targets is not None: for i in targets: if "span_map" in i.extra_fields: i.extra_fields["positive_map"] = i.extra_fields["span_map"] # override if span # rpn force boxes if targets: targets = [target.to(device) for target in targets if target is not None] if self.force_boxes: proposals = [] for t in targets: tb = t.copy_with_fields(["labels"]) tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device)) proposals.append(tb) if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: _, proposal_losses, fused_visual_features = self.rpn( images, visual_features, targets, language_dict_features, positive_map, captions, swint_feature_c4 ) elif self.training: null_loss = 0 for key, param in self.rpn.named_parameters(): null_loss += 0.0 * param.sum() proposal_losses = {("rpn_null_loss", null_loss)} else: proposals, proposal_losses, fused_visual_features = self.rpn( images, visual_features, targets, language_dict_features, positive_map, captions, swint_feature_c4 ) if self.roi_heads: if not self.training: assert len(proposals) == 1, "Evaluation batch size per GPU should be 1!" if len(proposals[0]) == 0: return proposals if self.cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL"): if self.training: # "Only support VL mask head right now!!" assert len(targets) == 1 and len(targets[0]) == len( positive_map ), "shape match assert for mask head!!" # Not necessary but as a safe guard: # use the binary 0/1 positive map to replace the normalized positive map targets[0].add_field("positive_map", positive_map) # TODO: make sure that this use of language_dict_features is correct!! Its content should be changed in self.rpn if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: x, result, detector_losses = self.roi_heads( fused_visual_features, proposals, targets, language_dict_features=language_dict_features, positive_map_label_to_token=positive_map if not self.training else None, ) else: x, result, detector_losses = self.roi_heads( visual_features, proposals, targets, language_dict_features=language_dict_features, positive_map_label_to_token=positive_map if not self.training else None, ) else: # RPN-only models don't have roi_heads x = visual_features result = proposals detector_losses = {} if self.training: losses = {} losses.update(detector_losses) losses.update(proposal_losses) return losses return result