Mugs / src /loss.py
zhoupans's picture
Upload 13 files
3c849be
raw
history blame
9.12 kB
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
functions for building multi-granular losses.
"""
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from utils import concat_all_gather
class InfoNCELoss(nn.Module):
"""
vanilla infoNCEloss.
--ncrops: how many crops are used in student networks
--dim: feature dimension in queue determinted by output dimention of student network
--queue_size: queue size
--temperature: temperature parameter for infoNCEloss
"""
def __init__(self, ncrops, dim=256, queue_size=65536, temperature=0.2):
super().__init__()
self.queue_size = queue_size
self.temperature = temperature
self.register_buffer("queue", torch.randn(dim, queue_size))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
self.CrossEntropyLoss = nn.CrossEntropyLoss()
self.ncrops = ncrops
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""
queue update
"""
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
# replace the keys at ptr (dequeue and enqueue)
if ptr + batch_size <= self.queue_size:
self.queue[:, ptr : ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.queue_size
else:
keys_t = keys.T
queue_remaining_size = self.queue_size - ptr
self.queue[:, ptr:] = keys_t[:, :queue_remaining_size]
self.queue[:, : batch_size - queue_remaining_size] = keys_t[
:, queue_remaining_size:
]
ptr = batch_size - queue_remaining_size # move pointer
self.queue_ptr[0] = ptr
# student_output, teacher_output
def forward(self, student_output, teacher_output, epoch):
"""
Cross-entropy between softmax outputs of the teacher and student networks.
"""
preds = student_output.chunk(self.ncrops)
targets = teacher_output.detach().chunk(2)
small_crop_loss, large_crop_loss = 0, 0
small_loss_terms, large_loss_terms = 0, 0
queue_feat = self.queue.clone().detach()
for t_idx, targ in enumerate(targets):
for p_idx, pred in enumerate(preds):
if t_idx == p_idx:
continue
# positive logits: Nx1
l_pos = torch.einsum("nc,nc->n", [pred, targ]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum("nc,ck->nk", [pred, queue_feat])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.temperature
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(
logits.device
)
loss = self.CrossEntropyLoss(logits, labels)
if p_idx < 2: ## large crop loss, namely loss on 224-sized images
large_crop_loss += loss
large_loss_terms += 1
else: ## small crop loss, namely loss on 96-sized images
small_crop_loss += loss
small_loss_terms += 1
# dequeue and enqueue
self._dequeue_and_enqueue(targ)
large_crop_loss /= large_loss_terms
small_crop_loss /= small_loss_terms
loss = 0.5 * (large_crop_loss + small_crop_loss)
return loss
class ClusteringLoss(nn.Module):
"""
Clustering loss which is very simialr to the one in DINO
--out_dim: center dimension determinted by output dimention of student network
--ncrops: how many crops are used in student networks
--warmup_teacher_temp: Initial value for the teacher temperature
--teacher_temp: Final value (after linear warmup) of the teacher temperature
--warmup_teacher_temp_epochs: Number of warmup epochs for the teacher temperature
--nepochs: total training epoch
--student_temp: temperature parameter in student output
--center_momentum: EMA parameter for center update
"""
def __init__(
self,
out_dim,
ncrops,
warmup_teacher_temp,
teacher_temp,
warmup_teacher_temp_epochs,
nepochs,
student_temp=0.1,
center_momentum=0.9,
):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.ncrops = ncrops
self.register_buffer("center", torch.zeros(1, out_dim))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self.teacher_temp_schedule = np.concatenate(
(
np.linspace(
warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs
),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp,
)
)
def forward(self, student_output, teacher_output, epoch):
"""
Cross-entropy between softmax outputs of the teacher and student networks.
"""
student_out = student_output / self.student_temp
student_out = student_out.chunk(self.ncrops)
# teacher centering and sharpening
temp = self.teacher_temp_schedule[epoch]
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2)
loss_large_crop, loss_small_crop = 0.0, 0.0
loss_terms_large_crop, loss_terms_small_crop = 0, 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
# we skip cases where student and teacher operate on the same view
continue
loss = torch.sum(
-q * F.log_softmax(student_out[v], dim=-1), dim=-1
).mean()
if v < 2:
loss_large_crop += loss
loss_terms_large_crop += 1
else:
loss_small_crop += loss
loss_terms_small_crop += 1
self.update_center(teacher_output)
loss_large_crop /= loss_terms_large_crop
loss_small_crop /= loss_terms_small_crop
total_loss = 0.5 * (loss_large_crop + loss_small_crop)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.mean(teacher_output, dim=0, keepdim=False)
dist.all_reduce(batch_center)
batch_center = batch_center / dist.get_world_size()
# ema update
self.center = self.center * self.center_momentum + batch_center * (
1 - self.center_momentum
)
def get_multi_granular_loss(args):
"""
build the multi-granular loss
"""
all_losses, all_weights = {}, {}
## build the instance discrimination loss
instance_supervision_loss = InfoNCELoss(
args.local_crops_number + 2,
dim=args.instance_out_dim,
queue_size=args.instance_queue_size,
temperature=args.instance_temp,
).cuda()
all_losses["instance-sup."] = instance_supervision_loss
all_weights["instance-sup."] = args.loss_weights[0]
## build the local group discrimination loss
local_group_supervision = InfoNCELoss(
args.local_crops_number + 2,
dim=args.local_group_out_dim,
queue_size=args.local_group_queue_size,
temperature=args.local_group_temp,
).cuda()
all_losses["local-group-sup."] = local_group_supervision
all_weights["local-group-sup."] = args.loss_weights[1]
## build the group discrimination loss
group_loss = ClusteringLoss(
args.group_out_dim,
args.local_crops_number
+ 2, # total number of crops = 2 global crops + local_crops_number
args.group_warmup_teacher_temp,
args.group_teacher_temp,
args.group_warmup_teacher_temp_epochs,
args.epochs,
student_temp=args.group_student_temp,
center_momentum=0.9,
).cuda()
all_losses["group-sup."] = group_loss
all_weights["group-sup."] = args.loss_weights[2]
return all_losses, all_weights