import numpy as np from utils.image import one_hot_mask from networks.layers.basic import seq_to_2d from networks.engines.aot_engine import AOTEngine, AOTInferEngine class DeAOTEngine(AOTEngine): def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, layer_loss_scaling_ratio=2., max_len_long_term=9999): super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_len_long_term) self.layer_loss_scaling_ratio = layer_loss_scaling_ratio def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): if curr_id_emb is None: if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) else: curr_one_hot_mask = curr_mask curr_id_emb = self.assign_identity(curr_one_hot_mask) lstt_curr_memories = self.curr_lstt_output[1] lstt_curr_memories_2d = [] for layer_idx in range(len(lstt_curr_memories)): curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[ layer_idx] curr_id_k, curr_id_v = self.AOT.LSTT.layers[ layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb) lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][ 3] = curr_id_k, curr_id_v local_curr_id_k = seq_to_2d( curr_id_k, self.enc_size_2d) if curr_id_k is not None else None local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d) lstt_curr_memories_2d.append([ seq_to_2d(curr_k, self.enc_size_2d), seq_to_2d(curr_v, self.enc_size_2d), local_curr_id_k, local_curr_id_v ]) self.short_term_memories_list.append(lstt_curr_memories_2d) self.short_term_memories_list = self.short_term_memories_list[ -self.short_term_mem_skip:] self.short_term_memories = self.short_term_memories_list[0] if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: # skip the update of long-term memory or not if not skip_long_term_update: self.update_long_term_memory(lstt_curr_memories) self.last_mem_step = self.frame_step class DeAOTInferEngine(AOTInferEngine): def __init__(self, aot_model, gpu_id=0, long_term_mem_gap=9999, short_term_mem_skip=1, max_aot_obj_num=None, max_len_long_term=9999): super().__init__(aot_model, gpu_id, long_term_mem_gap, short_term_mem_skip, max_aot_obj_num, max_len_long_term) def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): if isinstance(obj_nums, list): obj_nums = obj_nums[0] self.obj_nums = obj_nums aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) while (aot_num > len(self.aot_engines)): new_engine = DeAOTEngine(self.AOT, self.gpu_id, self.long_term_mem_gap, self.short_term_mem_skip, max_len_long_term = self.max_len_long_term) new_engine.eval() self.aot_engines.append(new_engine) separated_masks, separated_obj_nums = self.separate_mask( mask, obj_nums) img_embs = None for aot_engine, separated_mask, separated_obj_num in zip( self.aot_engines, separated_masks, separated_obj_nums): if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num: aot_engine.add_reference_frame(img, separated_mask, obj_nums=[separated_obj_num], frame_step=frame_step, img_embs=img_embs) else: aot_engine.update_short_term_memory(separated_mask) if img_embs is None: # reuse image embeddings img_embs = aot_engine.curr_enc_embs self.update_size()