Schrodingers's picture
Upload folder using huggingface_hub
ffbe0b4
raw
history blame contribute delete
No virus
4.44 kB
import numpy as np
from utils.image import one_hot_mask
from networks.layers.basic import seq_to_2d
from networks.engines.aot_engine import AOTEngine, AOTInferEngine
class DeAOTEngine(AOTEngine):
def __init__(self,
aot_model,
gpu_id=0,
long_term_mem_gap=9999,
short_term_mem_skip=1,
layer_loss_scaling_ratio=2.,
max_len_long_term=9999):
super().__init__(aot_model, gpu_id, long_term_mem_gap,
short_term_mem_skip, max_len_long_term)
self.layer_loss_scaling_ratio = layer_loss_scaling_ratio
def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False):
if curr_id_emb is None:
if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1:
curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num)
else:
curr_one_hot_mask = curr_mask
curr_id_emb = self.assign_identity(curr_one_hot_mask)
lstt_curr_memories = self.curr_lstt_output[1]
lstt_curr_memories_2d = []
for layer_idx in range(len(lstt_curr_memories)):
curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[
layer_idx]
curr_id_k, curr_id_v = self.AOT.LSTT.layers[
layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb)
lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][
3] = curr_id_k, curr_id_v
local_curr_id_k = seq_to_2d(
curr_id_k, self.enc_size_2d) if curr_id_k is not None else None
local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d)
lstt_curr_memories_2d.append([
seq_to_2d(curr_k, self.enc_size_2d),
seq_to_2d(curr_v, self.enc_size_2d), local_curr_id_k,
local_curr_id_v
])
self.short_term_memories_list.append(lstt_curr_memories_2d)
self.short_term_memories_list = self.short_term_memories_list[
-self.short_term_mem_skip:]
self.short_term_memories = self.short_term_memories_list[0]
if self.frame_step - self.last_mem_step >= self.long_term_mem_gap:
# skip the update of long-term memory or not
if not skip_long_term_update:
self.update_long_term_memory(lstt_curr_memories)
self.last_mem_step = self.frame_step
class DeAOTInferEngine(AOTInferEngine):
def __init__(self,
aot_model,
gpu_id=0,
long_term_mem_gap=9999,
short_term_mem_skip=1,
max_aot_obj_num=None,
max_len_long_term=9999):
super().__init__(aot_model, gpu_id, long_term_mem_gap,
short_term_mem_skip, max_aot_obj_num, max_len_long_term)
def add_reference_frame(self, img, mask, obj_nums, frame_step=-1):
if isinstance(obj_nums, list):
obj_nums = obj_nums[0]
self.obj_nums = obj_nums
aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1)
while (aot_num > len(self.aot_engines)):
new_engine = DeAOTEngine(self.AOT, self.gpu_id,
self.long_term_mem_gap,
self.short_term_mem_skip,
max_len_long_term = self.max_len_long_term)
new_engine.eval()
self.aot_engines.append(new_engine)
separated_masks, separated_obj_nums = self.separate_mask(
mask, obj_nums)
img_embs = None
for aot_engine, separated_mask, separated_obj_num in zip(
self.aot_engines, separated_masks, separated_obj_nums):
if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num:
aot_engine.add_reference_frame(img,
separated_mask,
obj_nums=[separated_obj_num],
frame_step=frame_step,
img_embs=img_embs)
else:
aot_engine.update_short_term_memory(separated_mask)
if img_embs is None: # reuse image embeddings
img_embs = aot_engine.curr_enc_embs
self.update_size()