haotiz's picture
initial commit
708dec4
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Implements the Generalized VL R-CNN framework
"""
import torch
from torch import nn
import torch.nn.functional as F
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from ..backbone import build_backbone
from ..rpn import build_rpn
from ..roi_heads import build_roi_heads
from ..language_backbone import build_language_backbone
from transformers import AutoTokenizer
import random
import timeit
import pdb
from copy import deepcopy
def random_word(input_ids, mask_token_id, vocabs, padding_token_id, greenlight_map):
"""
greenlight_map, batch_size x 256 (seq_len):
0 means this location cannot be calculated in the MLM loss
-1 means this location cannot be masked!!
1 means this location can be masked and can be calculated in the MLM loss
"""
output_label = deepcopy(input_ids)
for j in range(input_ids.size(0)):
for i in range(input_ids.size(1)):
prob = random.random()
# mask token with probability
ratio = 0.15
if greenlight_map is not None and greenlight_map[j,i] == -1:
output_label[j,i] = -100
continue
if (not input_ids[j,i] == padding_token_id) and prob < ratio:
prob /= ratio
# 80% randomly change token to mask token
if prob < 0.8:
input_ids[j,i] = mask_token_id
# 10% randomly change token to random token
elif prob < 0.9:
input_ids[j,i] = random.choice(vocabs)
else:
# no masking token (will be ignored by loss function later)
output_label[j,i] = -100
if greenlight_map is not None and greenlight_map[j,i] != 1:
output_label[j,i] = -100 # If this location should not be masked
return input_ids, output_label
class GeneralizedVLRCNN(nn.Module):
"""
Main class for Generalized R-CNN. Currently supports boxes and masks.
It consists of three main parts:
- backbone
- rpn
- heads: takes the features + the proposals from the RPN and computes
detections / masks from it.
"""
def __init__(self, cfg):
super(GeneralizedVLRCNN, self).__init__()
self.cfg = cfg
# visual encoder
self.backbone = build_backbone(cfg)
# language encoder
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
# self.tokenizer = build_tokenizer("clip")
from transformers import CLIPTokenizerFast
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
print("Reuse token 'ðŁĴij</w>' (token_id = 49404) for mask token!")
self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
from_slow=True, mask_token='ðŁĴij</w>')
else:
self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
from_slow=True)
else:
self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE)
self.tokenizer_vocab = self.tokenizer.get_vocab()
self.tokenizer_vocab_ids = [item for key, item in self.tokenizer_vocab.items()]
self.language_backbone = build_language_backbone(cfg)
self.rpn = build_rpn(cfg)
self.roi_heads = build_roi_heads(cfg)
self.DEBUG = cfg.MODEL.DEBUG
self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE
self.freeze_fpn = cfg.MODEL.FPN.FREEZE
self.freeze_rpn = cfg.MODEL.RPN.FREEZE
self.add_linear_layer = cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER
self.force_boxes = cfg.MODEL.RPN.FORCE_BOXES
if cfg.MODEL.LINEAR_PROB:
assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!"
if hasattr(self.backbone, 'fpn'):
assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!"
self.linear_prob = cfg.MODEL.LINEAR_PROB
self.freeze_cls_logits = cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS
if cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
# disable cls_logits
if hasattr(self.rpn.head, 'cls_logits'):
for p in self.rpn.head.cls_logits.parameters():
p.requires_grad = False
self.freeze_language_backbone = self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE
if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
for p in self.language_backbone.parameters():
p.requires_grad = False
self.use_mlm_loss = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS
self.mlm_loss_for_only_positives = cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_FOR_ONLY_POSITIVES
if self.cfg.GLIPKNOW.KNOWLEDGE_FILE:
from maskrcnn_benchmark.data.datasets.tsv import load_from_yaml_file
self.class_name_to_knowledge = load_from_yaml_file(self.cfg.GLIPKNOW.KNOWLEDGE_FILE)
self.class_name_list = sorted([k for k in self.class_name_to_knowledge])
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(GeneralizedVLRCNN, self).train(mode)
if self.freeze_backbone:
self.backbone.body.eval()
for p in self.backbone.body.parameters():
p.requires_grad = False
if self.freeze_fpn:
self.backbone.fpn.eval()
for p in self.backbone.fpn.parameters():
p.requires_grad = False
if self.freeze_rpn:
if hasattr(self.rpn, 'head'):
self.rpn.head.eval()
for p in self.rpn.parameters():
p.requires_grad = False
if self.linear_prob:
if self.rpn is not None:
for key, value in self.rpn.named_parameters():
if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key or 'dot_product_projection_text' in key or 'head.log_scale' in key or 'head.bias_lang' in key or 'head.bias0' in key):
value.requires_grad = False
if self.roi_heads is not None:
for key, value in self.roi_heads.named_parameters():
if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key or 'dot_product_projection_text' in key or 'head.log_scale' in key or 'head.bias_lang' in key or 'head.bias0' in key):
value.requires_grad = False
if self.freeze_cls_logits:
if hasattr(self.rpn.head, 'cls_logits'):
self.rpn.head.cls_logits.eval()
for p in self.rpn.head.cls_logits.parameters():
p.requires_grad = False
if self.add_linear_layer:
if self.rpn is not None:
for key, p in self.rpn.named_parameters():
if 'tunable_linear' in key:
p.requires_grad = True
if self.freeze_language_backbone:
self.language_backbone.eval()
for p in self.language_backbone.parameters():
p.requires_grad = False
def forward(self,
images,
targets=None,
captions=None,
positive_map=None,
greenlight_map=None):
"""
Arguments:
images (list[Tensor] or ImageList): images to be processed
targets (list[BoxList]): ground-truth boxes present in the image (optional)
mask_black_list: batch x 256, indicates whether or not a certain token is maskable or not
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
images = to_image_list(images)
# batch_size = images.tensors.shape[0]
device = images.tensors.device
if self.cfg.GLIPKNOW.PARALLEL_LANGUAGE_INPUT:
language_dict_features, positive_map = self._forward_language_parallel(
captions=captions, targets=targets, device=device,
positive_map=positive_map)
else:
# language embedding
language_dict_features = {}
if captions is not None:
#print(captions[0])
tokenized = self.tokenizer.batch_encode_plus(captions,
max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest",
return_special_tokens_mask=True,
return_tensors='pt',
truncation=True).to(device)
if self.use_mlm_loss:
if not self.mlm_loss_for_only_positives:
greenlight_map = None
input_ids, mlm_labels = random_word(
input_ids=tokenized.input_ids,
mask_token_id=self.tokenizer.mask_token_id,
vocabs=self.tokenizer_vocab_ids,
padding_token_id=self.tokenizer.pad_token_id,
greenlight_map=greenlight_map)
else:
input_ids = tokenized.input_ids
mlm_labels = None
tokenizer_input = {"input_ids": input_ids,
"attention_mask": tokenized.attention_mask}
if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
with torch.no_grad():
language_dict_features = self.language_backbone(tokenizer_input)
else:
language_dict_features = self.language_backbone(tokenizer_input)
# ONE HOT
if self.cfg.DATASETS.ONE_HOT:
new_masks = torch.zeros_like(language_dict_features['masks'],
device=language_dict_features['masks'].device)
new_masks[:, :self.cfg.MODEL.DYHEAD.NUM_CLASSES] = 1
language_dict_features['masks'] = new_masks
# MASK ALL SPECIAL TOKENS
if self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL:
language_dict_features["masks"] = 1 - tokenized.special_tokens_mask
language_dict_features["mlm_labels"] = mlm_labels
# visual embedding
swint_feature_c4 = None
if 'vl' in self.cfg.MODEL.SWINT.VERSION:
# the backbone only updates the "hidden" field in language_dict_features
inputs = {"img": images.tensors, "lang": language_dict_features}
visual_features, language_dict_features, swint_feature_c4 = self.backbone(inputs)
else:
visual_features = self.backbone(images.tensors)
# rpn force boxes
if targets:
targets = [target.to(device)
for target in targets if target is not None]
if self.force_boxes:
proposals = []
for t in targets:
tb = t.copy_with_fields(["labels"])
tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device))
proposals.append(tb)
if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES:
_, proposal_losses, fused_visual_features = self.rpn(
images, visual_features, targets, language_dict_features,
positive_map, captions, swint_feature_c4)
elif self.training:
null_loss = 0
for key, param in self.rpn.named_parameters():
null_loss += 0.0 * param.sum()
proposal_losses = {('rpn_null_loss', null_loss)}
else:
proposals, proposal_losses, fused_visual_features = self.rpn(images, visual_features, targets, language_dict_features, positive_map,
captions, swint_feature_c4)
if self.roi_heads:
if self.cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL"):
if self.training:
# "Only support VL mask head right now!!"
assert len(targets) == 1 and len(targets[0]) == len(positive_map), "shape match assert for mask head!!"
# Not necessary but as a safe guard:
# use the binary 0/1 positive map to replace the normalized positive map
targets[0].add_field("positive_map", positive_map)
# TODO: make sure that this use of language_dict_features is correct!! Its content should be changed in self.rpn
if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES:
x, result, detector_losses = self.roi_heads(
fused_visual_features, proposals, targets,
language_dict_features=language_dict_features,
positive_map_label_to_token=positive_map if not self.training else None
)
else:
x, result, detector_losses = self.roi_heads(
visual_features, proposals, targets,
language_dict_features=language_dict_features,
positive_map_label_to_token=positive_map if not self.training else None
)
else:
# RPN-only models don't have roi_heads
x = visual_features
result = proposals
detector_losses = {}
if self.training:
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return losses
return result
def _forward_language_parallel(self, captions=None, targets=None,
device=None, positive_map=None):
ktype = self.cfg.GLIPKNOW.KNOWLEDGE_TYPE
def _construct_captions_from_class_names(class_names):
captions = []
for c in class_names:
try:
info = self.class_name_to_knowledge[c]
cap = info['clean_name']
# combine wiki and gpt3 knowledge
if self.cfg.GLIPKNOW.WIKI_AND_GPT3:
ktype = 'def_wiki'
know_seq = info[ktype]
ktype = 'gpt3'
if ktype == 'gpt3' or type(info[ktype]) == list:
know_seq += ' '.join([seq for seq in info[ktype][:self.cfg.GLIPKNOW.GPT3_NUM] ])
cap += ': ' + know_seq
# only one knoweldge source is used
else:
if ktype and ktype in info and info[ktype]:
if ktype == 'gpt3' or type(info[ktype]) == list:
know_seq = ' '.join([seq for seq in info[ktype][:self.cfg.GLIPKNOW.GPT3_NUM] ])
else:
know_seq = info[ktype]
cap += ': ' + know_seq
except:
cap = c
print(f'cap {cap}, c {c}')
captions.append(cap)
return captions
if self.training:
assert captions is None
assert targets is not None
max_classes_per_batch = self.cfg.GLIPKNOW.MAX_NUM_CLASSES_PER_BATCH_TRAIN
if max_classes_per_batch >= len(self.class_name_list):
shuffled_class_names = self.class_name_list.copy()
random.shuffle(shuffled_class_names)
if max_classes_per_batch > len(shuffled_class_names):
shuffled_class_names.extend(shuffled_class_names[:max_classes_per_batch
-len(shuffled_class_names)])
random.shuffle(shuffled_class_names)
else:
label_list = []
label_to_idx = {}
for target_per_im in targets:
labels_per_im = target_per_im.get_field('label_names')
for label in labels_per_im:
if label not in label_to_idx:
label_to_idx[label] = len(label_list)
label_list.append(label)
label_list = label_list[:max_classes_per_batch]
if len(label_list) < max_classes_per_batch:
all_neg_classes = [c for c in self.class_name_list if c not
in label_to_idx]
neg_label_list = random.sample(all_neg_classes,
max_classes_per_batch - len(label_list))
label_list.extend(neg_label_list)
random.shuffle(label_list)
shuffled_class_names = label_list
label_to_shuffled_idx = {l: i for i, l in
enumerate(shuffled_class_names)}
total_boxes = sum(len(t) for t in targets)
positive_map = torch.zeros((total_boxes, max_classes_per_batch+1),
device=device)
offset = 0
for target_per_im in targets:
labels_per_im = target_per_im.get_field('label_names')
for label in labels_per_im:
j = label_to_shuffled_idx.get(label, -1)
if j >= 0:
positive_map[offset, j] = 1
offset += 1
captions = _construct_captions_from_class_names(shuffled_class_names)
captions.append('') # onobj at the end, onedet/modeling/rpn/loss.py:719
batch_size = len(targets)
else:
assert captions is not None
batch_size = 1
assert len(captions) == 1
class_names = captions[0]
max_classes_per_batch = len(class_names)
captions = _construct_captions_from_class_names(class_names)
captions.append('') # onobj at the end, onedet/modeling/rpn/loss.py:719
tokenized = self.tokenizer.batch_encode_plus(captions,
max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN,
padding="longest",
return_special_tokens_mask=True,
return_tensors='pt',
truncation=True).to(device)
assert not self.use_mlm_loss
tokenizer_input = {"input_ids": tokenized.input_ids,
"attention_mask": tokenized.attention_mask}
if self.cfg.MODEL.LANGUAGE_BACKBONE.FREEZE:
with torch.no_grad():
language_dict_features = self.language_backbone(tokenizer_input)
else:
language_dict_features = self.language_backbone(tokenizer_input)
assert not self.cfg.DATASETS.ONE_HOT
assert not self.cfg.MODEL.LANGUAGE_BACKBONE.MASK_SPECIAL
agg_type = self.cfg.GLIPKNOW.LAN_FEATURE_AGG_TYPE
agg_feats = language_dict_features['hidden']
agg_emb = language_dict_features['embedded']
if agg_type == 'first':
agg_feats = agg_feats[:, 0, :]
agg_emb = agg_emb[:, 0, :]
elif agg_type == 'mean':
attn_mask = language_dict_features['masks']
seq_len = attn_mask.sum(-1).unsqueeze(-1).float()
agg_feats = agg_feats * attn_mask.unsqueeze(-1).float()
agg_feats = agg_feats.sum(1) / seq_len
agg_emb = agg_emb * attn_mask.unsqueeze(-1).float()
agg_emb = agg_emb.sum(1) / seq_len
else:
raise ValueError('not supported GLIPKNOW.LAN_FEATURE_AGG_TYPE: {}'.format(agg_type))
expanded_features = agg_feats.unsqueeze(0).repeat(batch_size, 1, 1)
expanded_embedding = agg_emb.unsqueeze(0).repeat(batch_size, 1, 1)
lang_dict = {}
lang_dict["mlm_labels"] = None
lang_dict["aggregate"] = None
lang_dict["embedded"] = expanded_embedding
lang_dict['hidden'] = expanded_features
lang_dict["masks"] = torch.ones((batch_size, max_classes_per_batch+1),
device=device, dtype=language_dict_features['masks'].dtype)
# in GLIP setting, the token at the end of seqence is usually [PAD], and is masked out
# if [noobj] is not masked out, the loss sum is very big, as most
# anchors are matched to [noobj]
lang_dict["masks"][:,-1] = 0
return lang_dict, positive_map