Mugs / src /model.py
zhoupans's picture
Upload 13 files
3c849be
raw
history blame
No virus
20.8 kB
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
models and functions for building student and teacher networks for multi-granular losses.
"""
import torch
import torch.nn as nn
import src.vision_transformer as vits
from src.vision_transformer import trunc_normal_
class Instance_Superivsion_Head(nn.Module):
"""
a class to implement Instance Superivsion Head
--in_dim: input dimension of projection head
--hidden_dim: hidden dimension of projection head
--out_dim: ouput dimension of projection and prediction heads
--pred_hidden_dim: hidden dimension of prediction head
--nlayers: layer number of projection head. prediction head has nlayers-1 layer
--proj_bn: whether we use batch normalization in projection head
--pred_bn: whether we use batch normalization in prediction head
--norm_before_pred: whether we use normalization before prediction head
"""
def __init__(
self,
in_dim,
hidden_dim=2048,
out_dim=256,
pred_hidden_dim=4096,
nlayers=3,
proj_bn=False,
pred_bn=False,
norm_before_pred=True,
):
super().__init__()
nlayers = max(nlayers, 1)
self.norm_before_pred = norm_before_pred
self.projector = self._build_mlp(
nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn
)
self.apply(self._init_weights)
self.predictor = None
if pred_hidden_dim > 0: # teacher no, student yes
self.predictor = self._build_mlp(
nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn
)
def _init_weights(self, m):
"""
initilize the parameters in network
"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False):
"""
build a mlp
"""
mlp = []
for layer in range(num_layers):
dim1 = input_dim if layer == 0 else hidden_dim
dim2 = output_dim if layer == num_layers - 1 else hidden_dim
mlp.append(nn.Linear(dim1, dim2, bias=False))
if layer < num_layers - 1:
if use_bn:
mlp.append(nn.BatchNorm1d(dim2))
mlp.append(nn.GELU())
return nn.Sequential(*mlp)
def forward(self, x, return_target=False):
"""
forward the input through projection head for teacher and
projection/prediction heads for student
"""
feat = self.projector(x)
if return_target:
feat = nn.functional.normalize(feat, dim=-1, p=2)
return feat
## return prediction
if self.norm_before_pred:
feat = nn.functional.normalize(feat, dim=-1, p=2)
pred = self.predictor(feat)
pred = nn.functional.normalize(pred, dim=-1, p=2)
return pred
class Local_Group_Superivsion_Head(nn.Module):
"""
a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head
--in_dim: input dimension of projection head
--hidden_dim: hidden dimension of projection head
--out_dim: ouput dimension of projection and prediction heads
--pred_hidden_dim: hidden dimension of prediction head
--nlayers: layer number of projection head. prediction head has nlayers-1 layer
--proj_bn: whether we use batch normalization in projection head
--pred_bn: whether we use batch normalization in prediction head
--norm_before_pred: whether we use normalization before prediction head
"""
def __init__(
self,
in_dim,
hidden_dim=2048,
out_dim=256,
pred_hidden_dim=4096,
nlayers=3,
proj_bn=False,
pred_bn=False,
norm_before_pred=True,
):
super().__init__()
nlayers = max(nlayers, 1)
self.norm_before_pred = norm_before_pred
self.projector = self._build_mlp(
nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn
)
self.apply(self._init_weights)
self.predictor = None
if pred_hidden_dim > 0: # teacher no, student yes
self.predictor = self._build_mlp(
nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn
)
def _init_weights(self, m):
"""
initilize the parameters in network
"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False):
"""
build a mlp
"""
mlp = []
for layer in range(num_layers):
dim1 = input_dim if layer == 0 else hidden_dim
dim2 = output_dim if layer == num_layers - 1 else hidden_dim
mlp.append(nn.Linear(dim1, dim2, bias=False))
if layer < num_layers - 1:
if use_bn:
mlp.append(nn.BatchNorm1d(dim2))
mlp.append(nn.GELU())
return nn.Sequential(*mlp)
def forward(self, x, return_target=False):
"""
forward the input through projection head for teacher and
projection/prediction heads for student
"""
feat = self.projector(x)
if return_target:
feat = nn.functional.normalize(feat, dim=-1, p=2)
return feat
## return prediction
if self.norm_before_pred:
feat = nn.functional.normalize(feat, dim=-1, p=2)
pred = self.predictor(feat)
pred = nn.functional.normalize(pred, dim=-1, p=2)
return pred
class Group_Superivsion_Head(nn.Module):
"""
a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head
--in_dim: input dimension of projection head
--hidden_dim: hidden dimension of projection head
--out_dim: ouput dimension of projection and prediction heads
--pred_hidden_dim: hidden dimension of prediction head
--nlayers: layer number of projection head. prediction head has nlayers-1 layer
--proj_bn: whether we use batch normalization in projection head
--pred_bn: whether we use batch normalization in prediction head
--norm_before_pred: whether we use normalization before prediction head
"""
def __init__(
self,
in_dim,
out_dim,
hidden_dim=2048,
bottleneck_dim=256,
nlayers=3,
use_bn=False,
norm_last_layer=True,
):
super().__init__()
nlayers = max(nlayers, 1)
self.projector = self._build_mlp(
nlayers, in_dim, hidden_dim, bottleneck_dim, use_bn=use_bn
)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _build_mlp(self, num_layers, in_dim, hidden_dim, output_dim, use_bn=False):
"""
build a mlp
"""
if num_layers == 1:
mlp = nn.Linear(in_dim, output_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(num_layers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, output_dim))
mlp = nn.Sequential(*layers)
return mlp
def _init_weights(self, m):
"""
initilize the parameters in network
"""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
"""
forward the input through the projection and last prediction layer
"""
feat = self.projector(x)
feat = nn.functional.normalize(feat, dim=-1, p=2)
feat = self.last_layer(feat)
return feat
class Block_mem(nn.Module):
"""
a class to implement a memory block for local group supervision
--dim: feature vector dimenstion in the memory
--K: memory size
--top_n: number for neighbors in local group supervision
"""
def __init__(self, dim, K=2048, top_n=10):
super().__init__()
self.dim = dim
self.K = K
self.top_n = top_n
# create the queue
self.register_buffer("queue_q", torch.randn(K, dim))
self.register_buffer("queue_k", torch.randn(K, dim))
self.register_buffer("queue_v", torch.randn(K, dim))
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _dequeue_and_enqueue(self, query, weak_aug_flags):
"""
update memory queue
"""
# import pdb
# pdb.set_trace()
len_weak = 0
query = concat_all_gather(query)
if weak_aug_flags is not None:
weak_aug_flags = weak_aug_flags.cuda()
weak_aug_flags = concat_all_gather(weak_aug_flags)
idx_weak = torch.nonzero(weak_aug_flags)
len_weak = len(idx_weak)
if len_weak > 0:
idx_weak = idx_weak.squeeze(-1)
query = query[idx_weak]
else:
return len_weak
all_size = query.shape[0]
ptr = int(self.queue_ptr)
remaining_size = ptr + all_size - self.K
if remaining_size <= 0:
self.queue_q[ptr : ptr + all_size, :] = query
self.queue_k[ptr : ptr + all_size, :] = query
self.queue_v[ptr : ptr + all_size, :] = query
ptr = ptr + all_size
self.queue_ptr[0] = (ptr + all_size) % self.K
else:
self.queue_q[ptr : self.K, :] = query[0 : self.K - ptr, :]
self.queue_k[ptr : self.K, :] = query[0 : self.K - ptr, :]
self.queue_v[ptr : self.K, :] = query[0 : self.K - ptr, :]
self.queue_q[0:remaining_size, :] = query[self.K - ptr :, :]
self.queue_k[0:remaining_size, :] = query[self.K - ptr :, :]
self.queue_v[0:remaining_size, :] = query[self.K - ptr :, :]
self.queue_ptr[0] = remaining_size
return len_weak
@torch.no_grad()
def _get_similarity_index(self, x):
"""
compute the index of the top-n neighbors (key-value pair) in memory
"""
x = nn.functional.normalize(x, dim=-1)
queue_q = nn.functional.normalize(self.queue_q, dim=-1)
cosine = x @ queue_q.T
_, index = torch.topk(cosine, self.top_n, dim=-1)
return index
@torch.no_grad()
def _get_similarity_samples(self, query, index=None):
"""
compute top-n neighbors (key-value pair) in memory
"""
if index is None:
index = self._get_similarity_index(query)
get_k = self.queue_k[index.view(-1)]
get_v = self.queue_v[index.view(-1)]
B, tn = index.shape
get_k = get_k.view(B, tn, self.dim)
get_v = get_v.view(B, tn, self.dim)
return get_k, get_v
def forward(self, query):
"""
forward to find the top-n neighbors (key-value pair) in memory
"""
get_k, get_v = self._get_similarity_samples(query)
return get_k, get_v
class vit_mem(nn.Module):
"""
a class to implement a memory for local group supervision
--dim: feature vector dimenstion in the memory
--K: memory size
--top_n: number for neighbors in local group supervision
"""
def __init__(self, dim, K=2048, top_n=10):
super().__init__()
self.block = Block_mem(dim, K, top_n)
def _dequeue_and_enqueue(self, query, weak_aug_flags):
"""
update memory queue
"""
query = query.float()
weak_num = self.block._dequeue_and_enqueue(query, weak_aug_flags)
return weak_num
def forward(self, query):
"""
forward to find the top-n neighbors (key-value pair) in memory
"""
query = query.float()
get_k, get_v = self.block(query)
return get_k, get_v
class Mugs_Wrapper(nn.Module):
"""
a class to implement a student or teacher wrapper for mugs
--backbone: the backnone of student/teacher, e.g. ViT-small
--instance_head: head, including projection/prediction heads, for instance supervision
--local_group_head: head, including projection/prediction heads, for local group supervision
--group_head: projection head for group supervision
"""
def __init__(self, backbone, instance_head, local_group_head, group_head):
super(Mugs_Wrapper, self).__init__()
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
self.backbone = backbone
self.instance_head = instance_head
self.local_group_head = local_group_head
self.group_head = group_head
def forward(self, x, return_target=False, local_group_memory_inputs=None):
"""
forward input to get instance/local-group/group targets or predictions
"""
# convert to list
if not isinstance(x, list):
x = [x]
idx_crops = torch.cumsum(
torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in x]),
return_counts=True,
)[1],
0,
)
start_idx = 0
class_tokens = torch.empty(0).to(x[0].device)
mean_patch_tokens = torch.empty(0).to(x[0].device)
memory_class_tokens = torch.empty(0).to(x[0].device)
for _, end_idx in enumerate(idx_crops):
input = torch.cat(x[start_idx:end_idx])
token_feat, memory_class_token_feat = self.backbone(
input,
return_all=True,
local_group_memory_inputs=local_group_memory_inputs,
) # [[16, 197, 384], [16, 384]] teacher
# [[16, 197, 384], [16, 384]] student [[48, 37, 384], [48, 384]]
class_token_feat = token_feat[
:, 0
] # class tokens in ViT, [16, 384] teacher [16, 384] student [48, 384]
class_tokens = torch.cat((class_tokens, class_token_feat))
start_idx = end_idx
if self.local_group_head is not None:
memory_class_tokens = torch.cat(
(memory_class_tokens, memory_class_token_feat)
)
if input.shape[-1] == 224:
mean_patch_tokens = torch.cat(
(mean_patch_tokens, token_feat[:, 1:].mean(dim=1))
)
## target [16, 256] for teacher, [64, 256] for student,
instance_feat = (
self.instance_head(class_tokens, return_target)
if self.instance_head is not None
else None
)
## target [16, 256] for teacher, [64, 256] for student
local_group_feat = (
self.local_group_head(memory_class_tokens, return_target)
if self.local_group_head is not None
else None
)
# target [16, 65536] for teacher, [64, 65536] for student
group_feat = (
self.group_head(class_tokens) if self.group_head is not None else None
)
return instance_feat, local_group_feat, group_feat, mean_patch_tokens.detach()
def get_model(args):
"""
build a student or teacher for mugs, includeing backbone, instance/local-group/group heads,
and memory buffer
"""
## backbone
if args.arch in vits.__dict__.keys():
student = vits.__dict__[args.arch](
patch_size=args.patch_size,
num_relation_blocks=1,
drop_path_rate=args.drop_path_rate, # stochastic depth
)
teacher = vits.__dict__[args.arch](
patch_size=args.patch_size, num_relation_blocks=1
)
embed_dim = student.embed_dim
else:
assert f"Unknow architecture: {args.arch}"
## memory buffer for local-group loss
student_mem = vit_mem(
embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n
)
teacher_mem = vit_mem(
embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n
)
## multi-crop wrapper handles forward with inputs of different resolutions
student_instance_head, student_local_group_head, student_group_head = (
None,
None,
None,
)
teacher_instance_head, teacher_local_group_head, teacher_group_head = (
None,
None,
None,
)
# instance head
if args.loss_weights[0] > 0:
student_instance_head = Instance_Superivsion_Head(
in_dim=embed_dim,
hidden_dim=2048,
out_dim=args.instance_out_dim,
pred_hidden_dim=4096,
nlayers=3,
proj_bn=args.use_bn_in_head,
pred_bn=False,
norm_before_pred=args.norm_before_pred,
)
teacher_instance_head = Instance_Superivsion_Head(
in_dim=embed_dim,
hidden_dim=2048,
out_dim=args.instance_out_dim,
pred_hidden_dim=0,
nlayers=3,
proj_bn=args.use_bn_in_head,
pred_bn=False,
norm_before_pred=args.norm_before_pred,
)
# local group head
if args.loss_weights[1] > 0:
student_local_group_head = Local_Group_Superivsion_Head(
in_dim=embed_dim,
hidden_dim=2048,
out_dim=args.local_group_out_dim,
pred_hidden_dim=4096,
nlayers=3,
proj_bn=args.use_bn_in_head,
pred_bn=False,
norm_before_pred=args.norm_before_pred,
)
teacher_local_group_head = Local_Group_Superivsion_Head(
in_dim=embed_dim,
hidden_dim=2048,
out_dim=args.local_group_out_dim,
pred_hidden_dim=0,
nlayers=3,
proj_bn=args.use_bn_in_head,
pred_bn=False,
norm_before_pred=args.norm_before_pred,
)
# group head
if args.loss_weights[2] > 0:
student_group_head = Group_Superivsion_Head(
in_dim=embed_dim,
out_dim=args.group_out_dim,
hidden_dim=2048,
bottleneck_dim=args.group_bottleneck_dim,
nlayers=3,
use_bn=args.use_bn_in_head,
norm_last_layer=args.norm_last_layer,
)
teacher_group_head = Group_Superivsion_Head(
in_dim=embed_dim,
out_dim=args.group_out_dim,
hidden_dim=2048,
bottleneck_dim=args.group_bottleneck_dim,
nlayers=3,
use_bn=args.use_bn_in_head,
norm_last_layer=args.norm_last_layer,
)
# multi-crop wrapper
student = Mugs_Wrapper(
student, student_instance_head, student_local_group_head, student_group_head
)
teacher = Mugs_Wrapper(
teacher, teacher_instance_head, teacher_local_group_head, teacher_group_head
)
return student, teacher, student_mem, teacher_mem
# utils
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output