File size: 3,541 Bytes
32ca76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pdb

import torch.optim
import torch.nn as nn
from models.hinet import Hinet
# from utils.attacks import attack_layer, mp3_attack_v2, butterworth_attack
import numpy as np
import random

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


class Model(nn.Module):
    def __init__(self, num_point, num_bit, n_fft, hop_length, use_recover_layer, num_layers):
        super(Model, self).__init__()
        self.hinet = Hinet(num_layers=num_layers)
        self.watermark_fc = torch.nn.Linear(num_bit, num_point)
        self.watermark_fc_back = torch.nn.Linear(num_point, num_bit)
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.dropout1 = torch.nn.Dropout()
        self.identity = torch.nn.Identity()
        self.recover_layer = SameSizeConv2d(2, 2)
        self.use_recover_layer = use_recover_layer

    def stft(self, data):
        window = torch.hann_window(self.n_fft).to(data.device)
        tmp = torch.stft(data, n_fft=self.n_fft, hop_length=self.hop_length, window=window, return_complex=False)
        # [1, 501, 41, 2]
        return tmp

    def istft(self, signal_wmd_fft):
        window = torch.hann_window(self.n_fft).to(signal_wmd_fft.device)

        # Changed in version 2.0: Real datatype inputs are no longer supported. Input must now have a complex datatype, as returned by stft(..., return_complex=True).

        return torch.istft(signal_wmd_fft, n_fft=self.n_fft, hop_length=self.hop_length, window=window,
                           return_complex=False)

    def encode(self, signal, message, need_fft=False):
        # 1.信号执行fft
        signal_fft = self.stft(signal)
        # import pdb
        # pdb.set_trace()
        # (batch,freq_bins,time_frames,2)

        # 2.Message执行fft
        message_expand = self.watermark_fc(message)
        message_fft = self.stft(message_expand)

        # 3.encode
        signal_wmd_fft, msg_remain = self.enc_dec(signal_fft, message_fft, rev=False)
        # (batch,freq_bins,time_frames,2)
        signal_wmd = self.istft(signal_wmd_fft)
        if need_fft:
            return signal_wmd, signal_fft, message_fft

        return signal_wmd

    def decode(self, signal):
        signal_fft = self.stft(signal)
        if self.use_recover_layer:
            signal_fft = self.recover_layer(signal_fft)
        watermark_fft = signal_fft
        # watermark_fft = torch.randn(signal_fft.shape).cuda()
        _, message_restored_fft = self.enc_dec(signal_fft, watermark_fft, rev=True)
        message_restored_expanded = self.istft(message_restored_fft)
        message_restored_float = self.watermark_fc_back(message_restored_expanded).clamp(-1, 1)
        return message_restored_float

    def enc_dec(self, signal, watermark, rev):
        signal = signal.permute(0, 3, 2, 1)
        # [4, 2, 41, 501]

        watermark = watermark.permute(0, 3, 2, 1)

        # pdb.set_trace()
        signal2, watermark2 = self.hinet(signal, watermark, rev)
        return signal2.permute(0, 3, 2, 1), watermark2.permute(0, 3, 2, 1)


class SameSizeConv2d(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SameSizeConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        # (batch,501,41,2]
        x1 = x.permute(0, 3, 1, 2)
        # (batch,2,501,41]
        x2 = self.conv(x1)
        # (batch,2,501,41]
        x3 = x2.permute(0, 2, 3, 1)
        # (batch,501,41,2]
        return x3