ContextDet-Demo / models /contextdet_blip2.py
yuhangzang
update
a059c46
raw
history blame
10.7 kB
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from util.misc import (NestedTensor, inverse_sigmoid,
nested_tensor_from_tensor_list)
from .blip2_decoder import BLIP2Decoder
from .deformable_detr.backbone import build_backbone
from .deformable_detr.deformable_detr import DeformableDETR
from .transformer import build_ov_transformer
class ContextDET(DeformableDETR):
def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
aux_loss=True, with_box_refine=False, two_stage=False, llm_decoder=None):
super().__init__(backbone, transformer, num_classes, num_queries, num_feature_levels,
aux_loss, with_box_refine, two_stage)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm_decoder = llm_decoder
hidden_dim = transformer.d_model
out_size = self.llm_decoder.model.opt_proj.out_features
self.llm_proj = nn.Linear(out_size, hidden_dim, device=self.device)
self.start_end_proj = nn.Linear(hidden_dim, 2)
for layer in [self.llm_proj, self.start_end_proj]:
nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.zeros_(layer.bias)
# word_embed_proj_dim = llm_decoder.model.opt_model.config.word_embed_proj_dim
vocab_size = llm_decoder.model.opt_model.config.vocab_size
self.fc_logits = nn.Linear(hidden_dim, vocab_size)
def forward(self, samples, blip2_samples, mask_infos=None, task_button=None, threshold=0.3):
logits, hidden_states, input_ids, output_text = self.llm_decoder.model.forward(
blip2_samples, task_button=task_button)
hidden_states = hidden_states.detach()
hidden_states = self.llm_proj(hidden_states)
if not isinstance(samples, NestedTensor):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
pos.append(pos_l)
out = {}
start_end_proj = self.start_end_proj(hidden_states)
out['pred_mlm_logits'] = self.fc_logits(hidden_states)
out['pred_start'] = start_end_proj[:, :, 0:1]
out['pred_end'] = start_end_proj[:, :, 1:2]
out['output_text'] = output_text
if self.training:
k = min([len(mask_info) for mask_info in mask_infos])
k = min(k, 2)
select_ids = [random.sample(mask_info.keys(), k) for mask_info in mask_infos]
# select_ids = [random.choices(list(mask_info.keys()), k=4) for mask_info in mask_infos]
llm_feat = []
for b in range(len(select_ids)):
llm_feat_b = []
hidden_states_b = hidden_states[b, :, :]
for start, end in select_ids[b]:
llm_feat_b.append(hidden_states_b[start: end + 1].mean(dim=0, keepdim=True))
llm_feat.append(torch.cat(llm_feat_b)[None])
llm_feat = torch.cat(llm_feat)
query_embeds = None
if not self.two_stage:
query_embeds = self.query_embed.weight
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = (
self.transformer(srcs, masks, pos, query_embeds, llm_feat, k)
)
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
if lvl == 0:
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.class_embed[lvl](hs[lvl])
tmp = self.bbox_embed[lvl](hs[lvl])
if reference.shape[-1] == 4:
tmp += reference
else:
assert reference.shape[-1] == 2
tmp[..., :2] += reference
outputs_coord = tmp.sigmoid()
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_class = torch.stack(outputs_classes)
outputs_coord = torch.stack(outputs_coords)
out.update({'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1],
'init_reference': init_reference})
out['select_ids'] = select_ids
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
for temp in out["aux_outputs"]:
temp["select_ids"] = select_ids
if self.two_stage:
enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
out['enc_outputs'] = {
'pred_logits': enc_outputs_class,
'pred_boxes': enc_outputs_coord,
'anchors': anchors,
}
else:
bs = len(samples.tensors)
mask_infos_pred = [{} for _ in range(bs)]
llm_feat = []
tokenizer = self.llm_decoder.model.opt_tokenizer
if mask_infos is None:
if task_button == 'Cloze Test':
mask_infos = []
output_texts = []
for b in range(bs):
mask_infos_b = {}
output_texts_b = []
for ind, token in enumerate(input_ids[b]):
if token == tokenizer.mask_token_id:
mask_infos_b[(ind, ind)] = ''
pred_token = out['pred_mlm_logits'][b, ind:ind + 1, :]
pred_token = pred_token.argmax(1).item()
output_texts_b.append( pred_token )
output_texts_b.append( 1437 )
input_ids[b, ind: ind + 1] = pred_token
else:
output_texts_b.append( token.item() )
mask_infos.append(mask_infos_b)
output_texts.append(tokenizer.decode(output_texts_b[1:]))
out['output_text'] = output_texts
else:
mask_infos = []
for b in range(bs):
starts = (out['pred_start'][b, :, 0].sigmoid() > threshold).nonzero().squeeze(1)
ends = (out['pred_end'][b, :, 0].sigmoid() > threshold).nonzero().squeeze(1)
if len(starts) == 0:
starts = out['pred_start'][b, :].argmax(0)
if len(ends) == 0:
ends = out['pred_end'][b, :].argmax(0)
mask_infos_b = {}
for start, end in zip(starts, ends):
mask_infos_b[(int(start), int(end))] = ''
mask_infos.append(mask_infos_b)
for b in range(bs):
llm_feat_b = []
hidden_states_b = hidden_states[b, :, :]
for start, end in mask_infos[b].keys():
llm_feat_b.append(hidden_states_b[start: end + 1].mean(dim=0, keepdim=True))
pred_name = tokenizer.decode(input_ids[b, start: end + 1]).strip()
mask_infos_pred[b][(int(start), int(end))] = pred_name
llm_feat.append(torch.cat(llm_feat_b)[None])
out['mask_infos_pred'] = mask_infos_pred
query_embeds = None
if not self.two_stage:
query_embeds = self.query_embed.weight
outputs_classes_list = []
outputs_coords_list = []
for b in range(bs):
srcs_b = [i[b: b + 1] for i in srcs]
masks_b = [i[b: b + 1] for i in masks]
pos_b = [i[b: b + 1] for i in pos]
k = len(mask_infos[b])
if k == 0:
outputs_classes_list.append(torch.zeros(0, 2).to(self.device))
outputs_coords_list.append(torch.zeros(0, 4).to(self.device))
continue
num_repeat = math.ceil(k / 4)
outputs_classes = []
outputs_coords = []
for ind in range(num_repeat):
llm_feat_b = llm_feat[b][:, ind * 4: (ind + 1) * 4]
hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, anchors = (
self.transformer(srcs_b, masks_b, pos_b, query_embeds, llm_feat_b, llm_feat_b.shape[1])
)
lvl = hs.shape[0] - 1
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
outputs_class = self.class_embed[lvl](hs[lvl])
tmp = self.bbox_embed[lvl](hs[lvl])
if reference.shape[-1] == 4:
tmp += reference
else:
assert reference.shape[-1] == 2
tmp[..., :2] += reference
outputs_coord = tmp.sigmoid()
outputs_classes.append(outputs_class.flatten(0, 1))
outputs_coords.append(outputs_coord.flatten(0, 1))
outputs_classes = torch.cat(outputs_classes)[None]
outputs_coords = torch.cat(outputs_coords)[None]
outputs_classes_list.append(outputs_classes)
outputs_coords_list.append(outputs_coords)
out.update({'pred_logits': outputs_classes_list,
'pred_boxes': outputs_coords_list})
return out