Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| class Layer(nn.Module): | |
| def __init__(self, dim_in, dim_out, kernel_size, stride, padding): | |
| super(Layer, self).__init__() | |
| self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True) | |
| self.gate = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True) | |
| self.bn = nn.BatchNorm2d(dim_out) | |
| def forward(self, x): | |
| return self.bn(self.conv(x) * torch.sigmoid(self.gate(x))) | |
| class Encoder(nn.Module): | |
| def __init__(self, out_dim=32, n_layers=3, message_dim=0, message_band_size=None, n_fft=None): | |
| super(Encoder, self).__init__() | |
| assert message_band_size is not None | |
| assert n_fft is not None | |
| self.message_band_size = message_band_size | |
| main = [Layer(dim_in=1, dim_out=32, kernel_size=3, stride=1, padding=1)] | |
| for i in range(n_layers-2): | |
| main.append(Layer(dim_in=32, dim_out=32, kernel_size=3, stride=1, padding=1)) | |
| main.append(Layer(dim_in=32, dim_out=out_dim, kernel_size=3, stride=1, padding=1)) | |
| self.main = nn.Sequential(*main) | |
| self.linear = nn.Linear(message_dim, message_band_size) | |
| self.n_fft = n_fft | |
| def forward(self, x): | |
| h = self.main(x) | |
| return h | |
| def transform_message(self, msg): | |
| output = self.linear(msg.transpose(2, 3)).transpose(2, 3) | |
| if self.message_band_size != self.n_fft // 2 + 1: | |
| output = torch.nn.functional.pad(output, (0, 0, 0, self.n_fft // 2 + 1 - self.message_band_size)) | |
| return output | |
| class CarrierDecoder(nn.Module): | |
| def __init__(self, config, conv_dim, n_layers=4, message_band_size=1024): | |
| super(CarrierDecoder, self).__init__() | |
| self.config = config | |
| self.message_band_size = message_band_size | |
| layers = [Layer(dim_in=conv_dim, dim_out=96, kernel_size=3, stride=1, padding=1)] | |
| for i in range(n_layers-2): | |
| layers.append(Layer(dim_in=96, dim_out=96, kernel_size=3, stride=1, padding=1)) | |
| layers.append(Layer(dim_in=96, dim_out=1, kernel_size=1, stride=1, padding=0)) | |
| self.main = nn.Sequential(*layers) | |
| def forward(self, x, message_sdr): | |
| h = self.main(x) | |
| if self.config.ensure_negative_message: | |
| h = torch.abs(h) | |
| h[:, :, self.message_band_size:, :] = 0 | |
| if not self.config.no_normalization: | |
| h = h / torch.mean(h**2, dim=2, keepdim=True)**0.5 / (10**(message_sdr/20)) | |
| return h | |
| class MsgDecoder(nn.Module): | |
| def __init__(self, message_dim=0, message_band_size=None, channel_dim=128, num_layers=10): | |
| super(MsgDecoder, self).__init__() | |
| assert message_band_size is not None | |
| self.message_band_size = message_band_size | |
| main = [ | |
| nn.Dropout(0), | |
| Layer(dim_in=1, dim_out=channel_dim, kernel_size=3, stride=1, padding=1) | |
| ] | |
| for l in range(num_layers - 2): | |
| main += [ | |
| nn.Dropout(0), | |
| Layer(dim_in=channel_dim, dim_out=channel_dim, kernel_size=3, stride=1, padding=1), | |
| ] | |
| main += [ | |
| nn.Dropout(0), | |
| Layer(dim_in=channel_dim, dim_out=message_dim, kernel_size=3, stride=1, padding=1) | |
| ] | |
| self.main = nn.Sequential(*main) | |
| self.linear = nn.Linear(self.message_band_size, 1) | |
| def forward(self, x): | |
| h = self.main(x[:, :, :self.message_band_size]) | |
| h = self.linear(h.transpose(2, 3)).squeeze(3).unsqueeze(1) | |
| return h | |