import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from utils.math import generate_permute_matrix from utils.image import one_hot_mask from networks.layers.basic import seq_to_2d class AOTEngine(nn.Module): def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_len_long_term=9999): super().__init__() self.cfg = aot_model.cfg self.align_corners = aot_model.cfg.MODEL_ALIGN_CORNERS self.AOT = aot_model self.max_obj_num = aot_model.max_obj_num self.gpu_id = gpu_id self.long_term_mem_gap = long_term_mem_gap self.short_term_mem_skip = short_term_mem_skip self.max_len_long_term = max_len_long_term self.losses = None self.restart_engine() def forward(self, all_frames, all_masks, batch_size, obj_nums, step=0, tf_board=False, use_prev_pred=False, enable_prev_frame=False, use_prev_prob=False): # only used for training if self.losses is None: self._init_losses() self.freeze_id = True if use_prev_pred else False aux_weight = self.aux_weight * max(self.aux_step - step, 0.) / self.aux_step self.offline_encoder(all_frames, all_masks) self.add_reference_frame(frame_step=0, obj_nums=obj_nums) grad_state = torch.no_grad if aux_weight == 0 else torch.enable_grad with grad_state(): ref_aux_loss, ref_aux_mask = self.generate_loss_mask( self.offline_masks[self.frame_step], step) aux_losses = [ref_aux_loss] aux_masks = [ref_aux_mask] curr_losses, curr_masks = [], [] if enable_prev_frame: self.set_prev_frame(frame_step=1) with grad_state(): prev_aux_loss, prev_aux_mask = self.generate_loss_mask( self.offline_masks[self.frame_step], step) aux_losses.append(prev_aux_loss) aux_masks.append(prev_aux_mask) else: self.match_propogate_one_frame() curr_loss, curr_mask, curr_prob = self.generate_loss_mask( self.offline_masks[self.frame_step], step, return_prob=True) self.update_short_term_memory( curr_mask if not use_prev_prob else curr_prob, None if use_prev_pred else self.assign_identity( self.offline_one_hot_masks[self.frame_step])) curr_losses.append(curr_loss) curr_masks.append(curr_mask) self.match_propogate_one_frame() curr_loss, curr_mask, curr_prob = self.generate_loss_mask( self.offline_masks[self.frame_step], step, return_prob=True) curr_losses.append(curr_loss) curr_masks.append(curr_mask) for _ in range(self.total_offline_frame_num - 3): self.update_short_term_memory( curr_mask if not use_prev_prob else curr_prob, None if use_prev_pred else self.assign_identity( self.offline_one_hot_masks[self.frame_step])) self.match_propogate_one_frame() curr_loss, curr_mask, curr_prob = self.generate_loss_mask( self.offline_masks[self.frame_step], step, return_prob=True) curr_losses.append(curr_loss) curr_masks.append(curr_mask) aux_loss = torch.cat(aux_losses, dim=0).mean(dim=0) pred_loss = torch.cat(curr_losses, dim=0).mean(dim=0) loss = aux_weight * aux_loss + pred_loss all_pred_mask = aux_masks + curr_masks all_frame_loss = aux_losses + curr_losses boards = {'image': {}, 'scalar': {}} return loss, all_pred_mask, all_frame_loss, boards def _init_losses(self): cfg = self.cfg from networks.layers.loss import CrossEntropyLoss, SoftJaccordLoss bce_loss = CrossEntropyLoss( cfg.TRAIN_TOP_K_PERCENT_PIXELS, cfg.TRAIN_HARD_MINING_RATIO * cfg.TRAIN_TOTAL_STEPS) iou_loss = SoftJaccordLoss() losses = [bce_loss, iou_loss] loss_weights = [0.5, 0.5] self.losses = nn.ModuleList(losses) self.loss_weights = loss_weights self.aux_weight = cfg.TRAIN_AUX_LOSS_WEIGHT self.aux_step = cfg.TRAIN_TOTAL_STEPS * cfg.TRAIN_AUX_LOSS_RATIO + 1e-5 def encode_one_img_mask(self, img=None, mask=None, frame_step=-1): if frame_step == -1: frame_step = self.frame_step if self.enable_offline_enc: curr_enc_embs = self.offline_enc_embs[frame_step] elif img is None: curr_enc_embs = None else: curr_enc_embs = self.AOT.encode_image(img) if mask is not None: curr_one_hot_mask = one_hot_mask(mask, self.max_obj_num) elif self.enable_offline_enc: curr_one_hot_mask = self.offline_one_hot_masks[frame_step] else: curr_one_hot_mask = None return curr_enc_embs, curr_one_hot_mask def offline_encoder(self, all_frames, all_masks=None): self.enable_offline_enc = True self.offline_frames = all_frames.size(0) // self.batch_size # extract backbone features self.offline_enc_embs = self.split_frames( self.AOT.encode_image(all_frames), self.batch_size) self.total_offline_frame_num = len(self.offline_enc_embs) if all_masks is not None: # extract mask embeddings offline_one_hot_masks = one_hot_mask(all_masks, self.max_obj_num) self.offline_masks = list( torch.split(all_masks, self.batch_size, dim=0)) self.offline_one_hot_masks = list( torch.split(offline_one_hot_masks, self.batch_size, dim=0)) if self.input_size_2d is None: self.update_size(all_frames.size()[2:], self.offline_enc_embs[0][-1].size()[2:]) def assign_identity(self, one_hot_mask): if self.enable_id_shuffle: one_hot_mask = torch.einsum('bohw,bot->bthw', one_hot_mask, self.id_shuffle_matrix) id_emb = self.AOT.get_id_emb(one_hot_mask).view( self.batch_size, -1, self.enc_hw).permute(2, 0, 1) if self.training and self.freeze_id: id_emb = id_emb.detach() return id_emb def split_frames(self, xs, chunk_size): new_xs = [] for x in xs: all_x = list(torch.split(x, chunk_size, dim=0)) new_xs.append(all_x) return list(zip(*new_xs)) def add_reference_frame(self, img=None, mask=None, frame_step=-1, obj_nums=None, img_embs=None): if self.obj_nums is None and obj_nums is None: print('No objects for reference frame!') exit() elif obj_nums is not None: self.obj_nums = obj_nums if frame_step == -1: frame_step = self.frame_step if img_embs is None: curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask( img, mask, frame_step) else: _, curr_one_hot_mask = self.encode_one_img_mask( None, mask, frame_step) curr_enc_embs = img_embs if curr_enc_embs is None: print('No image for reference frame!') exit() if curr_one_hot_mask is None: print('No mask for reference frame!') exit() if self.input_size_2d is None: self.update_size(img.size()[2:], curr_enc_embs[-1].size()[2:]) self.curr_enc_embs = curr_enc_embs self.curr_one_hot_mask = curr_one_hot_mask if self.pos_emb is None: self.pos_emb = self.AOT.get_pos_emb(curr_enc_embs[-1]).expand( self.batch_size, -1, -1, -1).view(self.batch_size, -1, self.enc_hw).permute(2, 0, 1) curr_id_emb = self.assign_identity(curr_one_hot_mask) self.curr_id_embs = curr_id_emb # self matching and propagation self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, None, None, curr_id_emb, pos_emb=self.pos_emb, size_2d=self.enc_size_2d) lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output if self.long_term_memories is None: self.long_term_memories = lstt_long_memories else: self.update_long_term_memory(lstt_long_memories) self.last_mem_step = self.frame_step self.short_term_memories_list = [lstt_short_memories] self.short_term_memories = lstt_short_memories def set_prev_frame(self, img=None, mask=None, frame_step=1): self.frame_step = frame_step curr_enc_embs, curr_one_hot_mask = self.encode_one_img_mask( img, mask, frame_step) if curr_enc_embs is None: print('No image for previous frame!') exit() if curr_one_hot_mask is None: print('No mask for previous frame!') exit() self.curr_enc_embs = curr_enc_embs self.curr_one_hot_mask = curr_one_hot_mask curr_id_emb = self.assign_identity(curr_one_hot_mask) self.curr_id_embs = curr_id_emb # self matching and propagation self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, None, None, curr_id_emb, pos_emb=self.pos_emb, size_2d=self.enc_size_2d) lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories = self.curr_lstt_output if self.long_term_memories is None: self.long_term_memories = lstt_long_memories else: self.update_long_term_memory(lstt_long_memories) self.last_mem_step = frame_step self.short_term_memories_list = [lstt_short_memories] self.short_term_memories = lstt_short_memories def update_long_term_memory(self, new_long_term_memories): TOKEN_NUM = new_long_term_memories[0][0].shape[0] if self.long_term_memories is None: self.long_term_memories = new_long_term_memories updated_long_term_memories = [] for new_long_term_memory, last_long_term_memory in zip( new_long_term_memories, self.long_term_memories): updated_e = [] for new_e, last_e in zip(new_long_term_memory, last_long_term_memory): if new_e is None or last_e is None: updated_e.append(None) else: if last_e.shape[0] >= self.max_len_long_term * TOKEN_NUM: last_e = last_e[:(self.max_len_long_term - 1) * TOKEN_NUM] updated_e.append(torch.cat([new_e, last_e], dim=0)) updated_long_term_memories.append(updated_e) self.long_term_memories = updated_long_term_memories def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): if curr_id_emb is None: if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) else: curr_one_hot_mask = curr_mask curr_id_emb = self.assign_identity(curr_one_hot_mask) lstt_curr_memories = self.curr_lstt_output[1] lstt_curr_memories_2d = [] for layer_idx in range(len(lstt_curr_memories)): curr_k, curr_v = lstt_curr_memories[layer_idx][ 0], lstt_curr_memories[layer_idx][1] curr_k, curr_v = self.AOT.LSTT.layers[layer_idx].fuse_key_value_id( curr_k, curr_v, curr_id_emb) lstt_curr_memories[layer_idx][0], lstt_curr_memories[layer_idx][ 1] = curr_k, curr_v lstt_curr_memories_2d.append([ seq_to_2d(lstt_curr_memories[layer_idx][0], self.enc_size_2d), seq_to_2d(lstt_curr_memories[layer_idx][1], self.enc_size_2d) ]) self.short_term_memories_list.append(lstt_curr_memories_2d) self.short_term_memories_list = self.short_term_memories_list[ -self.short_term_mem_skip:] self.short_term_memories = self.short_term_memories_list[0] if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: # skip the update of long-term memory or not if not skip_long_term_update: self.update_long_term_memory(lstt_curr_memories) self.last_mem_step = self.frame_step def match_propogate_one_frame(self, img=None, img_embs=None): self.frame_step += 1 if img_embs is None: curr_enc_embs, _ = self.encode_one_img_mask( img, None, self.frame_step) else: curr_enc_embs = img_embs self.curr_enc_embs = curr_enc_embs self.curr_lstt_output = self.AOT.LSTT_forward(curr_enc_embs, self.long_term_memories, self.short_term_memories, None, pos_emb=self.pos_emb, size_2d=self.enc_size_2d) def decode_current_logits(self, output_size=None): curr_enc_embs = self.curr_enc_embs curr_lstt_embs = self.curr_lstt_output[0] pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs, curr_enc_embs) if self.enable_id_shuffle: # reverse shuffle pred_id_logits = torch.einsum('bohw,bto->bthw', pred_id_logits, self.id_shuffle_matrix) # remove unused identities for batch_idx, obj_num in enumerate(self.obj_nums): pred_id_logits[batch_idx, (obj_num+1):] = - \ 1e+10 if pred_id_logits.dtype == torch.float32 else -1e+4 self.pred_id_logits = pred_id_logits if output_size is not None: pred_id_logits = F.interpolate(pred_id_logits, size=output_size, mode="bilinear", align_corners=self.align_corners) return pred_id_logits def predict_current_mask(self, output_size=None, return_prob=False): if output_size is None: output_size = self.input_size_2d pred_id_logits = F.interpolate(self.pred_id_logits, size=output_size, mode="bilinear", align_corners=self.align_corners) pred_mask = torch.argmax(pred_id_logits, dim=1) if not return_prob: return pred_mask else: pred_prob = torch.softmax(pred_id_logits, dim=1) return pred_mask, pred_prob def calculate_current_loss(self, gt_mask, step): pred_id_logits = self.pred_id_logits pred_id_logits = F.interpolate(pred_id_logits, size=gt_mask.size()[-2:], mode="bilinear", align_corners=self.align_corners) label_list = [] logit_list = [] for batch_idx, obj_num in enumerate(self.obj_nums): now_label = gt_mask[batch_idx].long() now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0) label_list.append(now_label.long()) logit_list.append(now_logit) total_loss = 0 for loss, loss_weight in zip(self.losses, self.loss_weights): total_loss = total_loss + loss_weight * \ loss(logit_list, label_list, step) return total_loss def generate_loss_mask(self, gt_mask, step, return_prob=False): self.decode_current_logits() loss = self.calculate_current_loss(gt_mask, step) if return_prob: mask, prob = self.predict_current_mask(return_prob=True) return loss, mask, prob else: mask = self.predict_current_mask() return loss, mask def keep_gt_mask(self, pred_mask, keep_prob=0.2): pred_mask = pred_mask.float() gt_mask = self.offline_masks[self.frame_step].float().squeeze(1) shape = [1 for _ in range(pred_mask.ndim)] shape[0] = self.batch_size random_tensor = keep_prob + torch.rand( shape, dtype=pred_mask.dtype, device=pred_mask.device) random_tensor.floor_() # binarize pred_mask = pred_mask * (1 - random_tensor) + gt_mask * random_tensor return pred_mask def restart_engine(self, batch_size=1, enable_id_shuffle=False): self.batch_size = batch_size self.frame_step = 0 self.last_mem_step = -1 self.enable_id_shuffle = enable_id_shuffle self.freeze_id = False self.obj_nums = None self.pos_emb = None self.enc_size_2d = None self.enc_hw = None self.input_size_2d = None self.long_term_memories = None self.short_term_memories_list = [] self.short_term_memories = None self.enable_offline_enc = False self.offline_enc_embs = None self.offline_one_hot_masks = None self.offline_frames = -1 self.total_offline_frame_num = 0 self.curr_enc_embs = None self.curr_memories = None self.curr_id_embs = None if enable_id_shuffle: self.id_shuffle_matrix = generate_permute_matrix( self.max_obj_num + 1, batch_size, gpu_id=self.gpu_id) else: self.id_shuffle_matrix = None def update_size(self, input_size, enc_size): self.input_size_2d = input_size self.enc_size_2d = enc_size self.enc_hw = self.enc_size_2d[0] * self.enc_size_2d[1] class AOTInferEngine(nn.Module): def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_aot_obj_num=None, max_len_long_term=9999,): super().__init__() self.cfg = aot_model.cfg self.AOT = aot_model if max_aot_obj_num is None or max_aot_obj_num > aot_model.max_obj_num: self.max_aot_obj_num = aot_model.max_obj_num else: self.max_aot_obj_num = max_aot_obj_num self.gpu_id = gpu_id self.long_term_mem_gap = long_term_mem_gap self.short_term_mem_skip = short_term_mem_skip self.max_len_long_term = max_len_long_term self.aot_engines = [] self.restart_engine() def restart_engine(self): del (self.aot_engines) self.aot_engines = [] self.obj_nums = None def separate_mask(self, mask, obj_nums): if mask is None: return [None] * len(self.aot_engines) if len(self.aot_engines) == 1: return [mask], [obj_nums] separated_obj_nums = [ self.max_aot_obj_num for _ in range(len(self.aot_engines)) ] if obj_nums % self.max_aot_obj_num > 0: separated_obj_nums[-1] = obj_nums % self.max_aot_obj_num if len(mask.size()) == 3 or mask.size()[0] == 1: separated_masks = [] for idx in range(len(self.aot_engines)): start_id = idx * self.max_aot_obj_num + 1 end_id = (idx + 1) * self.max_aot_obj_num fg_mask = ((mask >= start_id) & (mask <= end_id)).float() separated_mask = (fg_mask * mask - start_id + 1) * fg_mask separated_masks.append(separated_mask) return separated_masks, separated_obj_nums else: prob = mask separated_probs = [] for idx in range(len(self.aot_engines)): start_id = idx * self.max_aot_obj_num + 1 end_id = (idx + 1) * self.max_aot_obj_num fg_prob = prob[start_id:(end_id + 1)] bg_prob = 1. - torch.sum(fg_prob, dim=1, keepdim=True) separated_probs.append(torch.cat([bg_prob, fg_prob], dim=1)) return separated_probs, separated_obj_nums def min_logit_aggregation(self, all_logits): if len(all_logits) == 1: return all_logits[0] fg_logits = [] bg_logits = [] for logit in all_logits: bg_logits.append(logit[:, 0:1]) fg_logits.append(logit[:, 1:1 + self.max_aot_obj_num]) bg_logit, _ = torch.min(torch.cat(bg_logits, dim=1), dim=1, keepdim=True) merged_logit = torch.cat([bg_logit] + fg_logits, dim=1) return merged_logit def soft_logit_aggregation(self, all_logits): if len(all_logits) == 1: return all_logits[0] fg_probs = [] bg_probs = [] for logit in all_logits: prob = torch.softmax(logit, dim=1) bg_probs.append(prob[:, 0:1]) fg_probs.append(prob[:, 1:1 + self.max_aot_obj_num]) bg_prob = torch.prod(torch.cat(bg_probs, dim=1), dim=1, keepdim=True) merged_prob = torch.cat([bg_prob] + fg_probs, dim=1).clamp(1e-5, 1 - 1e-5) merged_logit = torch.logit(merged_prob) return merged_logit def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): if isinstance(obj_nums, list): obj_nums = obj_nums[0] self.obj_nums = obj_nums aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) while (aot_num > len(self.aot_engines)): new_engine = AOTEngine(self.AOT, self.gpu_id, self.long_term_mem_gap, self.short_term_mem_skip, self.max_len_long_term,) new_engine.eval() self.aot_engines.append(new_engine) separated_masks, separated_obj_nums = self.separate_mask( mask, obj_nums) img_embs = None for aot_engine, separated_mask, separated_obj_num in zip( self.aot_engines, separated_masks, separated_obj_nums): aot_engine.add_reference_frame(img, separated_mask, obj_nums=[separated_obj_num], frame_step=frame_step, img_embs=img_embs) if img_embs is None: # reuse image embeddings img_embs = aot_engine.curr_enc_embs self.update_size() def match_propogate_one_frame(self, img=None): img_embs = None for aot_engine in self.aot_engines: aot_engine.match_propogate_one_frame(img, img_embs=img_embs) if img_embs is None: # reuse image embeddings img_embs = aot_engine.curr_enc_embs def decode_current_logits(self, output_size=None): all_logits = [] for aot_engine in self.aot_engines: all_logits.append(aot_engine.decode_current_logits(output_size)) pred_id_logits = self.soft_logit_aggregation(all_logits) return pred_id_logits def update_memory(self, curr_mask, skip_long_term_update=False): _curr_mask = F.interpolate(curr_mask,self.input_size_2d) separated_masks, _ = self.separate_mask(_curr_mask, self.obj_nums) for aot_engine, separated_mask in zip(self.aot_engines, separated_masks): aot_engine.update_short_term_memory(separated_mask, skip_long_term_update=skip_long_term_update) def update_size(self): self.input_size_2d = self.aot_engines[0].input_size_2d self.enc_size_2d = self.aot_engines[0].enc_size_2d self.enc_hw = self.aot_engines[0].enc_hw