Spaces:
Sleeping
Sleeping
File size: 3,541 Bytes
32ca76b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import pdb
import torch.optim
import torch.nn as nn
from models.hinet import Hinet
# from utils.attacks import attack_layer, mp3_attack_v2, butterworth_attack
import numpy as np
import random
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class Model(nn.Module):
def __init__(self, num_point, num_bit, n_fft, hop_length, use_recover_layer, num_layers):
super(Model, self).__init__()
self.hinet = Hinet(num_layers=num_layers)
self.watermark_fc = torch.nn.Linear(num_bit, num_point)
self.watermark_fc_back = torch.nn.Linear(num_point, num_bit)
self.n_fft = n_fft
self.hop_length = hop_length
self.dropout1 = torch.nn.Dropout()
self.identity = torch.nn.Identity()
self.recover_layer = SameSizeConv2d(2, 2)
self.use_recover_layer = use_recover_layer
def stft(self, data):
window = torch.hann_window(self.n_fft).to(data.device)
tmp = torch.stft(data, n_fft=self.n_fft, hop_length=self.hop_length, window=window, return_complex=False)
# [1, 501, 41, 2]
return tmp
def istft(self, signal_wmd_fft):
window = torch.hann_window(self.n_fft).to(signal_wmd_fft.device)
# Changed in version 2.0: Real datatype inputs are no longer supported. Input must now have a complex datatype, as returned by stft(..., return_complex=True).
return torch.istft(signal_wmd_fft, n_fft=self.n_fft, hop_length=self.hop_length, window=window,
return_complex=False)
def encode(self, signal, message, need_fft=False):
# 1.信号执行fft
signal_fft = self.stft(signal)
# import pdb
# pdb.set_trace()
# (batch,freq_bins,time_frames,2)
# 2.Message执行fft
message_expand = self.watermark_fc(message)
message_fft = self.stft(message_expand)
# 3.encode
signal_wmd_fft, msg_remain = self.enc_dec(signal_fft, message_fft, rev=False)
# (batch,freq_bins,time_frames,2)
signal_wmd = self.istft(signal_wmd_fft)
if need_fft:
return signal_wmd, signal_fft, message_fft
return signal_wmd
def decode(self, signal):
signal_fft = self.stft(signal)
if self.use_recover_layer:
signal_fft = self.recover_layer(signal_fft)
watermark_fft = signal_fft
# watermark_fft = torch.randn(signal_fft.shape).cuda()
_, message_restored_fft = self.enc_dec(signal_fft, watermark_fft, rev=True)
message_restored_expanded = self.istft(message_restored_fft)
message_restored_float = self.watermark_fc_back(message_restored_expanded).clamp(-1, 1)
return message_restored_float
def enc_dec(self, signal, watermark, rev):
signal = signal.permute(0, 3, 2, 1)
# [4, 2, 41, 501]
watermark = watermark.permute(0, 3, 2, 1)
# pdb.set_trace()
signal2, watermark2 = self.hinet(signal, watermark, rev)
return signal2.permute(0, 3, 2, 1), watermark2.permute(0, 3, 2, 1)
class SameSizeConv2d(nn.Module):
def __init__(self, in_channels, out_channels):
super(SameSizeConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
# (batch,501,41,2]
x1 = x.permute(0, 3, 1, 2)
# (batch,2,501,41]
x2 = self.conv(x1)
# (batch,2,501,41]
x3 = x2.permute(0, 2, 3, 1)
# (batch,501,41,2]
return x3
|