watchtowerss's picture
track-anything --version 1
4d1ebf3
raw
history blame contribute delete
No virus
8.05 kB
"""
This file defines XMem, the highest level nn.Module interface
During training, it is used by trainer.py
During evaluation, it is used by inference_core.py
It further depends on modules.py which gives more detailed implementations of sub-modules
"""
import torch
import torch.nn as nn
from model.aggregate import aggregate
from model.modules import *
from model.memory_util import *
class XMem(nn.Module):
def __init__(self, config, model_path=None, map_location=None):
"""
model_path/map_location are used in evaluation only
map_location is for converting models saved in cuda to cpu
"""
super().__init__()
model_weights = self.init_hyperparameters(config, model_path, map_location)
self.single_object = config.get('single_object', False)
print(f'Single object mode: {self.single_object}')
self.key_encoder = KeyEncoder()
self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object)
# Projection from f16 feature space to key/value space
self.key_proj = KeyProjection(1024, self.key_dim)
self.decoder = Decoder(self.value_dim, self.hidden_dim)
if model_weights is not None:
self.load_weights(model_weights, init_as_zero_if_needed=True)
def encode_key(self, frame, need_sk=True, need_ek=True):
# Determine input shape
if len(frame.shape) == 5:
# shape is b*t*c*h*w
need_reshape = True
b, t = frame.shape[:2]
# flatten so that we can feed them into a 2D CNN
frame = frame.flatten(start_dim=0, end_dim=1)
elif len(frame.shape) == 4:
# shape is b*c*h*w
need_reshape = False
else:
raise NotImplementedError
f16, f8, f4 = self.key_encoder(frame)
key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
if need_reshape:
# B*C*T*H*W
key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
if shrinkage is not None:
shrinkage = shrinkage.view(b, t, *shrinkage.shape[-3:]).transpose(1, 2).contiguous()
if selection is not None:
selection = selection.view(b, t, *selection.shape[-3:]).transpose(1, 2).contiguous()
# B*T*C*H*W
f16 = f16.view(b, t, *f16.shape[-3:])
f8 = f8.view(b, t, *f8.shape[-3:])
f4 = f4.view(b, t, *f4.shape[-3:])
return key, shrinkage, selection, f16, f8, f4
def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
num_objects = masks.shape[1]
if num_objects != 1:
others = torch.cat([
torch.sum(
masks[:, [j for j in range(num_objects) if i!=j]]
, dim=1, keepdim=True)
for i in range(num_objects)], 1)
else:
others = torch.zeros_like(masks)
g16, h16 = self.value_encoder(frame, image_feat_f16, h16, masks, others, is_deep_update)
return g16, h16
# Used in training only.
# This step is replaced by MemoryManager in test time
def read_memory(self, query_key, query_selection, memory_key,
memory_shrinkage, memory_value):
"""
query_key : B * CK * H * W
query_selection : B * CK * H * W
memory_key : B * CK * T * H * W
memory_shrinkage: B * 1 * T * H * W
memory_value : B * num_objects * CV * T * H * W
"""
batch_size, num_objects = memory_value.shape[:2]
memory_value = memory_value.flatten(start_dim=1, end_dim=2)
affinity = get_affinity(memory_key, memory_shrinkage, query_key, query_selection)
memory = readout(affinity, memory_value)
memory = memory.view(batch_size, num_objects, self.value_dim, *memory.shape[-2:])
return memory
def segment(self, multi_scale_features, memory_readout,
hidden_state, selector=None, h_out=True, strip_bg=True):
hidden_state, logits = self.decoder(*multi_scale_features, hidden_state, memory_readout, h_out=h_out)
prob = torch.sigmoid(logits)
if selector is not None:
prob = prob * selector
logits, prob = aggregate(prob, dim=1, return_logits=True)
if strip_bg:
# Strip away the background
prob = prob[:, 1:]
return hidden_state, logits, prob
def forward(self, mode, *args, **kwargs):
if mode == 'encode_key':
return self.encode_key(*args, **kwargs)
elif mode == 'encode_value':
return self.encode_value(*args, **kwargs)
elif mode == 'read_memory':
return self.read_memory(*args, **kwargs)
elif mode == 'segment':
return self.segment(*args, **kwargs)
else:
raise NotImplementedError
def init_hyperparameters(self, config, model_path=None, map_location=None):
"""
Init three hyperparameters: key_dim, value_dim, and hidden_dim
If model_path is provided, we load these from the model weights
The actual parameters are then updated to the config in-place
Otherwise we load it either from the config or default
"""
if model_path is not None:
# load the model and key/value/hidden dimensions with some hacks
# config is updated with the loaded parameters
model_weights = torch.load(model_path, map_location=map_location)
self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights
if self.disable_hidden:
self.hidden_dim = 0
else:
self.hidden_dim = model_weights['decoder.hidden_update.transform.weight'].shape[0]//3
print(f'Hyperparameters read from the model weights: '
f'C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}')
else:
model_weights = None
# load dimensions from config or default
if 'key_dim' not in config:
self.key_dim = 64
print(f'key_dim not found in config. Set to default {self.key_dim}')
else:
self.key_dim = config['key_dim']
if 'value_dim' not in config:
self.value_dim = 512
print(f'value_dim not found in config. Set to default {self.value_dim}')
else:
self.value_dim = config['value_dim']
if 'hidden_dim' not in config:
self.hidden_dim = 64
print(f'hidden_dim not found in config. Set to default {self.hidden_dim}')
else:
self.hidden_dim = config['hidden_dim']
self.disable_hidden = (self.hidden_dim <= 0)
config['key_dim'] = self.key_dim
config['value_dim'] = self.value_dim
config['hidden_dim'] = self.hidden_dim
return model_weights
def load_weights(self, src_dict, init_as_zero_if_needed=False):
# Maps SO weight (without other_mask) to MO weight (with other_mask)
for k in list(src_dict.keys()):
if k == 'value_encoder.conv1.weight':
if src_dict[k].shape[1] == 4:
print('Converting weights from single object to multiple objects.')
pads = torch.zeros((64,1,7,7), device=src_dict[k].device)
if not init_as_zero_if_needed:
print('Randomly initialized padding.')
nn.init.orthogonal_(pads)
else:
print('Zero-initialized padding.')
src_dict[k] = torch.cat([src_dict[k], pads], 1)
self.load_state_dict(src_dict)