GLEE_demo / GLEE /glee /models /glee_model.py
wjf5203
add video func support
8468984
raw
history blame
No virus
22.4 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
"""
import torch
import torch.nn.functional as F
from torch import nn
# from ..backbone import build_backbone, Backbone
# from ..body.encoder import build_encoder
# from ..body.decoder import build_decoder
from detectron2.modeling import build_backbone
from .pixel_decoder.maskdino_encoder import build_pixel_decoder
from .transformer_decoder.maskdino_decoder import build_transformer_decoder
import random
from transformers import AutoTokenizer
from collections import OrderedDict
from ..modules.point_features import point_sample
from timm.models.layers import trunc_normal_
from transformers import CLIPTokenizer,CLIPTextModel
from .vos_utils import masks_to_boxes, FeatureFuser
import numpy as np
import math
def rand_sample(x, max_len):
if x.shape[1] <= max_len:
return x
else:
rand_idx = torch.randperm(x.shape[1])[:max_len]
return x[:,rand_idx]
def agg_lang_feat(features, mask, pool_type="average"):
"""average pooling of language features"""
# feat: (bs, seq_len, C)
# mask: (bs, seq_len)
if pool_type == "average":
embedded = features * mask.unsqueeze(-1).float() # use mask to zero out invalid token features
aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
elif pool_type == "max":
out = []
for i in range(len(features)):
pool_feat, _ = torch.max(features[i][mask[i]], 0) # (L, C) -> (C, )
out.append(pool_feat)
aggregate = torch.stack(out, dim=0) # (bs, C)
else:
raise ValueError("pool_type should be average or max")
return aggregate
class GLEE_Model(nn.Module):
"""
Main class for mask classification semantic segmentation architectures.
"""
def __init__(self, cfg, matcher, device, video_info, contras_mean):
super().__init__()
self.cfg = cfg
self.matcher = matcher
self.backbone = build_backbone(cfg)
output_channels = [v for k,v in self.backbone._out_feature_channels.items()]
self.sot_fuser = FeatureFuser(output_channels[-3:], 256)
self.tokenizer = CLIPTokenizer.from_pretrained('GLEE/clip_vit_base_patch32')
self.tokenizer.add_special_tokens({'cls_token': self.tokenizer.eos_token})
self.text_encoder = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
# self.text_encoder_teacher = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
self.lang_encoder = None
# for p in self.text_encoder_teacher.parameters():
# p.requires_grad = False
self.lang_projection = nn.Parameter(torch.rand(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, cfg.MODEL.DIM_PROJ))
self.text_encode_type = 'clip_teacher'
# self.lang_encoder = None
self.pixel_decoder = build_pixel_decoder(cfg, self.backbone.output_shape())
transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
self.predictor = build_transformer_decoder(cfg, transformer_predictor_in_channels, lang_encoder = self.lang_encoder, mask_classification=True,)
self.to(device)
self.video_info = video_info
self.contras_mean = contras_mean
self.track_loss_version = cfg.MODEL.TRACK_VERSION
self.no_mask_tasks = ['obj365', 'obj365_clip','openimage', 'openimage_clip', 'vg', 'grit', 'bdd_det', 'bdd_track_box']
# for visual prompt
hidden_dim = 256
self.max_spatial_len = [512,512,512,512]
self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(4)])
trunc_normal_(self.mask_sptial_embed[0], std=.02)
trunc_normal_(self.mask_sptial_embed[1], std=.02)
trunc_normal_(self.mask_sptial_embed[2], std=.02)
trunc_normal_(self.mask_sptial_embed[3], std=.02)
# learnable positive negative indicator
self.pn_indicator = nn.Embedding(2, hidden_dim)
@property
def device(self):
return self.pixel_mean.device
def forward(self, images, prompts, task, targets=None, batch_name_list=None, is_train = True, visual_prompt_type='scribble'):
extra = {}
# dist_loss = None
early_semantic = None
if self.text_encode_type == "clip_teacher":
if task not in ['grounding','rvos']:
assert batch_name_list
calsses_name_list = batch_name_list
tokenized = self.tokenizer.batch_encode_plus(calsses_name_list,
max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, # 256
padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", # max_length
return_special_tokens_mask=True,
return_tensors='pt',
truncation=True).to(images.device)
texts = (tokenized['input_ids'], tokenized['attention_mask'])
token_x = self.text_encoder(*texts)['last_hidden_state']
valid_mask = tokenized['attention_mask'].bool()
# token_x_teacher = self.text_encoder_teacher(*texts)['last_hidden_state']
# if is_train:
# dist_loss = F.mse_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
# F.l2_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
token_x = token_x @ self.lang_projection
lang_feat_pool = agg_lang_feat(token_x, tokenized['attention_mask'], pool_type="average") # (bs, 768)
extra['class_embeddings'] = lang_feat_pool
if True: # early_fusion
gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0 #[bs,L]
early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask}
if 'grounding' in prompts:
if self.text_encode_type == 'clip_frozen' or self.text_encode_type == 'clip_teacher':
tokens = self.tokenizer(
prompts['grounding'], padding='max_length', truncation=True, max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, return_tensors='pt'
)
tokens = {key: value.to(images.device) for key, value in tokens.items()}
texts = (tokens['input_ids'], tokens['attention_mask'])
x = self.text_encoder(*texts)
token_x = x['last_hidden_state']
token_x = token_x @ self.lang_projection
extra['grounding_tokens'] = token_x.permute(1,0,2) #[len,bz,C]
non_zero_query_mask = tokens['attention_mask']
lang_feat_pool = agg_lang_feat(token_x, non_zero_query_mask, pool_type="average").unsqueeze(1) # (bs, 1, 768)
dist_loss = (lang_feat_pool*0).sum()
extra['grounding_nonzero_mask'] = ~non_zero_query_mask.bool() # [bz,len]
extra['grounding_class'] = lang_feat_pool.squeeze(1) #[bz,C
# gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
# gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
# gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0 #[bs,L]
# early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask}
early_semantic = {"hidden":token_x.float(),"masks":tokens['attention_mask']>0}
if isinstance(images,torch.Tensor):
features = self.backbone(images)
else:
features = self.backbone(images.tensor)
if 'spatial' in prompts:
## setp 1,2,3
key_images = [ images ] #bz*[1,3,H,W]
key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]
prompt_mode = visual_prompt_type
ref_feats, ref_masks = self.get_template(key_images, key_promptmasks, prompt_mode)
early_fusion = {"hidden":ref_feats,"masks":ref_masks}
if early_semantic is None:
early_semantic = early_fusion
else:
early_semantic["hidden"] = torch.cat([early_semantic["hidden"],early_fusion["hidden"]],dim=1)
early_semantic["masks"] = torch.cat([early_semantic["masks"],early_fusion["masks"]],dim=1)
# bz = len(images)//2
mask_features, _, multi_scale_features, zero_loss = self.pixel_decoder.forward_features(features, masks=None, early_fusion = early_semantic)
if 'spatial' in prompts:
pos_masks = prompts['spatial']
# neg_masks = [~p for p in prompts['spatial']]
neg_masks = [p&False for p in prompts['spatial']]
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
_,h,w = extra['spatial_query_pos_mask'][0].shape
divisor = torch.tensor([h,w], device=mask_features.device)[None,]
# Get mean pos spatial query
non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)
non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)
spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True) #[(N, C, P)
spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num() # [1,bz,C]
# Get mean neg spatial query
non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
# Get layerwise spatial query
src_spatial_queries = []
src_spatial_maskings = []
for i in range(len(multi_scale_features)):
bs,dc,h,w = multi_scale_features[i].shape
# src_mask_features = multi_scale_features[i].view(h,w,bs,dc)
src_mask_features = multi_scale_features[i].permute(2,3,0,1)
src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
non_zero_query_point[non_zero_query_mask] = 0
spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
src_spatial_queries += [spatial_tokens]
src_spatial_maskings += [non_zero_query_mask]
extra['visual_prompt_tokens'] = src_spatial_queries #[len,bz,C]
extra['visual_prompt_nonzero_mask'] = src_spatial_maskings # [bz,len]
outputs = self.predictor(multi_scale_features, mask_features, extra=extra, task=task, masks=None, targets=targets)
return outputs
def vos_step1(self, previous_image, prompts, task, targets=None, batch_name_list=None, is_train = False):
extra = {}
if isinstance(previous_image,torch.Tensor):
features = self.backbone(previous_image)
else:
features = self.backbone(previous_image.tensor)
# bz = len(images)//2
## setp 1,2,3
key_images = [previous_image] #bz*[1,3,H,W]
key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]
ref_feats, ref_masks = self.get_template(key_images, key_promptmasks)
early_fusion = {"hidden":ref_feats,"masks":ref_masks}
mask_features, _, multi_scale_features, zero_loss = self.pixel_decoder.forward_features(features, masks=None, early_fusion = early_fusion)
prompt_multi_scale_features = multi_scale_features+[mask_features]
if 'spatial' in prompts:
pos_masks = prompts['spatial']
# neg_masks = [~p for p in prompts['spatial']]
neg_masks = [p&False for p in prompts['spatial']]
extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})
# import pdb;pdb.set_trace()
_,h,w = extra['spatial_query_pos_mask'][0].shape
divisor = torch.tensor([h,w], device=mask_features.device)[None,]
# Get mean pos spatial query
non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
# [:,1:]第一个维度是指示属于那个batch,原本这里的mshape是【num_inst,H,W】,得到的nonzero 是【num_point,3】,[:,1:]是xy坐标,
# #这里舍弃第一个维度是表示每张图片上prompt覆盖的物体的point混在一起采样,没有instance之间的区分. 因此每个图片都得到一个[512,2]的point set,是采样过后的正样本
non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2) # 把list中的结果通过padding concat到一起,得到的是[bz,512,2]
non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0) # 把xy坐标相加小于0的找出来
spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True) #[(N, C, P)
spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num() # [1,bz,C]
# import pdb;pdb.set_trace()
# Get mean neg spatial query
non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()
# Get layerwise spatial query
src_spatial_queries = []
src_spatial_maskings = []
for i in range(len(prompt_multi_scale_features)):
bs,dc,h,w = prompt_multi_scale_features[i].shape
# src_mask_features = multi_scale_features[i].view(h,w,bs,dc)
src_mask_features = prompt_multi_scale_features[i].permute(2,3,0,1)
# import pdb;pdb.set_trace()
src_mask_features = src_mask_features @ self.mask_sptial_embed[i]
non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
# import pdb;pdb.set_trace()
non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
non_zero_query_point[non_zero_query_mask] = 0
# import pdb;pdb.set_trace()
spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]
src_spatial_queries += [spatial_tokens]
src_spatial_maskings += [non_zero_query_mask]
extra['visual_prompt_tokens'] = src_spatial_queries #[len,bz,C]
extra['visual_prompt_nonzero_mask'] = src_spatial_maskings # [bz,len]
return early_fusion, extra
def vos_step2(self, images, task, language_dict_features, last_extra, targets=None, batch_name_list=None, is_train = False):
extra = last_extra
dist_loss = None
if True:
if task not in ['grounding','rvos']:
assert batch_name_list
calsses_name_list = batch_name_list
tokenized = self.tokenizer.batch_encode_plus(calsses_name_list,
max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, # 256
padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", # max_length
return_special_tokens_mask=True,
return_tensors='pt',
truncation=True).to("cuda")
texts = (tokenized['input_ids'], tokenized['attention_mask'])
token_x = self.text_encoder(*texts)['last_hidden_state']
token_x = token_x @ self.lang_projection
lang_feat_pool = agg_lang_feat(token_x, tokenized['attention_mask'], pool_type="average") # (bs, 768)
extra['class_embeddings'] = lang_feat_pool
if isinstance(images,torch.Tensor):
features = self.backbone(images)
else:
features = self.backbone(images.tensor)
# bz = len(images)//2
# import pdb;pdb.set_trace()
mask_features, _, multi_scale_features, zero_loss = self.pixel_decoder.forward_features(features, masks=None, early_fusion = language_dict_features)
outputs = self.predictor(multi_scale_features, mask_features, extra=extra, task=task, masks=None, targets=targets)
return outputs
def get_template(self, imgs, pad_masks, prompt_mode='scribble'):
"""img: (N, 3, H, W), mask: (N, 1, H, W), bbox: (1, 4)"""
"""get 4-channel template"""
croped_img_with_mask = []
for image_i, mask_i in zip( imgs, pad_masks):
if prompt_mode in ['scribble','point']:
image_with_mask = image_i + mask_i.to(image_i)
else:
image_with_mask = image_i
# image_with_mask = torch.cat([image_i,mask_i.to(image_i)],dim=1) #[1,3,H,W]
box_i = masks_to_boxes(mask_i[0]) #[xyxy]
box_i[:, 2:] = box_i[:, 2:] - box_i[:, :2] #xywh
x, y, w, h = box_i[0].long().tolist()
self.search_area_factor=2
crop_sz = math.ceil(math.sqrt(w * h) * self.search_area_factor)
x1 = max(0,round(x + 0.5 * w - crop_sz * 0.5))
x2 = x1 + crop_sz
y1 = max(0,round(y + 0.5 * h - crop_sz * 0.5))
y2 = y1 + crop_sz
im_crop = image_with_mask[:, :, y1:y2, x1:x2]
# resize
if im_crop.shape[-1] ==0 or im_crop.shape[-2] ==0 :
im_crop = image_with_mask
im_crop = F.interpolate(im_crop, (256,256), mode='bilinear', align_corners=False)
croped_img_with_mask.append(im_crop)
croped_img_with_mask = torch.cat(croped_img_with_mask,dim=0) #[bz,3,256,256]
with torch.no_grad():
ref_srcs = self.backbone(croped_img_with_mask.contiguous())
ref_srcs = [v for k,v in ref_srcs.items()]
ref_feats = self.sot_fuser(ref_srcs[1:]).float() #[bz,256,32,32]
ref_feats = ref_feats.flatten(-2).permute(0, 2, 1) # (bs, L, C)
ref_masks = torch.ones_like(ref_feats[:,:,0])>0 #[bs,L]
return ref_feats, ref_masks