File size: 3,507 Bytes
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
from lpips import LPIPS
from kornia import color
# from taming.modules.losses.vqperceptual import *

class ImageSecretLoss(nn.Module):
    def __init__(self, recon_type='rgb', recon_weight=1., perceptual_weight=1.0, secret_weight=10., kl_weight=0.000001, logvar_init=0.0, ramp=100000, max_image_weight_ratio=2.) -> None:
        super().__init__()
        self.recon_type = recon_type
        assert recon_type in ['rgb', 'yuv']
        if recon_type == 'yuv':
            self.register_buffer('yuv_scales', torch.tensor([1,100,100]).unsqueeze(1).float())  # [3,1]
        self.recon_weight = recon_weight
        self.perceptual_weight = perceptual_weight
        self.secret_weight = secret_weight
        self.kl_weight = kl_weight

        self.ramp = ramp
        self.max_image_weight = max_image_weight_ratio * secret_weight - 1
        self.register_buffer('ramp_on', torch.tensor(False))
        self.register_buffer('step0', torch.tensor(1e9))  # large number

        self.perceptual_loss = LPIPS().eval()
        self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
        self.bce = nn.BCEWithLogitsLoss(reduction="none")
    
    def activate_ramp(self, global_step):
        if not self.ramp_on:  # do not activate ramp twice
            self.step0 = torch.tensor(global_step)
            self.ramp_on = ~self.ramp_on
            print('[TRAINING] Activate ramp for image loss at step ', global_step)

    def compute_recon_loss(self, inputs, reconstructions):
        if self.recon_type == 'rgb':
            rec_loss = torch.abs(inputs - reconstructions).mean(dim=[1,2,3])
        elif self.recon_type == 'yuv':
            reconstructions_yuv = color.rgb_to_yuv((reconstructions + 1) / 2)
            inputs_yuv = color.rgb_to_yuv((inputs + 1) / 2)
            yuv_loss = torch.mean((reconstructions_yuv - inputs_yuv)**2, dim=[2,3])
            rec_loss = torch.mm(yuv_loss, self.yuv_scales).squeeze(1)
        else:
            raise ValueError(f"Unknown recon type {self.recon_type}")
        return rec_loss
    
    def forward(self, inputs, reconstructions, posteriors, secret_gt, secret_pred, global_step):
        loss_dict = {}
        rec_loss = self.compute_recon_loss(inputs.contiguous(), reconstructions.contiguous())

        loss = rec_loss*self.recon_weight

        if self.perceptual_weight > 0:
            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()).mean(dim=[1,2,3])
            loss += self.perceptual_weight * p_loss
            loss_dict['p_loss'] = p_loss.mean()

        loss = loss / torch.exp(self.logvar) + self.logvar
        if self.kl_weight > 0:
            kl_loss = posteriors.kl()
            loss += kl_loss*self.kl_weight
            loss_dict['kl_loss'] = kl_loss.mean()

        image_weight = 1 + min(self.max_image_weight, max(0., self.max_image_weight*(global_step - self.step0.item())/self.ramp))

        secret_loss = self.bce(secret_pred, secret_gt).mean(dim=1)
        loss = (loss*image_weight + secret_loss*self.secret_weight) / (image_weight+self.secret_weight)

        # loss dict update
        bit_acc = ((secret_pred.detach() > 0).float() == secret_gt).float().mean()
        loss_dict['bit_acc'] = bit_acc
        loss_dict['loss'] = loss.mean()
        loss_dict['img_lw'] = image_weight/self.secret_weight
        loss_dict['rec_loss'] = rec_loss.mean()
        loss_dict['secret_loss'] = secret_loss.mean()

        return loss.mean(), loss_dict