|
"""This files contains training loss implementation. |
|
|
|
Copyright (2024) Bytedance Ltd. and/or its affiliates |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License. |
|
|
|
Ref: |
|
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py |
|
""" |
|
|
|
from typing import Mapping, Text, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from torch.cuda.amp import autocast |
|
|
|
from ..diffusion import create_diffusion |
|
from .blocks import SimpleMLPAdaLN |
|
from .perceptual_loss import PerceptualLoss |
|
from .discriminator import NLayerDiscriminator |
|
|
|
|
|
def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor: |
|
"""Hinge loss for discrminator. |
|
|
|
This function is borrowed from |
|
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20 |
|
""" |
|
loss_real = torch.mean(F.relu(1.0 - logits_real)) |
|
loss_fake = torch.mean(F.relu(1.0 + logits_fake)) |
|
d_loss = 0.5 * (loss_real + loss_fake) |
|
return d_loss |
|
|
|
|
|
def compute_lecam_loss( |
|
logits_real_mean: torch.Tensor, |
|
logits_fake_mean: torch.Tensor, |
|
ema_logits_real_mean: torch.Tensor, |
|
ema_logits_fake_mean: torch.Tensor, |
|
) -> torch.Tensor: |
|
"""Computes the LeCam loss for the given average real and fake logits. |
|
|
|
Args: |
|
logits_real_mean -> torch.Tensor: The average real logits. |
|
logits_fake_mean -> torch.Tensor: The average fake logits. |
|
ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits. |
|
ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits. |
|
|
|
Returns: |
|
lecam_loss -> torch.Tensor: The LeCam loss. |
|
""" |
|
lecam_loss = torch.mean( |
|
torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2) |
|
) |
|
lecam_loss += torch.mean( |
|
torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2) |
|
) |
|
return lecam_loss |
|
|
|
|
|
class ReconstructionLoss_Stage1(torch.nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
loss_config = config.losses |
|
self.quantizer_weight = loss_config.quantizer_weight |
|
self.target_codebook_size = 1024 |
|
|
|
def forward( |
|
self, |
|
target_codes: torch.Tensor, |
|
reconstructions: torch.Tensor, |
|
quantizer_loss: torch.Tensor, |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
return self._forward_generator(target_codes, reconstructions, quantizer_loss) |
|
|
|
def _forward_generator( |
|
self, |
|
target_codes: torch.Tensor, |
|
reconstructions: torch.Tensor, |
|
quantizer_loss: Mapping[Text, torch.Tensor], |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
reconstructions = reconstructions.contiguous() |
|
loss_fct = nn.CrossEntropyLoss(reduction="mean") |
|
batch_size = reconstructions.shape[0] |
|
reconstruction_loss = loss_fct( |
|
reconstructions.view(batch_size, self.target_codebook_size, -1), |
|
target_codes.view(batch_size, -1), |
|
) |
|
total_loss = ( |
|
reconstruction_loss |
|
+ self.quantizer_weight * quantizer_loss["quantizer_loss"] |
|
) |
|
|
|
loss_dict = dict( |
|
total_loss=total_loss.clone().detach(), |
|
reconstruction_loss=reconstruction_loss.detach(), |
|
quantizer_loss=( |
|
self.quantizer_weight * quantizer_loss["quantizer_loss"] |
|
).detach(), |
|
commitment_loss=quantizer_loss["commitment_loss"].detach(), |
|
codebook_loss=quantizer_loss["codebook_loss"].detach(), |
|
) |
|
|
|
return total_loss, loss_dict |
|
|
|
|
|
class ReconstructionLoss_Stage2(torch.nn.Module): |
|
def __init__(self, config): |
|
"""Initializes the losses module. |
|
|
|
Args: |
|
config: A dictionary, the configuration for the model and everything else. |
|
""" |
|
super().__init__() |
|
loss_config = config.losses |
|
self.discriminator = NLayerDiscriminator() |
|
|
|
self.reconstruction_loss = loss_config.reconstruction_loss |
|
self.reconstruction_weight = loss_config.reconstruction_weight |
|
self.quantizer_weight = loss_config.quantizer_weight |
|
self.perceptual_loss = PerceptualLoss(loss_config.perceptual_loss).eval() |
|
self.perceptual_weight = loss_config.perceptual_weight |
|
self.discriminator_iter_start = loss_config.discriminator_start |
|
|
|
self.discriminator_factor = loss_config.discriminator_factor |
|
self.discriminator_weight = loss_config.discriminator_weight |
|
self.lecam_regularization_weight = loss_config.lecam_regularization_weight |
|
self.lecam_ema_decay = loss_config.get("lecam_ema_decay", 0.999) |
|
if self.lecam_regularization_weight > 0.0: |
|
self.register_buffer("ema_real_logits_mean", torch.zeros((1))) |
|
self.register_buffer("ema_fake_logits_mean", torch.zeros((1))) |
|
|
|
self.config = config |
|
|
|
@autocast(enabled=False) |
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
reconstructions: torch.Tensor, |
|
extra_result_dict: Mapping[Text, torch.Tensor], |
|
global_step: int, |
|
mode: str = "generator", |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
|
|
inputs = inputs.float() |
|
reconstructions = reconstructions.float() |
|
|
|
if mode == "generator": |
|
return self._forward_generator( |
|
inputs, reconstructions, extra_result_dict, global_step |
|
) |
|
elif mode == "discriminator": |
|
return self._forward_discriminator(inputs, reconstructions, global_step) |
|
else: |
|
raise ValueError(f"Unsupported mode {mode}") |
|
|
|
def should_discriminator_be_trained(self, global_step: int): |
|
return global_step >= self.discriminator_iter_start |
|
|
|
def _forward_generator( |
|
self, |
|
inputs: torch.Tensor, |
|
reconstructions: torch.Tensor, |
|
extra_result_dict: Mapping[Text, torch.Tensor], |
|
global_step: int, |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
"""Generator training step.""" |
|
inputs = inputs.contiguous() |
|
reconstructions = reconstructions.contiguous() |
|
if self.reconstruction_loss == "l1": |
|
reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") |
|
elif self.reconstruction_loss == "l2": |
|
reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") |
|
else: |
|
raise ValueError( |
|
f"Unsuppored reconstruction_loss {self.reconstruction_loss}" |
|
) |
|
reconstruction_loss *= self.reconstruction_weight |
|
|
|
|
|
perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() |
|
|
|
|
|
generator_loss = torch.zeros((), device=inputs.device) |
|
discriminator_factor = ( |
|
self.discriminator_factor |
|
if self.should_discriminator_be_trained(global_step) |
|
else 0 |
|
) |
|
d_weight = 1.0 |
|
if discriminator_factor > 0.0 and self.discriminator_weight > 0.0: |
|
|
|
for param in self.discriminator.parameters(): |
|
param.requires_grad = False |
|
logits_fake = self.discriminator(reconstructions) |
|
generator_loss = -torch.mean(logits_fake) |
|
|
|
d_weight *= self.discriminator_weight |
|
|
|
|
|
quantizer_loss = extra_result_dict["quantizer_loss"] |
|
total_loss = ( |
|
reconstruction_loss |
|
+ self.perceptual_weight * perceptual_loss |
|
+ self.quantizer_weight * quantizer_loss |
|
+ d_weight * discriminator_factor * generator_loss |
|
) |
|
loss_dict = dict( |
|
total_loss=total_loss.clone().detach(), |
|
reconstruction_loss=reconstruction_loss.detach(), |
|
perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), |
|
quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(), |
|
weighted_gan_loss=( |
|
d_weight * discriminator_factor * generator_loss |
|
).detach(), |
|
discriminator_factor=torch.tensor(discriminator_factor), |
|
commitment_loss=extra_result_dict["commitment_loss"].detach(), |
|
codebook_loss=extra_result_dict["codebook_loss"].detach(), |
|
d_weight=d_weight, |
|
gan_loss=generator_loss.detach(), |
|
) |
|
|
|
return total_loss, loss_dict |
|
|
|
def _forward_discriminator( |
|
self, |
|
inputs: torch.Tensor, |
|
reconstructions: torch.Tensor, |
|
global_step: int, |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
"""Discrminator training step.""" |
|
discriminator_factor = ( |
|
self.discriminator_factor |
|
if self.should_discriminator_be_trained(global_step) |
|
else 0 |
|
) |
|
loss_dict = {} |
|
|
|
for param in self.discriminator.parameters(): |
|
param.requires_grad = True |
|
|
|
real_images = inputs.detach().requires_grad_(True) |
|
logits_real = self.discriminator(real_images) |
|
logits_fake = self.discriminator(reconstructions.detach()) |
|
|
|
discriminator_loss = discriminator_factor * hinge_d_loss( |
|
logits_real=logits_real, logits_fake=logits_fake |
|
) |
|
|
|
|
|
lecam_loss = torch.zeros((), device=inputs.device) |
|
if self.lecam_regularization_weight > 0.0: |
|
lecam_loss = ( |
|
compute_lecam_loss( |
|
torch.mean(logits_real), |
|
torch.mean(logits_fake), |
|
self.ema_real_logits_mean, |
|
self.ema_fake_logits_mean, |
|
) |
|
* self.lecam_regularization_weight |
|
) |
|
|
|
self.ema_real_logits_mean = ( |
|
self.ema_real_logits_mean * self.lecam_ema_decay |
|
+ torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay) |
|
) |
|
self.ema_fake_logits_mean = ( |
|
self.ema_fake_logits_mean * self.lecam_ema_decay |
|
+ torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay) |
|
) |
|
|
|
discriminator_loss += lecam_loss |
|
|
|
loss_dict = dict( |
|
discriminator_loss=discriminator_loss.detach(), |
|
logits_real=logits_real.detach().mean(), |
|
logits_fake=logits_fake.detach().mean(), |
|
lecam_loss=lecam_loss.detach(), |
|
) |
|
return discriminator_loss, loss_dict |
|
|
|
|
|
class ReconstructionLoss_Single_Stage(ReconstructionLoss_Stage2): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
loss_config = config.losses |
|
self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") |
|
|
|
if self.quantize_mode == "vae": |
|
self.kl_weight = loss_config.get("kl_weight", 1e-6) |
|
logvar_init = loss_config.get("logvar_init", 0.0) |
|
self.logvar = nn.Parameter( |
|
torch.ones(size=()) * logvar_init, requires_grad=False |
|
) |
|
|
|
def _forward_generator( |
|
self, |
|
inputs: torch.Tensor, |
|
reconstructions: torch.Tensor, |
|
extra_result_dict: Mapping[Text, torch.Tensor], |
|
global_step: int, |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
"""Generator training step.""" |
|
inputs = inputs.contiguous() |
|
reconstructions = reconstructions.contiguous() |
|
if self.reconstruction_loss == "l1": |
|
reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean") |
|
elif self.reconstruction_loss == "l2": |
|
reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean") |
|
else: |
|
raise ValueError( |
|
f"Unsuppored reconstruction_loss {self.reconstruction_loss}" |
|
) |
|
reconstruction_loss *= self.reconstruction_weight |
|
|
|
|
|
perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean() |
|
|
|
|
|
generator_loss = torch.zeros((), device=inputs.device) |
|
discriminator_factor = ( |
|
self.discriminator_factor |
|
if self.should_discriminator_be_trained(global_step) |
|
else 0 |
|
) |
|
d_weight = 1.0 |
|
if discriminator_factor > 0.0 and self.discriminator_weight > 0.0: |
|
|
|
for param in self.discriminator.parameters(): |
|
param.requires_grad = False |
|
logits_fake = self.discriminator(reconstructions) |
|
generator_loss = -torch.mean(logits_fake) |
|
|
|
d_weight *= self.discriminator_weight |
|
|
|
if self.quantize_mode == "vq": |
|
|
|
quantizer_loss = extra_result_dict["quantizer_loss"] |
|
total_loss = ( |
|
reconstruction_loss |
|
+ self.perceptual_weight * perceptual_loss |
|
+ self.quantizer_weight * quantizer_loss |
|
+ d_weight * discriminator_factor * generator_loss |
|
) |
|
loss_dict = dict( |
|
total_loss=total_loss.clone().detach(), |
|
reconstruction_loss=reconstruction_loss.detach(), |
|
perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), |
|
quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(), |
|
weighted_gan_loss=( |
|
d_weight * discriminator_factor * generator_loss |
|
).detach(), |
|
discriminator_factor=torch.tensor(discriminator_factor), |
|
commitment_loss=extra_result_dict["commitment_loss"].detach(), |
|
codebook_loss=extra_result_dict["codebook_loss"].detach(), |
|
d_weight=d_weight, |
|
gan_loss=generator_loss.detach(), |
|
) |
|
elif self.quantize_mode == "vae": |
|
|
|
reconstruction_loss = reconstruction_loss / torch.exp(self.logvar) |
|
posteriors = extra_result_dict |
|
kl_loss = posteriors.kl() |
|
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] |
|
total_loss = ( |
|
reconstruction_loss |
|
+ self.perceptual_weight * perceptual_loss |
|
+ self.kl_weight * kl_loss |
|
+ d_weight * discriminator_factor * generator_loss |
|
) |
|
loss_dict = dict( |
|
total_loss=total_loss.clone().detach(), |
|
reconstruction_loss=reconstruction_loss.detach(), |
|
perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(), |
|
kl_loss=(self.kl_weight * kl_loss).detach(), |
|
weighted_gan_loss=( |
|
d_weight * discriminator_factor * generator_loss |
|
).detach(), |
|
discriminator_factor=torch.tensor(discriminator_factor), |
|
d_weight=d_weight, |
|
gan_loss=generator_loss.detach(), |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
return total_loss, loss_dict |
|
|
|
|
|
class MLMLoss(torch.nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.label_smoothing = config.losses.label_smoothing |
|
self.loss_weight_unmasked_token = config.losses.loss_weight_unmasked_token |
|
self.criterion = torch.nn.CrossEntropyLoss( |
|
label_smoothing=self.label_smoothing, reduction="none" |
|
) |
|
|
|
def forward( |
|
self, inputs: torch.Tensor, targets: torch.Tensor, weights=None |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
inputs = rearrange(inputs, "b n c -> b c n") |
|
loss = self.criterion(inputs, targets) |
|
weights = weights.to(loss) |
|
loss_weights = ( |
|
1.0 - weights |
|
) * self.loss_weight_unmasked_token + weights |
|
loss = (loss * loss_weights).sum() / (loss_weights.sum() + 1e-8) |
|
|
|
correct_tokens = ((torch.argmax(inputs, dim=1) == targets) * weights).sum( |
|
dim=1 |
|
) / (weights.sum(1) + 1e-8) |
|
return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()} |
|
|
|
|
|
class ARLoss(torch.nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.target_vocab_size = config.model.vq_model.codebook_size |
|
self.criterion = torch.nn.CrossEntropyLoss(reduction="mean") |
|
|
|
def forward( |
|
self, logits: torch.Tensor, labels: torch.Tensor |
|
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
|
shift_logits = logits[..., :-1, :].permute(0, 2, 1).contiguous() |
|
shift_labels = labels.contiguous() |
|
shift_logits = shift_logits.view( |
|
shift_logits.shape[0], self.target_vocab_size, -1 |
|
) |
|
shift_labels = shift_labels.view(shift_labels.shape[0], -1) |
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = self.criterion(shift_logits, shift_labels) |
|
correct_tokens = (torch.argmax(shift_logits, dim=1) == shift_labels).sum( |
|
dim=1 |
|
) / shift_labels.size(1) |
|
return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()} |
|
|
|
|
|
class DiffLoss(nn.Module): |
|
"""Diffusion Loss""" |
|
|
|
def __init__(self, config): |
|
super(DiffLoss, self).__init__() |
|
self.in_channels = config.model.vq_model.token_size |
|
|
|
self.net = SimpleMLPAdaLN( |
|
in_channels=self.in_channels, |
|
model_channels=config.losses.diffloss_w, |
|
out_channels=self.in_channels * 2, |
|
z_channels=config.model.maskgen.decoder_embed_dim, |
|
num_res_blocks=config.losses.diffloss_d, |
|
grad_checkpointing=config.get("training.grad_checkpointing", False), |
|
) |
|
|
|
self.train_diffusion = create_diffusion( |
|
timestep_respacing="", noise_schedule="cosine" |
|
) |
|
self.gen_diffusion = create_diffusion( |
|
timestep_respacing=config.losses.get("num_sampling_steps", "100"), |
|
noise_schedule="cosine", |
|
) |
|
|
|
def forward(self, target, z, mask=None): |
|
t = torch.randint( |
|
0, |
|
self.train_diffusion.num_timesteps, |
|
(target.shape[0],), |
|
device=target.device, |
|
) |
|
model_kwargs = dict(c=z) |
|
loss_dict = self.train_diffusion.training_losses( |
|
self.net, target, t, model_kwargs |
|
) |
|
loss = loss_dict["loss"] |
|
if mask is not None: |
|
loss = (loss * mask).sum() / mask.sum() |
|
|
|
loss_dict = dict( |
|
diff_loss=loss.clone().mean().detach(), |
|
) |
|
|
|
return loss.mean(), loss_dict |
|
|
|
def sample(self, z, temperature=1.0, cfg=1.0): |
|
|
|
if not cfg == 1.0: |
|
noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda() |
|
noise = torch.cat([noise, noise], dim=0) |
|
model_kwargs = dict(c=z, cfg_scale=cfg) |
|
sample_fn = self.net.forward_with_cfg |
|
else: |
|
noise = torch.randn(z.shape[0], self.in_channels).cuda() |
|
model_kwargs = dict(c=z) |
|
sample_fn = self.net.forward |
|
|
|
sampled_token_latent = self.gen_diffusion.p_sample_loop( |
|
sample_fn, |
|
noise.shape, |
|
noise, |
|
clip_denoised=False, |
|
model_kwargs=model_kwargs, |
|
progress=False, |
|
temperature=temperature, |
|
) |
|
|
|
return sampled_token_latent |
|
|