# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ Implements the Generalized VL R-CNN framework """ import torch from torch import nn import torch.nn.functional as F from maskrcnn_benchmark.structures.image_list import to_image_list from maskrcnn_benchmark.structures.bounding_box import BoxList from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist from ..backbone import build_backbone from ..rpn import build_rpn from ..roi_heads import build_roi_heads from ..language_backbone import build_language_backbone from transformers import AutoTokenizer import random import timeit import pdb from copy import deepcopy def random_word(input_ids, mask_token_id, vocabs, padding_token_id, greenlight_map): """ greenlight_map, batch_size x 256 (seq_len): 0 means this location cannot be calculated in the MLM loss -1 means this location cannot be masked!! 1 means this location can be masked and can be calculated in the MLM loss """ output_label = deepcopy(input_ids) for j in range(input_ids.size(0)): for i in range(input_ids.size(1)): prob = random.random() # mask token with probability ratio = 0.15 if greenlight_map is not None and greenlight_map[j,i] == -1: output_label[j,i] = -100 continue if (not input_ids[j,i] == padding_token_id) and prob < ratio: prob /= ratio # 80% randomly change token to mask token if prob < 0.8: input_ids[j,i] = mask_token_id # 10% randomly change token to random token elif prob < 0.9: input_ids[j,i] = random.choice(vocabs) else: # no masking token (will be ignored by loss function later) output_label[j,i] = -100 if greenlight_map is not None and greenlight_map[j,i] != 1: output_label[j,i] = -100 # If this location should not be masked return input_ids, output_label class GeneralizedVLRCNN(nn.Module): """ Main class for Generalized R-CNN. Currently supports boxes and masks. It consists of three main parts: - backbone - rpn - heads: takes the features + the proposals from the RPN and computes detections / masks from it. """ def __init__(self, cfg): super(GeneralizedVLRCNN, self).__init__() self.cfg = cfg # visual encoder self.backbone = build_backbone(cfg) # language encoder if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": # self.tokenizer = build_tokenizer("clip") from transformers import CLIPTokenizerFast if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: print("Reuse token 'ðŁĴij' (token_id = 49404) for mask token!") self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True, mask_token='ðŁĴij') else: self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True) else: self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) self.tokenizer_vocab = self.tokenizer.get_vocab() self.tokenizer_vocab_ids = [item for key, item in self.tokenizer_vocab.items()] self.language_backbone = build_language_backbone(cfg) self.rpn = build_rpn(cfg) self.roi_heads = build_roi_heads(cfg) self.DEBUG = cfg.MODEL.DEBUG self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE self.freeze_fpn = cfg.MODEL.FPN.FREEZE self.freeze_rpn = cfg.MODEL.RPN.FREEZE self.add_linear_layer = cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER self.force_boxes = cfg.MODEL.RPN.FORCE_BOXES if cfg.MODEL.LINEAR_PROB: assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!" if hasattr(self.backbone, 'fpn'): assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!" self.linear_prob = cfg.MODEL.LINEAR_PROB self.freeze_cls_logits = cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS if cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: # disable cls_logits if hasattr(self.rpn.head, 'cls_logits'): for p in self.rpn.head.cls_logits.parameters(): p.requires_grad = False self.freeze_language_backbone = self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: for p in self.language_backbone.parameters(): p.requires_grad = False self.use_mlm_loss = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS self.mlm_loss_for_only_positives = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_FOR_ONLY_POSITIVES if self.cfg.GLIPKNOW.KNOWLEDGE_FILE: from maskrcnn_benchmark.data.datasets.tsv import load_from_yaml_file self.class_name_to_knowledge = load_from_yaml_file(self.cfg.GLIPKNOW.KNOWLEDGE_FILE) self.class_name_list = sorted([k for k in self.class_name_to_knowledge]) def train(self, mode=True): """Convert the model into training mode while keep layers freezed.""" super(GeneralizedVLRCNN, self).train(mode) if self.freeze_backbone: self.backbone.body.eval() for p in self.backbone.body.parameters(): p.requires_grad = False if self.freeze_fpn: self.backbone.fpn.eval() for p in self.backbone.fpn.parameters(): p.requires_grad = False if self.freeze_rpn: if hasattr(self.rpn, 'head'): self.rpn.head.eval() for p in self.rpn.parameters(): p.requires_grad = False if self.linear_prob: if self.rpn is not None: for key, value in self.rpn.named_parameters(): if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key or 'dot_product_projection_text' in key or 'head.log_scale' in key or 'head.bias_lang' in key or 'head.bias0' in key): value.requires_grad = False if self.roi_heads is not None: for key, value in self.roi_heads.named_parameters(): if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key or 'dot_product_projection_text' in key or 'head.log_scale' in key or 'head.bias_lang' in key or 'head.bias0' in key): value.requires_grad = False if self.freeze_cls_logits: if hasattr(self.rpn.head, 'cls_logits'): self.rpn.head.cls_logits.eval() for p in self.rpn.head.cls_logits.parameters(): p.requires_grad = False if self.add_linear_layer: if self.rpn is not None: for key, p in self.rpn.named_parameters(): if 'tunable_linear' in key: p.requires_grad = True if self.freeze_language_backbone: self.language_backbone.eval() for p in self.language_backbone.parameters(): p.requires_grad = False def forward(self, images, targets=None, captions=None, positive_map=None, greenlight_map=None): """ Arguments: images (list[Tensor] or ImageList): images to be processed targets (list[BoxList]): ground-truth boxes present in the image (optional) mask_black_list: batch x 256, indicates whether or not a certain token is maskable or not Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. During testing, it returns list[BoxList] contains additional fields like `scores`, `labels` and `mask` (for Mask R-CNN models). """ if self.training and targets is None: raise ValueError("In training mode, targets should be passed") images = to_image_list(images) # batch_size = images.tensors.shape[0] device = images.tensors.device if self.cfg.GLIPKNOW.PARALLEL_LANGUAGE_INPUT: language_dict_features, positive_map = self._forward_language_parallel( captions=captions, targets=targets, device=device, positive_map=positive_map) else: # language embedding language_dict_features = {} if captions is not None: #print(captions[0]) tokenized = self.tokenizer.batch_encode_plus(captions, max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", return_special_tokens_mask=True, return_tensors='pt', truncation=True).to(device) if self.use_mlm_loss: if not self.mlm_loss_for_only_positives: greenlight_map = None input_ids, mlm_labels = random_word( input_ids=tokenized.input_ids, mask_token_id=self.tokenizer.mask_token_id, vocabs=self.tokenizer_vocab_ids, padding_token_id=self.tokenizer.pad_token_id, greenlight_map=greenlight_map) else: input_ids = tokenized.input_ids mlm_labels = None tokenizer_input = {"input_ids": input_ids, "attention_mask": tokenized.attention_mask} if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: with torch.no_grad(): language_dict_features = self.language_backbone(tokenizer_input) else: language_dict_features = self.language_backbone(tokenizer_input) # ONE HOT if self.cfg.DATASETS.ONE_HOT: new_masks = torch.zeros_like(language_dict_features['masks'], device=language_dict_features['masks'].device) new_masks[:, :self.cfg.MODEL.DYHEAD.NUM_CLASSES] = 1 language_dict_features['masks'] = new_masks # MASK ALL SPECIAL TOKENS if self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL: language_dict_features["masks"] = 1 - tokenized.special_tokens_mask language_dict_features["mlm_labels"] = mlm_labels # visual embedding swint_feature_c4 = None if 'vl' in self.cfg.MODEL.SWINT.VERSION: # the backbone only updates the "hidden" field in language_dict_features inputs = {"img": images.tensors, "lang": language_dict_features} visual_features, language_dict_features, swint_feature_c4 = self.backbone(inputs) else: visual_features = self.backbone(images.tensors) # rpn force boxes if targets: targets = [target.to(device) for target in targets if target is not None] if self.force_boxes: proposals = [] for t in targets: tb = t.copy_with_fields(["labels"]) tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device)) proposals.append(tb) if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: _, proposal_losses, fused_visual_features = self.rpn( images, visual_features, targets, language_dict_features, positive_map, captions, swint_feature_c4) elif self.training: null_loss = 0 for key, param in self.rpn.named_parameters(): null_loss += 0.0 * param.sum() proposal_losses = {('rpn_null_loss', null_loss)} else: proposals, proposal_losses, fused_visual_features = self.rpn(images, visual_features, targets, language_dict_features, positive_map, captions, swint_feature_c4) if self.roi_heads: if self.cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL"): if self.training: # "Only support VL mask head right now!!" assert len(targets) == 1 and len(targets[0]) == len(positive_map), "shape match assert for mask head!!" # Not necessary but as a safe guard: # use the binary 0/1 positive map to replace the normalized positive map targets[0].add_field("positive_map", positive_map) # TODO: make sure that this use of language_dict_features is correct!! Its content should be changed in self.rpn if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: x, result, detector_losses = self.roi_heads( fused_visual_features, proposals, targets, language_dict_features=language_dict_features, positive_map_label_to_token=positive_map if not self.training else None ) else: x, result, detector_losses = self.roi_heads( visual_features, proposals, targets, language_dict_features=language_dict_features, positive_map_label_to_token=positive_map if not self.training else None ) else: # RPN-only models don't have roi_heads x = visual_features result = proposals detector_losses = {} if self.training: losses = {} losses.update(detector_losses) losses.update(proposal_losses) return losses return result def _forward_language_parallel(self, captions=None, targets=None, device=None, positive_map=None): ktype = self.cfg.GLIPKNOW.KNOWLEDGE_TYPE def _construct_captions_from_class_names(class_names): captions = [] for c in class_names: try: info = self.class_name_to_knowledge[c] cap = info['clean_name'] # combine wiki and gpt3 knowledge if self.cfg.GLIPKNOW.WIKI_AND_GPT3: ktype = 'def_wiki' know_seq = info[ktype] ktype = 'gpt3' if ktype == 'gpt3' or type(info[ktype]) == list: know_seq += ' '.join([seq for seq in info[ktype][:self.cfg.GLIPKNOW.GPT3_NUM] ]) cap += ': ' + know_seq # only one knoweldge source is used else: if ktype and ktype in info and info[ktype]: if ktype == 'gpt3' or type(info[ktype]) == list: know_seq = ' '.join([seq for seq in info[ktype][:self.cfg.GLIPKNOW.GPT3_NUM] ]) else: know_seq = info[ktype] cap += ': ' + know_seq except: cap = c print(f'cap {cap}, c {c}') captions.append(cap) return captions if self.training: assert captions is None assert targets is not None max_classes_per_batch = self.cfg.GLIPKNOW.MAX_NUM_CLASSES_PER_BATCH_TRAIN if max_classes_per_batch >= len(self.class_name_list): shuffled_class_names = self.class_name_list.copy() random.shuffle(shuffled_class_names) if max_classes_per_batch > len(shuffled_class_names): shuffled_class_names.extend(shuffled_class_names[:max_classes_per_batch -len(shuffled_class_names)]) random.shuffle(shuffled_class_names) else: label_list = [] label_to_idx = {} for target_per_im in targets: labels_per_im = target_per_im.get_field('label_names') for label in labels_per_im: if label not in label_to_idx: label_to_idx[label] = len(label_list) label_list.append(label) label_list = label_list[:max_classes_per_batch] if len(label_list) < max_classes_per_batch: all_neg_classes = [c for c in self.class_name_list if c not in label_to_idx] neg_label_list = random.sample(all_neg_classes, max_classes_per_batch - len(label_list)) label_list.extend(neg_label_list) random.shuffle(label_list) shuffled_class_names = label_list label_to_shuffled_idx = {l: i for i, l in enumerate(shuffled_class_names)} total_boxes = sum(len(t) for t in targets) positive_map = torch.zeros((total_boxes, max_classes_per_batch+1), device=device) offset = 0 for target_per_im in targets: labels_per_im = target_per_im.get_field('label_names') for label in labels_per_im: j = label_to_shuffled_idx.get(label, -1) if j >= 0: positive_map[offset, j] = 1 offset += 1 captions = _construct_captions_from_class_names(shuffled_class_names) captions.append('') # onobj at the end, onedet/modeling/rpn/loss.py:719 batch_size = len(targets) else: assert captions is not None batch_size = 1 assert len(captions) == 1 class_names = captions[0] max_classes_per_batch = len(class_names) captions = _construct_captions_from_class_names(class_names) captions.append('') # onobj at the end, onedet/modeling/rpn/loss.py:719 tokenized = self.tokenizer.batch_encode_plus(captions, max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, padding="longest", return_special_tokens_mask=True, return_tensors='pt', truncation=True).to(device) assert not self.use_mlm_loss tokenizer_input = {"input_ids": tokenized.input_ids, "attention_mask": tokenized.attention_mask} if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE: with torch.no_grad(): language_dict_features = self.language_backbone(tokenizer_input) else: language_dict_features = self.language_backbone(tokenizer_input) assert not self.cfg.DATASETS.ONE_HOT assert not self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL agg_type = self.cfg.GLIPKNOW.LAN_FEATURE_AGG_TYPE agg_feats = language_dict_features['hidden'] agg_emb = language_dict_features['embedded'] if agg_type == 'first': agg_feats = agg_feats[:, 0, :] agg_emb = agg_emb[:, 0, :] elif agg_type == 'mean': attn_mask = language_dict_features['masks'] seq_len = attn_mask.sum(-1).unsqueeze(-1).float() agg_feats = agg_feats * attn_mask.unsqueeze(-1).float() agg_feats = agg_feats.sum(1) / seq_len agg_emb = agg_emb * attn_mask.unsqueeze(-1).float() agg_emb = agg_emb.sum(1) / seq_len else: raise ValueError('not supported GLIPKNOW.LAN_FEATURE_AGG_TYPE: {}'.format(agg_type)) expanded_features = agg_feats.unsqueeze(0).repeat(batch_size, 1, 1) expanded_embedding = agg_emb.unsqueeze(0).repeat(batch_size, 1, 1) lang_dict = {} lang_dict["mlm_labels"] = None lang_dict["aggregate"] = None lang_dict["embedded"] = expanded_embedding lang_dict['hidden'] = expanded_features lang_dict["masks"] = torch.ones((batch_size, max_classes_per_batch+1), device=device, dtype=language_dict_features['masks'].dtype) # in GLIP setting, the token at the end of seqence is usually [PAD], and is masked out # if [noobj] is not masked out, the loss sum is very big, as most # anchors are matched to [noobj] lang_dict["masks"][:,-1] = 0 return lang_dict, positive_map