Spaces:
Sleeping
Sleeping
File size: 5,648 Bytes
123489f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from inference.memory_manager import MemoryManager
from model.network import XMem
from model.aggregate import aggregate
from tracker.util.tensor_util import pad_divide_by, unpad
class InferenceCore:
def __init__(self, network: XMem, config):
self.config = config
self.network = network
self.mem_every = config["mem_every"]
self.deep_update_every = config["deep_update_every"]
self.enable_long_term = config["enable_long_term"]
# if deep_update_every < 0, synchronize deep update with memory frame
self.deep_update_sync = self.deep_update_every < 0
self.clear_memory()
self.all_labels = None
def clear_memory(self):
self.curr_ti = -1
self.last_mem_ti = 0
if not self.deep_update_sync:
self.last_deep_update_ti = -self.deep_update_every
self.memory = MemoryManager(config=self.config)
def update_config(self, config):
self.mem_every = config["mem_every"]
self.deep_update_every = config["deep_update_every"]
self.enable_long_term = config["enable_long_term"]
# if deep_update_every < 0, synchronize deep update with memory frame
self.deep_update_sync = self.deep_update_every < 0
self.memory.update_config(config)
def set_all_labels(self, all_labels):
# self.all_labels = [l.item() for l in all_labels]
self.all_labels = all_labels
def step(self, image, mask=None, valid_labels=None, end=False):
# image: 3*H*W
# mask: num_objects*H*W or None
self.curr_ti += 1
image, self.pad = pad_divide_by(image, 16)
image = image.unsqueeze(0) # add the batch dimension
is_mem_frame = (
(self.curr_ti - self.last_mem_ti >= self.mem_every) or (mask is not None)
) and (not end)
need_segment = (self.curr_ti > 0) and (
(valid_labels is None) or (len(self.all_labels) != len(valid_labels))
)
is_deep_update = (
(self.deep_update_sync and is_mem_frame)
or ( # synchronized
not self.deep_update_sync
and self.curr_ti - self.last_deep_update_ti >= self.deep_update_every
) # no-sync
) and (not end)
is_normal_update = (not self.deep_update_sync or not is_deep_update) and (
not end
)
key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(
image, need_ek=(self.enable_long_term or need_segment), need_sk=is_mem_frame
)
multi_scale_features = (f16, f8, f4)
# segment the current frame is needed
if need_segment:
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment(
multi_scale_features,
memory_readout,
self.memory.get_hidden(),
h_out=is_normal_update,
strip_bg=False,
)
# remove batch dim
pred_prob_with_bg = pred_prob_with_bg[0]
pred_prob_no_bg = pred_prob_with_bg[1:]
pred_logits_with_bg = pred_logits_with_bg[0]
pred_logits_no_bg = pred_logits_with_bg[1:]
if is_normal_update:
self.memory.set_hidden(hidden)
else:
pred_prob_no_bg = (
pred_prob_with_bg
) = pred_logits_with_bg = pred_logits_no_bg = None
# use the input mask if any
if mask is not None:
mask, _ = pad_divide_by(mask, 16)
if pred_prob_no_bg is not None:
# if we have a predicted mask, we work on it
# make pred_prob_no_bg consistent with the input mask
mask_regions = mask.sum(0) > 0.5
pred_prob_no_bg[:, mask_regions] = 0
# shift by 1 because mask/pred_prob_no_bg do not contain background
mask = mask.type_as(pred_prob_no_bg)
if valid_labels is not None:
shift_by_one_non_labels = [
i
for i in range(pred_prob_no_bg.shape[0])
if (i + 1) not in valid_labels
]
# non-labelled objects are copied from the predicted mask
mask[shift_by_one_non_labels] = pred_prob_no_bg[
shift_by_one_non_labels
]
pred_prob_with_bg = aggregate(mask, dim=0)
# also create new hidden states
self.memory.create_hidden_state(len(self.all_labels), key)
# save as memory if needed
if is_mem_frame:
value, hidden = self.network.encode_value(
image,
f16,
self.memory.get_hidden(),
pred_prob_with_bg[1:].unsqueeze(0),
is_deep_update=is_deep_update,
)
self.memory.add_memory(
key,
shrinkage,
value,
self.all_labels,
selection=selection if self.enable_long_term else None,
)
self.last_mem_ti = self.curr_ti
if is_deep_update:
self.memory.set_hidden(hidden)
self.last_deep_update_ti = self.curr_ti
if pred_logits_with_bg is None:
return unpad(pred_prob_with_bg, self.pad), None
else:
return unpad(pred_prob_with_bg, self.pad), unpad(
pred_logits_with_bg, self.pad
)
|