github-actions[bot]
Sync to HuggingFace Spaces
123489f
"""
This file defines XMem, the highest level nn.Module interface
During training, it is used by trainer.py
During evaluation, it is used by inference_core.py
It further depends on modules.py which gives more detailed implementations of sub-modules
"""
import torch
import torch.nn as nn
from model.aggregate import aggregate
from model.modules import *
from model.memory_util import *
class XMem(nn.Module):
def __init__(self, config, model_path=None, map_location=None):
"""
model_path/map_location are used in evaluation only
map_location is for converting models saved in cuda to cpu
"""
super().__init__()
model_weights = self.init_hyperparameters(config, model_path, map_location)
self.single_object = config.get("single_object", False)
print(f"Single object mode: {self.single_object}")
self.key_encoder = KeyEncoder()
self.value_encoder = ValueEncoder(
self.value_dim, self.hidden_dim, self.single_object
)
# Projection from f16 feature space to key/value space
self.key_proj = KeyProjection(1024, self.key_dim)
self.decoder = Decoder(self.value_dim, self.hidden_dim)
if model_weights is not None:
self.load_weights(model_weights, init_as_zero_if_needed=True)
def encode_key(self, frame, need_sk=True, need_ek=True):
# Determine input shape
if len(frame.shape) == 5:
# shape is b*t*c*h*w
need_reshape = True
b, t = frame.shape[:2]
# flatten so that we can feed them into a 2D CNN
frame = frame.flatten(start_dim=0, end_dim=1)
elif len(frame.shape) == 4:
# shape is b*c*h*w
need_reshape = False
else:
raise NotImplementedError
f16, f8, f4 = self.key_encoder(frame)
key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
if need_reshape:
# B*C*T*H*W
key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
if shrinkage is not None:
shrinkage = (
shrinkage.view(b, t, *shrinkage.shape[-3:])
.transpose(1, 2)
.contiguous()
)
if selection is not None:
selection = (
selection.view(b, t, *selection.shape[-3:])
.transpose(1, 2)
.contiguous()
)
# B*T*C*H*W
f16 = f16.view(b, t, *f16.shape[-3:])
f8 = f8.view(b, t, *f8.shape[-3:])
f4 = f4.view(b, t, *f4.shape[-3:])
return key, shrinkage, selection, f16, f8, f4
def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
num_objects = masks.shape[1]
if num_objects != 1:
others = torch.cat(
[
torch.sum(
masks[:, [j for j in range(num_objects) if i != j]],
dim=1,
keepdim=True,
)
for i in range(num_objects)
],
1,
)
else:
others = torch.zeros_like(masks)
g16, h16 = self.value_encoder(
frame, image_feat_f16, h16, masks, others, is_deep_update
)
return g16, h16
# Used in training only.
# This step is replaced by MemoryManager in test time
def read_memory(
self, query_key, query_selection, memory_key, memory_shrinkage, memory_value
):
"""
query_key : B * CK * H * W
query_selection : B * CK * H * W
memory_key : B * CK * T * H * W
memory_shrinkage: B * 1 * T * H * W
memory_value : B * num_objects * CV * T * H * W
"""
batch_size, num_objects = memory_value.shape[:2]
memory_value = memory_value.flatten(start_dim=1, end_dim=2)
affinity = get_affinity(
memory_key, memory_shrinkage, query_key, query_selection
)
memory = readout(affinity, memory_value)
memory = memory.view(
batch_size, num_objects, self.value_dim, *memory.shape[-2:]
)
return memory
def segment(
self,
multi_scale_features,
memory_readout,
hidden_state,
selector=None,
h_out=True,
strip_bg=True,
):
hidden_state, logits = self.decoder(
*multi_scale_features, hidden_state, memory_readout, h_out=h_out
)
prob = torch.sigmoid(logits)
if selector is not None:
prob = prob * selector
logits, prob = aggregate(prob, dim=1, return_logits=True)
if strip_bg:
# Strip away the background
prob = prob[:, 1:]
return hidden_state, logits, prob
def forward(self, mode, *args, **kwargs):
if mode == "encode_key":
return self.encode_key(*args, **kwargs)
elif mode == "encode_value":
return self.encode_value(*args, **kwargs)
elif mode == "read_memory":
return self.read_memory(*args, **kwargs)
elif mode == "segment":
return self.segment(*args, **kwargs)
else:
raise NotImplementedError
def init_hyperparameters(self, config, model_path=None, map_location=None):
"""
Init three hyperparameters: key_dim, value_dim, and hidden_dim
If model_path is provided, we load these from the model weights
The actual parameters are then updated to the config in-place
Otherwise we load it either from the config or default
"""
if model_path is not None:
# load the model and key/value/hidden dimensions with some hacks
# config is updated with the loaded parameters
model_weights = torch.load(model_path, map_location=map_location)
self.key_dim = model_weights["key_proj.key_proj.weight"].shape[0]
self.value_dim = model_weights[
"value_encoder.fuser.block2.conv2.weight"
].shape[0]
self.disable_hidden = (
"decoder.hidden_update.transform.weight" not in model_weights
)
if self.disable_hidden:
self.hidden_dim = 0
else:
self.hidden_dim = (
model_weights["decoder.hidden_update.transform.weight"].shape[0]
// 3
)
print(
f"Hyperparameters read from the model weights: "
f"C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}"
)
else:
model_weights = None
# load dimensions from config or default
if "key_dim" not in config:
self.key_dim = 64
print(f"key_dim not found in config. Set to default {self.key_dim}")
else:
self.key_dim = config["key_dim"]
if "value_dim" not in config:
self.value_dim = 512
print(f"value_dim not found in config. Set to default {self.value_dim}")
else:
self.value_dim = config["value_dim"]
if "hidden_dim" not in config:
self.hidden_dim = 64
print(
f"hidden_dim not found in config. Set to default {self.hidden_dim}"
)
else:
self.hidden_dim = config["hidden_dim"]
self.disable_hidden = self.hidden_dim <= 0
config["key_dim"] = self.key_dim
config["value_dim"] = self.value_dim
config["hidden_dim"] = self.hidden_dim
return model_weights
def load_weights(self, src_dict, init_as_zero_if_needed=False):
# Maps SO weight (without other_mask) to MO weight (with other_mask)
for k in list(src_dict.keys()):
if k == "value_encoder.conv1.weight":
if src_dict[k].shape[1] == 4:
print("Converting weights from single object to multiple objects.")
pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
if not init_as_zero_if_needed:
print("Randomly initialized padding.")
nn.init.orthogonal_(pads)
else:
print("Zero-initialized padding.")
src_dict[k] = torch.cat([src_dict[k], pads], 1)
self.load_state_dict(src_dict)