Spaces:
Runtime error
Runtime error
from typing import Union, Tuple | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, Tensor | |
from torch_sparse import SparseTensor, masked_select_nnz | |
from torch_geometric.typing import OptTensor, Adj | |
from torch_geometric.nn.inits import reset | |
from torch_geometric.nn.norm import BatchNorm | |
from torch_geometric.nn.glob import GlobalAttention | |
from torch_geometric.data import Batch | |
from torch_geometric.nn.conv import RGCNConv | |
import constants | |
from data import graph_from_tensor | |
def masked_edge_index(edge_index, edge_mask): | |
# type: (Tensor, Tensor) -> Tensor | |
pass | |
def masked_edge_index(edge_index, edge_mask): | |
# type: (SparseTensor, Tensor) -> SparseTensor | |
pass | |
def masked_edge_index(edge_index, edge_mask): | |
if isinstance(edge_index, Tensor): | |
return edge_index[:, edge_mask] | |
else: | |
return masked_select_nnz(edge_index, edge_mask, layout='coo') | |
def masked_edge_attrs(edge_attrs, edge_mask): | |
return edge_attrs[edge_mask, :] | |
class GCL(RGCNConv): | |
def __init__(self, in_channels, out_channels, num_relations, nn, | |
dropout=0.1, **kwargs): | |
super().__init__(in_channels=in_channels, out_channels=out_channels, | |
num_relations=num_relations, **kwargs) | |
self.nn = nn | |
self.dropout = dropout | |
self.reset_edge_nn() | |
def reset_edge_nn(self): | |
reset(self.nn) | |
def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], | |
edge_index: Adj, edge_type: OptTensor = None, | |
edge_attr: OptTensor = None): | |
# Convert input features to a pair of node features or node indices. | |
x_l: OptTensor = None | |
if isinstance(x, tuple): | |
x_l = x[0] | |
else: | |
x_l = x | |
if x_l is None: | |
x_l = torch.arange(self.in_channels_l, device=self.weight.device) | |
x_r: Tensor = x_l | |
if isinstance(x, tuple): | |
x_r = x[1] | |
size = (x_l.size(0), x_r.size(0)) | |
if isinstance(edge_index, SparseTensor): | |
edge_type = edge_index.storage.value() | |
assert edge_type is not None | |
# propagate_type: (x: Tensor) | |
out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device) | |
weight = self.weight | |
# Basis-decomposition | |
if self.num_bases is not None: | |
weight = (self.comp @ weight.view(self.num_bases, -1)).view( | |
self.num_relations, self.in_channels_l, self.out_channels) | |
# Block-diagonal-decomposition | |
if self.num_blocks is not None: | |
if x_l.dtype == torch.long and self.num_blocks is not None: | |
raise ValueError('Block-diagonal decomposition not supported ' | |
'for non-continuous input features.') | |
for i in range(self.num_relations): | |
tmp = masked_edge_index(edge_index, edge_type == i) | |
h = self.propagate(tmp, x=x_l, size=size) | |
h = h.view(-1, weight.size(1), weight.size(2)) | |
h = torch.einsum('abc,bcd->abd', h, weight[i]) | |
out += h.contiguous().view(-1, self.out_channels) | |
else: | |
# No regularization/Basis-decomposition | |
for i in range(self.num_relations): | |
tmp = masked_edge_index(edge_index, edge_type == i) | |
attr = masked_edge_attrs(edge_attr, edge_type == i) | |
if x_l.dtype == torch.long: | |
out += self.propagate(tmp, x=weight[i, x_l], size=size) | |
else: | |
h = self.propagate(tmp, x=x_l, size=size, | |
edge_attr=attr) | |
out = out + (h @ weight[i]) | |
root = self.root | |
if root is not None: | |
out += root[x_r] if x_r.dtype == torch.long else x_r @ root | |
if self.bias is not None: | |
out += self.bias | |
return out | |
def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: | |
# Use edge nn to compute weight tensor from edge attributes | |
# (=onehot timestep distances between nodes) | |
weights = self.nn(edge_attr) | |
weights = weights[..., :self.in_channels_l] | |
weights = weights.view(-1, self.in_channels_l) | |
out = x_j * weights | |
out = F.relu(out) | |
out = F.dropout(out, p=self.dropout, training=self.training) | |
return out | |
class MLP(nn.Module): | |
def __init__(self, input_dim=256, hidden_dim=256, output_dim=256, | |
num_layers=2, activation=True, dropout=0.1): | |
super().__init__() | |
self.layers = nn.ModuleList() | |
if num_layers == 1: | |
self.layers.append(nn.Linear(input_dim, output_dim)) | |
else: | |
# Input layer (1) + Intermediate layers (n-2) + Output layer (1) | |
self.layers.append(nn.Linear(input_dim, hidden_dim)) | |
for _ in range(num_layers - 2): | |
self.layers.append(nn.Linear(hidden_dim, hidden_dim)) | |
self.layers.append(nn.Linear(hidden_dim, output_dim)) | |
self.activation = activation | |
self.p = dropout | |
def forward(self, x): | |
for layer in self.layers: | |
x = F.dropout(x, p=self.p, training=self.training) | |
x = layer(x) | |
if self.activation: | |
x = F.relu(x) | |
return x | |
class GCN(nn.Module): | |
def __init__(self, input_dim=256, hidden_dim=256, n_layers=3, | |
num_relations=3, num_dists=32, batch_norm=False, dropout=0.1): | |
super().__init__() | |
self.layers = nn.ModuleList() | |
self.norm_layers = nn.ModuleList() | |
edge_nn = nn.Linear(num_dists, input_dim) | |
self.batch_norm = batch_norm | |
self.layers.append(GCL(input_dim, hidden_dim, num_relations, edge_nn)) | |
if self.batch_norm: | |
self.norm_layers.append(BatchNorm(hidden_dim)) | |
for i in range(n_layers-1): | |
self.layers.append(GCL(hidden_dim, hidden_dim, | |
num_relations, edge_nn)) | |
if self.batch_norm: | |
self.norm_layers.append(BatchNorm(hidden_dim)) | |
self.p = dropout | |
def forward(self, data): | |
x, edge_index, edge_attrs = data.x, data.edge_index, data.edge_attrs | |
edge_type = edge_attrs[:, 0] | |
edge_attr = edge_attrs[:, 1:] | |
for i in range(len(self.layers)): | |
residual = x | |
x = F.dropout(x, p=self.p, training=self.training) | |
x = self.layers[i](x, edge_index, edge_type, edge_attr) | |
if self.batch_norm: | |
x = self.norm_layers[i](x) | |
x = F.relu(x) | |
x = residual + x | |
return x | |
class CNNEncoder(nn.Module): | |
def __init__(self, output_dim=256, dense_dim=256, batch_norm=False, | |
dropout=0.1): | |
super().__init__() | |
# Convolutional layers | |
if batch_norm: | |
self.conv = nn.Sequential( | |
# From (4 x 32) to (8 x 4 x 32) | |
nn.Conv2d(1, 8, 3, padding=1), | |
nn.BatchNorm2d(8), | |
nn.ReLU(True), | |
# From (8 x 4 x 32) to (8 x 4 x 8) | |
nn.MaxPool2d((1, 4), stride=(1, 4)), | |
# From (8 x 4 x 8) to (16 x 4 x 8) | |
nn.Conv2d(8, 16, 3, padding=1), | |
nn.BatchNorm2d(16), | |
nn.ReLU(True) | |
) | |
else: | |
self.conv = nn.Sequential( | |
nn.Conv2d(1, 8, 3, padding=1), | |
nn.ReLU(True), | |
nn.MaxPool2d((1, 4), stride=(1, 4)), | |
nn.Conv2d(8, 16, 3, padding=1), | |
nn.ReLU(True) | |
) | |
self.flatten = nn.Flatten(start_dim=1) | |
# Linear layers | |
self.lin = nn.Sequential( | |
nn.Dropout(dropout), | |
nn.Linear(16 * 4 * 8, dense_dim), | |
nn.ReLU(True), | |
nn.Dropout(dropout), | |
nn.Linear(dense_dim, output_dim) | |
) | |
def forward(self, x): | |
x = x.unsqueeze(1) | |
x = self.conv(x) | |
x = self.flatten(x) | |
x = self.lin(x) | |
return x | |
class CNNDecoder(nn.Module): | |
def __init__(self, input_dim=256, dense_dim=256, batch_norm=False, | |
dropout=0.1): | |
super().__init__() | |
# Linear decompressors | |
self.lin = nn.Sequential( | |
nn.Dropout(dropout), | |
nn.Linear(input_dim, dense_dim), | |
nn.ReLU(True), | |
nn.Dropout(dropout), | |
nn.Linear(dense_dim, 16 * 4 * 8), | |
nn.ReLU(True) | |
) | |
self.unflatten = nn.Unflatten(dim=1, unflattened_size=(16, 4, 8)) | |
# Upsample and convolutional layers | |
if batch_norm: | |
self.conv = nn.Sequential( | |
nn.Upsample(scale_factor=(1, 4), mode='nearest'), | |
nn.Conv2d(16, 8, 3, padding=1), | |
nn.BatchNorm2d(8), | |
nn.ReLU(True), | |
nn.Conv2d(8, 1, 3, padding=1) | |
) | |
else: | |
self.conv = nn.Sequential( | |
nn.Upsample(scale_factor=(1, 4), mode='nearest'), | |
nn.Conv2d(16, 8, 3, padding=1), | |
nn.ReLU(True), | |
nn.Conv2d(8, 1, 3, padding=1) | |
) | |
def forward(self, x): | |
x = self.lin(x) | |
x = self.unflatten(x) | |
x = self.conv(x) | |
x = x.unsqueeze(1) | |
return x | |
class ContentEncoder(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.__dict__.update(kwargs) | |
self.dropout_layer = nn.Dropout(p=self.dropout) | |
# Pitch and duration embedding layers (separate layers for drums | |
# and non drums) | |
self.non_drums_pitch_emb = nn.Linear(constants.N_PITCH_TOKENS, | |
self.d//2) | |
self.drums_pitch_emb = nn.Linear(constants.N_PITCH_TOKENS, self.d//2) | |
self.dur_emb = nn.Linear(constants.N_DUR_TOKENS, self.d//2) | |
# Batch norm layers | |
self.bn_non_drums = nn.BatchNorm1d(num_features=self.d//2) | |
self.bn_drums = nn.BatchNorm1d(num_features=self.d//2) | |
self.bn_dur = nn.BatchNorm1d(num_features=self.d//2) | |
self.chord_encoder = nn.Linear( | |
self.d * (constants.MAX_SIMU_TOKENS-1), self.d) | |
self.graph_encoder = GCN( | |
dropout=self.dropout, | |
input_dim=self.d, | |
hidden_dim=self.d, | |
n_layers=self.gnn_n_layers, | |
num_relations=constants.N_EDGE_TYPES, | |
batch_norm=self.batch_norm | |
) | |
# Soft attention node-aggregation layer | |
gate_nn = nn.Sequential( | |
MLP(input_dim=self.d, output_dim=1, num_layers=1, | |
activation=False, dropout=self.dropout), | |
nn.BatchNorm1d(1) | |
) | |
self.graph_attention = GlobalAttention(gate_nn) | |
self.bars_encoder = nn.Linear(self.n_bars * self.d, self.d) | |
def forward(self, graph): | |
c_tensor = graph.c_tensor | |
# Discard SOS token | |
c_tensor = c_tensor[:, 1:, :] | |
# Get drums and non drums tensors | |
drums = c_tensor[graph.is_drum] | |
non_drums = c_tensor[torch.logical_not(graph.is_drum)] | |
# Compute drums embeddings | |
sz = drums.size() | |
drums_pitch = self.drums_pitch_emb( | |
drums[..., :constants.N_PITCH_TOKENS]) | |
drums_pitch = self.bn_drums(drums_pitch.view(-1, self.d//2)) | |
drums_pitch = drums_pitch.view(sz[0], sz[1], self.d//2) | |
drums_dur = self.dur_emb(drums[..., constants.N_PITCH_TOKENS:]) | |
drums_dur = self.bn_dur(drums_dur.view(-1, self.d//2)) | |
drums_dur = drums_dur.view(sz[0], sz[1], self.d//2) | |
drums = torch.cat((drums_pitch, drums_dur), dim=-1) | |
# n_nodes x MAX_SIMU_TOKENS x d | |
# Compute non drums embeddings | |
sz = non_drums.size() | |
non_drums_pitch = self.non_drums_pitch_emb( | |
non_drums[..., :constants.N_PITCH_TOKENS] | |
) | |
non_drums_pitch = self.bn_non_drums(non_drums_pitch.view(-1, self.d//2)) | |
non_drums_pitch = non_drums_pitch.view(sz[0], sz[1], self.d//2) | |
non_drums_dur = self.dur_emb(non_drums[..., constants.N_PITCH_TOKENS:]) | |
non_drums_dur = self.bn_dur(non_drums_dur.view(-1, self.d//2)) | |
non_drums_dur = non_drums_dur.view(sz[0], sz[1], self.d//2) | |
non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1) | |
# n_nodes x MAX_SIMU_TOKENS x d | |
# Compute chord embeddings (drums and non drums) | |
drums = self.chord_encoder( | |
drums.view(-1, self.d * (constants.MAX_SIMU_TOKENS-1)) | |
) | |
non_drums = self.chord_encoder( | |
non_drums.view(-1, self.d * (constants.MAX_SIMU_TOKENS-1)) | |
) | |
drums = F.relu(drums) | |
non_drums = F.relu(non_drums) | |
drums = self.dropout_layer(drums) | |
non_drums = self.dropout_layer(non_drums) | |
# n_nodes x d | |
# Merge drums and non drums | |
out = torch.zeros((c_tensor.size(0), self.d), device=self.device, | |
dtype=drums.dtype) | |
out[graph.is_drum] = drums | |
out[torch.logical_not(graph.is_drum)] = non_drums | |
# n_nodes x d | |
# Set initial graph node states to intermediate chord representations | |
# and pass through GCN | |
graph.x = out | |
graph.distinct_bars = graph.bars + self.n_bars*graph.batch | |
out = self.graph_encoder(graph) | |
# n_nodes x d | |
# Aggregate final node states into bar encodings with soft attention | |
with torch.cuda.amp.autocast(enabled=False): | |
out = self.graph_attention(out, batch=graph.distinct_bars) | |
# bs x n_bars x d | |
out = out.view(-1, self.n_bars * self.d) | |
# bs x (n_bars*d) | |
z_c = self.bars_encoder(out) | |
# bs x d | |
return z_c | |
class StructureEncoder(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.__dict__.update(kwargs) | |
self.cnn_encoder = CNNEncoder( | |
dense_dim=self.d, | |
output_dim=self.d, | |
dropout=self.dropout, | |
batch_norm=self.batch_norm | |
) | |
self.bars_encoder = nn.Linear(self.n_bars * self.d, self.d) | |
def forward(self, graph): | |
s_tensor = graph.s_tensor | |
out = self.cnn_encoder(s_tensor.view(-1, constants.N_TRACKS, | |
self.resolution * 4)) | |
# (bs*n_bars) x d | |
out = out.view(-1, self.n_bars * self.d) | |
# bs x (n_bars*d) | |
z_s = self.bars_encoder(out) | |
# bs x d | |
return z_s | |
class Encoder(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.__dict__.update(kwargs) | |
self.s_encoder = StructureEncoder(**kwargs) | |
self.c_encoder = ContentEncoder(**kwargs) | |
self.dropout_layer = nn.Dropout(p=self.dropout) | |
# Linear layer that merges content and structure representations | |
self.linear_merge = nn.Linear(2*self.d, self.d) | |
self.bn_linear_merge = nn.BatchNorm1d(num_features=self.d) | |
self.linear_mu = nn.Linear(self.d, self.d) | |
self.linear_log_var = nn.Linear(self.d, self.d) | |
def forward(self, graph): | |
z_s = self.s_encoder(graph) | |
z_c = self.c_encoder(graph) | |
# Merge content and structure representations | |
z_g = torch.cat((z_c, z_s), dim=1) | |
z_g = self.dropout_layer(z_g) | |
z_g = self.linear_merge(z_g) | |
z_g = self.bn_linear_merge(z_g) | |
z_g = F.relu(z_g) | |
# Compute mu and log(std^2) | |
z_g = self.dropout_layer(z_g) | |
mu = self.linear_mu(z_g) | |
log_var = self.linear_log_var(z_g) | |
return mu, log_var | |
class StructureDecoder(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.__dict__.update(kwargs) | |
self.bars_decoder = nn.Linear(self.d, self.d * self.n_bars) | |
self.cnn_decoder = CNNDecoder( | |
input_dim=self.d, | |
dense_dim=self.d, | |
dropout=self.dropout, | |
batch_norm=self.batch_norm | |
) | |
def forward(self, z_s): | |
# z_s: bs x d | |
out = self.bars_decoder(z_s) # bs x (n_bars*d) | |
out = self.cnn_decoder(out.reshape(-1, self.d)) | |
out = out.view(z_s.size(0), self.n_bars, constants.N_TRACKS, -1) | |
return out | |
class ContentDecoder(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.__dict__.update(kwargs) | |
self.bars_decoder = nn.Linear(self.d, self.d * self.n_bars) | |
self.graph_decoder = GCN( | |
dropout=self.dropout, | |
input_dim=self.d, | |
hidden_dim=self.d, | |
n_layers=self.gnn_n_layers, | |
num_relations=constants.N_EDGE_TYPES, | |
batch_norm=self.batch_norm | |
) | |
self.chord_decoder = nn.Linear( | |
self.d, self.d*(constants.MAX_SIMU_TOKENS-1)) | |
# Pitch and duration (un)embedding linear layers | |
self.drums_pitch_emb = nn.Linear(self.d//2, constants.N_PITCH_TOKENS) | |
self.non_drums_pitch_emb = nn.Linear( | |
self.d//2, constants.N_PITCH_TOKENS) | |
self.dur_emb = nn.Linear(self.d//2, constants.N_DUR_TOKENS) | |
self.dropout_layer = nn.Dropout(p=self.dropout) | |
def forward(self, z_c, s): | |
out = self.bars_decoder(z_c) # bs x (n_bars*d) | |
# Initialize node features with corresponding z_bar | |
# and propagate with GNN | |
s.distinct_bars = s.bars + self.n_bars*s.batch | |
_, counts = torch.unique(s.distinct_bars, return_counts=True) | |
out = out.view(-1, self.d) | |
out = torch.repeat_interleave(out, counts, axis=0) # n_nodes x d | |
s.x = out | |
out = self.graph_decoder(s) # n_nodes x d | |
out = self.chord_decoder(out) # n_nodes x (MAX_SIMU_TOKENS*d) | |
out = out.view(-1, constants.MAX_SIMU_TOKENS-1, self.d) | |
drums = out[s.is_drum] # n_nodes_drums x MAX_SIMU_TOKENS x d | |
non_drums = out[torch.logical_not(s.is_drum)] | |
# n_nodes_non_drums x MAX_SIMU_TOKENS x d | |
# Obtain final pitch and dur logits (softmax will be applied | |
# outside forward) | |
non_drums = self.dropout_layer(non_drums) | |
drums = self.dropout_layer(drums) | |
drums_pitch = self.drums_pitch_emb(drums[..., :self.d//2]) | |
drums_dur = self.dur_emb(drums[..., self.d//2:]) | |
drums = torch.cat((drums_pitch, drums_dur), dim=-1) | |
# n_nodes_drums x MAX_SIMU_TOKENS x d_token | |
non_drums_pitch = self.non_drums_pitch_emb(non_drums[..., :self.d//2]) | |
non_drums_dur = self.dur_emb(non_drums[..., self.d//2:]) | |
non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1) | |
# n_nodes_non_drums x MAX_SIMU_TOKENS x d_token | |
# Merge drums and non-drums in the final output tensor | |
d_token = constants.D_TOKEN_PAIR | |
out = torch.zeros((s.num_nodes, constants.MAX_SIMU_TOKENS-1, d_token), | |
device=self.device, dtype=drums.dtype) | |
out[s.is_drum] = drums | |
out[torch.logical_not(s.is_drum)] = non_drums | |
return out | |
class Decoder(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.__dict__.update(kwargs) | |
self.lin_decoder = nn.Linear(self.d, 2 * self.d) | |
self.batch_norm = nn.BatchNorm1d(num_features=2*self.d) | |
self.dropout = nn.Dropout(p=self.dropout) | |
self.s_decoder = StructureDecoder(**kwargs) | |
self.c_decoder = ContentDecoder(**kwargs) | |
self.sigmoid_thresh = 0.5 | |
def _structure_from_binary(self, s_tensor): | |
# Create graph structures for each batch | |
s = [] | |
for i in range(s_tensor.size(0)): | |
s.append(graph_from_tensor(s_tensor[i])) | |
# Create batch of graphs from single graphs | |
s = Batch.from_data_list(s, exclude_keys=['batch']) | |
s = s.to(next(self.parameters()).device) | |
return s | |
def _binary_from_logits(self, s_logits): | |
# Hard threshold instead of sampling gives more pleasant results | |
s_tensor = torch.sigmoid(s_logits) | |
s_tensor[s_tensor >= self.sigmoid_thresh] = 1 | |
s_tensor[s_tensor < self.sigmoid_thresh] = 0 | |
s_tensor = s_tensor.bool() | |
# Avoid empty bars by creating a fake activation for each empty | |
# (n_tracks x n_timesteps) bar matrix in position [0, 0] | |
empty_mask = ~s_tensor.any(dim=-1).any(dim=-1) | |
idxs = torch.nonzero(empty_mask, as_tuple=True) | |
s_tensor[idxs + (0, 0)] = True | |
return s_tensor | |
def _structure_from_logits(self, s_logits): | |
# Compute binary structure tensor from logits and build torch geometric | |
# structure from binary tensor | |
s_tensor = self._binary_from_logits(s_logits) | |
s = self._structure_from_binary(s_tensor) | |
return s | |
def forward(self, z, s=None): | |
# Obtain z_s and z_c from z | |
z = self.lin_decoder(z) | |
z = self.batch_norm(z) | |
z = F.relu(z) | |
z = self.dropout(z) # bs x (2*d) | |
z_s, z_c = z[:, :self.d], z[:, self.d:] | |
# Obtain the tensor containing structure logits | |
s_logits = self.s_decoder(z_s) | |
if s is None: | |
# Build torch geometric graph structure from structure logits. | |
# This step involves non differentiable operations. | |
# No gradients pass through here. | |
s = self._structure_from_logits(s_logits.detach()) | |
# Obtain the tensor containing content logits | |
c_logits = self.c_decoder(z_c, s) | |
return s_logits, c_logits | |
class VAE(nn.Module): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.encoder = Encoder(**kwargs) | |
self.decoder = Decoder(**kwargs) | |
def forward(self, graph): | |
# Encoder pass | |
mu, log_var = self.encoder(graph) | |
# Reparameterization trick | |
z = torch.exp(0.5 * log_var) | |
z = z * torch.randn_like(z) | |
z = z + mu | |
# Decoder pass | |
out = self.decoder(z, graph) | |
return out, mu, log_var | |