File size: 2,218 Bytes
509db6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import torch
from generative.losses import PatchAdversarialLoss

intensity_loss = torch.nn.L1Loss()
adv_loss = PatchAdversarialLoss(criterion="least_squares")

adv_weight = 0.5
perceptual_weight = 1.0
# kl_weight: important hyper-parameter.
#     If too large, decoder cannot recon good results from latent space.
#     If too small, latent space will not be regularized enough for the diffusion model
kl_weight = 1e-6


def compute_kl_loss(z_mu, z_sigma):
    kl_loss = 0.5 * torch.sum(
        z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=list(range(1, len(z_sigma.shape)))
    )
    return torch.sum(kl_loss) / kl_loss.shape[0]


def generator_loss(gen_images, real_images, z_mu, z_sigma, disc_net, loss_perceptual):
    recons_loss = intensity_loss(gen_images, real_images)
    kl_loss = compute_kl_loss(z_mu, z_sigma)
    p_loss = loss_perceptual(gen_images.float(), real_images.float())
    loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss

    logits_fake = disc_net(gen_images)[-1]
    generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
    loss_g = loss_g + adv_weight * generator_loss

    return loss_g


def discriminator_loss(gen_images, real_images, disc_net):
    logits_fake = disc_net(gen_images.contiguous().detach())[-1]
    loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
    logits_real = disc_net(real_images.contiguous().detach())[-1]
    loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
    discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
    loss_d = adv_weight * discriminator_loss
    return loss_d