jordand's picture
Upload 21 files
60cc71a verified
raw
history blame
3.29 kB
import torch
import torch.nn as nn
import numpy as np
class Layer(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, stride, padding):
super(Layer, self).__init__()
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
self.gate = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
self.bn = nn.BatchNorm2d(dim_out)
def forward(self, x):
return self.bn(self.conv(x) * torch.sigmoid(self.gate(x)))
class Encoder(nn.Module):
def __init__(self, out_dim=32, n_layers=3, message_dim=0, message_band_size=None, n_fft=None):
super(Encoder, self).__init__()
assert message_band_size is not None
assert n_fft is not None
self.message_band_size = message_band_size
main = [Layer(dim_in=1, dim_out=32, kernel_size=3, stride=1, padding=1)]
for i in range(n_layers-2):
main.append(Layer(dim_in=32, dim_out=32, kernel_size=3, stride=1, padding=1))
main.append(Layer(dim_in=32, dim_out=out_dim, kernel_size=3, stride=1, padding=1))
self.main = nn.Sequential(*main)
self.linear = nn.Linear(message_dim, message_band_size)
self.n_fft = n_fft
def forward(self, x):
h = self.main(x)
return h
def transform_message(self, msg):
output = self.linear(msg.transpose(2, 3)).transpose(2, 3)
if self.message_band_size != self.n_fft // 2 + 1:
output = torch.nn.functional.pad(output, (0, 0, 0, self.n_fft // 2 + 1 - self.message_band_size))
return output
class CarrierDecoder(nn.Module):
def __init__(self, config, conv_dim, n_layers=4, message_band_size=1024):
super(CarrierDecoder, self).__init__()
self.config = config
self.message_band_size = message_band_size
layers = [Layer(dim_in=conv_dim, dim_out=96, kernel_size=3, stride=1, padding=1)]
for i in range(n_layers-2):
layers.append(Layer(dim_in=96, dim_out=96, kernel_size=3, stride=1, padding=1))
layers.append(Layer(dim_in=96, dim_out=1, kernel_size=1, stride=1, padding=0))
self.main = nn.Sequential(*layers)
def forward(self, x, message_sdr):
h = self.main(x)
if self.config.ensure_negative_message:
h = torch.abs(h)
h[:, :, self.message_band_size:, :] = 0
if not self.config.no_normalization:
h = h / torch.mean(h**2, dim=2, keepdim=True)**0.5 / (10**(message_sdr/20))
return h
class MsgDecoder(nn.Module):
def __init__(self, message_dim=0, message_band_size=None, channel_dim=128, num_layers=10):
super(MsgDecoder, self).__init__()
assert message_band_size is not None
self.message_band_size = message_band_size
main = [
nn.Dropout(0),
Layer(dim_in=1, dim_out=channel_dim, kernel_size=3, stride=1, padding=1)
]
for l in range(num_layers - 2):
main += [
nn.Dropout(0),
Layer(dim_in=channel_dim, dim_out=channel_dim, kernel_size=3, stride=1, padding=1),
]
main += [
nn.Dropout(0),
Layer(dim_in=channel_dim, dim_out=message_dim, kernel_size=3, stride=1, padding=1)
]
self.main = nn.Sequential(*main)
self.linear = nn.Linear(self.message_band_size, 1)
def forward(self, x):
h = self.main(x[:, :, :self.message_band_size])
h = self.linear(h.transpose(2, 3)).squeeze(3).unsqueeze(1)
return h