""" This file defines XMem, the highest level nn.Module interface During training, it is used by trainer.py During evaluation, it is used by inference_core.py It further depends on modules.py which gives more detailed implementations of sub-modules """ import torch import torch.nn as nn from model.aggregate import aggregate from model.modules import * from model.memory_util import * class XMem(nn.Module): def __init__(self, config, model_path=None, map_location=None): """ model_path/map_location are used in evaluation only map_location is for converting models saved in cuda to cpu """ super().__init__() model_weights = self.init_hyperparameters(config, model_path, map_location) self.single_object = config.get("single_object", False) print(f"Single object mode: {self.single_object}") self.key_encoder = KeyEncoder() self.value_encoder = ValueEncoder( self.value_dim, self.hidden_dim, self.single_object ) # Projection from f16 feature space to key/value space self.key_proj = KeyProjection(1024, self.key_dim) self.decoder = Decoder(self.value_dim, self.hidden_dim) if model_weights is not None: self.load_weights(model_weights, init_as_zero_if_needed=True) def encode_key(self, frame, need_sk=True, need_ek=True): # Determine input shape if len(frame.shape) == 5: # shape is b*t*c*h*w need_reshape = True b, t = frame.shape[:2] # flatten so that we can feed them into a 2D CNN frame = frame.flatten(start_dim=0, end_dim=1) elif len(frame.shape) == 4: # shape is b*c*h*w need_reshape = False else: raise NotImplementedError f16, f8, f4 = self.key_encoder(frame) key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek) if need_reshape: # B*C*T*H*W key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous() if shrinkage is not None: shrinkage = ( shrinkage.view(b, t, *shrinkage.shape[-3:]) .transpose(1, 2) .contiguous() ) if selection is not None: selection = ( selection.view(b, t, *selection.shape[-3:]) .transpose(1, 2) .contiguous() ) # B*T*C*H*W f16 = f16.view(b, t, *f16.shape[-3:]) f8 = f8.view(b, t, *f8.shape[-3:]) f4 = f4.view(b, t, *f4.shape[-3:]) return key, shrinkage, selection, f16, f8, f4 def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True): num_objects = masks.shape[1] if num_objects != 1: others = torch.cat( [ torch.sum( masks[:, [j for j in range(num_objects) if i != j]], dim=1, keepdim=True, ) for i in range(num_objects) ], 1, ) else: others = torch.zeros_like(masks) g16, h16 = self.value_encoder( frame, image_feat_f16, h16, masks, others, is_deep_update ) return g16, h16 # Used in training only. # This step is replaced by MemoryManager in test time def read_memory( self, query_key, query_selection, memory_key, memory_shrinkage, memory_value ): """ query_key : B * CK * H * W query_selection : B * CK * H * W memory_key : B * CK * T * H * W memory_shrinkage: B * 1 * T * H * W memory_value : B * num_objects * CV * T * H * W """ batch_size, num_objects = memory_value.shape[:2] memory_value = memory_value.flatten(start_dim=1, end_dim=2) affinity = get_affinity( memory_key, memory_shrinkage, query_key, query_selection ) memory = readout(affinity, memory_value) memory = memory.view( batch_size, num_objects, self.value_dim, *memory.shape[-2:] ) return memory def segment( self, multi_scale_features, memory_readout, hidden_state, selector=None, h_out=True, strip_bg=True, ): hidden_state, logits = self.decoder( *multi_scale_features, hidden_state, memory_readout, h_out=h_out ) prob = torch.sigmoid(logits) if selector is not None: prob = prob * selector logits, prob = aggregate(prob, dim=1, return_logits=True) if strip_bg: # Strip away the background prob = prob[:, 1:] return hidden_state, logits, prob def forward(self, mode, *args, **kwargs): if mode == "encode_key": return self.encode_key(*args, **kwargs) elif mode == "encode_value": return self.encode_value(*args, **kwargs) elif mode == "read_memory": return self.read_memory(*args, **kwargs) elif mode == "segment": return self.segment(*args, **kwargs) else: raise NotImplementedError def init_hyperparameters(self, config, model_path=None, map_location=None): """ Init three hyperparameters: key_dim, value_dim, and hidden_dim If model_path is provided, we load these from the model weights The actual parameters are then updated to the config in-place Otherwise we load it either from the config or default """ if model_path is not None: # load the model and key/value/hidden dimensions with some hacks # config is updated with the loaded parameters model_weights = torch.load(model_path, map_location=map_location) self.key_dim = model_weights["key_proj.key_proj.weight"].shape[0] self.value_dim = model_weights[ "value_encoder.fuser.block2.conv2.weight" ].shape[0] self.disable_hidden = ( "decoder.hidden_update.transform.weight" not in model_weights ) if self.disable_hidden: self.hidden_dim = 0 else: self.hidden_dim = ( model_weights["decoder.hidden_update.transform.weight"].shape[0] // 3 ) print( f"Hyperparameters read from the model weights: " f"C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}" ) else: model_weights = None # load dimensions from config or default if "key_dim" not in config: self.key_dim = 64 print(f"key_dim not found in config. Set to default {self.key_dim}") else: self.key_dim = config["key_dim"] if "value_dim" not in config: self.value_dim = 512 print(f"value_dim not found in config. Set to default {self.value_dim}") else: self.value_dim = config["value_dim"] if "hidden_dim" not in config: self.hidden_dim = 64 print( f"hidden_dim not found in config. Set to default {self.hidden_dim}" ) else: self.hidden_dim = config["hidden_dim"] self.disable_hidden = self.hidden_dim <= 0 config["key_dim"] = self.key_dim config["value_dim"] = self.value_dim config["hidden_dim"] = self.hidden_dim return model_weights def load_weights(self, src_dict, init_as_zero_if_needed=False): # Maps SO weight (without other_mask) to MO weight (with other_mask) for k in list(src_dict.keys()): if k == "value_encoder.conv1.weight": if src_dict[k].shape[1] == 4: print("Converting weights from single object to multiple objects.") pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device) if not init_as_zero_if_needed: print("Randomly initialized padding.") nn.init.orthogonal_(pads) else: print("Zero-initialized padding.") src_dict[k] = torch.cat([src_dict[k], pads], 1) self.load_state_dict(src_dict)