from inference.memory_manager import MemoryManager from model.network import XMem from model.aggregate import aggregate from tracker.util.tensor_util import pad_divide_by, unpad class InferenceCore: def __init__(self, network: XMem, config): self.config = config self.network = network self.mem_every = config["mem_every"] self.deep_update_every = config["deep_update_every"] self.enable_long_term = config["enable_long_term"] # if deep_update_every < 0, synchronize deep update with memory frame self.deep_update_sync = self.deep_update_every < 0 self.clear_memory() self.all_labels = None def clear_memory(self): self.curr_ti = -1 self.last_mem_ti = 0 if not self.deep_update_sync: self.last_deep_update_ti = -self.deep_update_every self.memory = MemoryManager(config=self.config) def update_config(self, config): self.mem_every = config["mem_every"] self.deep_update_every = config["deep_update_every"] self.enable_long_term = config["enable_long_term"] # if deep_update_every < 0, synchronize deep update with memory frame self.deep_update_sync = self.deep_update_every < 0 self.memory.update_config(config) def set_all_labels(self, all_labels): # self.all_labels = [l.item() for l in all_labels] self.all_labels = all_labels def step(self, image, mask=None, valid_labels=None, end=False): # image: 3*H*W # mask: num_objects*H*W or None self.curr_ti += 1 image, self.pad = pad_divide_by(image, 16) image = image.unsqueeze(0) # add the batch dimension is_mem_frame = ( (self.curr_ti - self.last_mem_ti >= self.mem_every) or (mask is not None) ) and (not end) need_segment = (self.curr_ti > 0) and ( (valid_labels is None) or (len(self.all_labels) != len(valid_labels)) ) is_deep_update = ( (self.deep_update_sync and is_mem_frame) or ( # synchronized not self.deep_update_sync and self.curr_ti - self.last_deep_update_ti >= self.deep_update_every ) # no-sync ) and (not end) is_normal_update = (not self.deep_update_sync or not is_deep_update) and ( not end ) key, shrinkage, selection, f16, f8, f4 = self.network.encode_key( image, need_ek=(self.enable_long_term or need_segment), need_sk=is_mem_frame ) multi_scale_features = (f16, f8, f4) # segment the current frame is needed if need_segment: memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment( multi_scale_features, memory_readout, self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False, ) # remove batch dim pred_prob_with_bg = pred_prob_with_bg[0] pred_prob_no_bg = pred_prob_with_bg[1:] pred_logits_with_bg = pred_logits_with_bg[0] pred_logits_no_bg = pred_logits_with_bg[1:] if is_normal_update: self.memory.set_hidden(hidden) else: pred_prob_no_bg = ( pred_prob_with_bg ) = pred_logits_with_bg = pred_logits_no_bg = None # use the input mask if any if mask is not None: mask, _ = pad_divide_by(mask, 16) if pred_prob_no_bg is not None: # if we have a predicted mask, we work on it # make pred_prob_no_bg consistent with the input mask mask_regions = mask.sum(0) > 0.5 pred_prob_no_bg[:, mask_regions] = 0 # shift by 1 because mask/pred_prob_no_bg do not contain background mask = mask.type_as(pred_prob_no_bg) if valid_labels is not None: shift_by_one_non_labels = [ i for i in range(pred_prob_no_bg.shape[0]) if (i + 1) not in valid_labels ] # non-labelled objects are copied from the predicted mask mask[shift_by_one_non_labels] = pred_prob_no_bg[ shift_by_one_non_labels ] pred_prob_with_bg = aggregate(mask, dim=0) # also create new hidden states self.memory.create_hidden_state(len(self.all_labels), key) # save as memory if needed if is_mem_frame: value, hidden = self.network.encode_value( image, f16, self.memory.get_hidden(), pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update, ) self.memory.add_memory( key, shrinkage, value, self.all_labels, selection=selection if self.enable_long_term else None, ) self.last_mem_ti = self.curr_ti if is_deep_update: self.memory.set_hidden(hidden) self.last_deep_update_ti = self.curr_ti if pred_logits_with_bg is None: return unpad(pred_prob_with_bg, self.pad), None else: return unpad(pred_prob_with_bg, self.pad), unpad( pred_logits_with_bg, self.pad )