Spaces:
Runtime error
Runtime error
File size: 9,731 Bytes
d896bd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
import itertools
import os
import torch
import numpy as np
from torch_geometric.data import Dataset
from torch_geometric.data import Data
from torch_geometric.data.collate import collate
import constants
from constants import EdgeTypes
def get_node_labels(s_tensor, ones_idxs):
# Build a tensor which has node labels in place of each activation in the
# stucture tensor
labels = torch.zeros_like(s_tensor, dtype=torch.long,
device=s_tensor.device)
n_nodes = len(ones_idxs[0])
labels[ones_idxs] = torch.arange(n_nodes, device=s_tensor.device)
return labels
def get_track_edges(s_tensor, ones_idxs=None, node_labels=None):
track_edges = []
if ones_idxs is None:
# Indices where the binary structure tensor is active
ones_idxs = torch.nonzero(s_tensor, as_tuple=True)
if node_labels is None:
node_labels = get_node_labels(s_tensor, ones_idxs)
# For each track, add direct and inverse edges between consecutive nodes
for track in range(s_tensor.size(0)):
# List of active timesteps in the current track
tss = list(ones_idxs[1][ones_idxs[0] == track])
edge_type = EdgeTypes.TRACK.value + track
edges = [
# Edge tuple: (u, v, type, ts_distance). Zip is used to obtain
# consecutive active timesteps. Edges in different tracks have
# different types.
(node_labels[track, t1],
node_labels[track, t2], edge_type, t2 - t1)
for t1, t2 in zip(tss[:-1], tss[1:])
]
inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges]
track_edges.extend(edges + inverse_edges)
return torch.tensor(track_edges, dtype=torch.long)
def get_onset_edges(s_tensor, ones_idxs=None, node_labels=None):
onset_edges = []
edge_type = EdgeTypes.ONSET.value
if ones_idxs is None:
# Indices where the binary structure tensor is active
ones_idxs = torch.nonzero(s_tensor, as_tuple=True)
if node_labels is None:
node_labels = get_node_labels(s_tensor, ones_idxs)
# Add direct and inverse edges between nodes played in the same timestep
for ts in range(s_tensor.size(1)):
# List of active tracks in the current timestep
tracks = list(ones_idxs[0][ones_idxs[1] == ts])
# Obtain all possible pairwise combinations of active tracks
combinations = list(itertools.combinations(tracks, 2))
edges = [
# Edge tuple: (u, v, type, ts_distance(=0)).
(node_labels[track1, ts], node_labels[track2, ts], edge_type, 0)
for track1, track2 in combinations
]
inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges]
onset_edges.extend(edges + inverse_edges)
return torch.tensor(onset_edges, dtype=torch.long)
def get_next_edges(s_tensor, ones_idxs=None, node_labels=None):
next_edges = []
edge_type = EdgeTypes.NEXT.value
if ones_idxs is None:
# Indices where the binary structure tensor is active
ones_idxs = torch.nonzero(s_tensor, as_tuple=True)
if node_labels is None:
node_labels = get_node_labels(s_tensor, ones_idxs)
# List of active timesteps
tss = torch.nonzero(torch.any(s_tensor.bool(), dim=0)).squeeze()
if tss.dim() == 0:
return torch.tensor([], dtype=torch.long)
for i in range(tss.size(0)-1):
# Get consecutive active timesteps
t1, t2 = tss[i], tss[i+1]
# Get all the active tracks in the two timesteps
t1_tracks = ones_idxs[0][ones_idxs[1] == t1]
t2_tracks = ones_idxs[0][ones_idxs[1] == t2]
# Combine the source and destination tracks, removing combinations with
# the same source and destination track (since these represent track
# edges).
tracks_product = list(itertools.product(t1_tracks, t2_tracks))
tracks_product = [(track1, track2)
for (track1, track2) in tracks_product
if track1 != track2]
# Edge tuple: (u, v, type, ts_distance).
edges = [(node_labels[track1, t1], node_labels[track2, t2],
edge_type, t2 - t1)
for track1, track2 in tracks_product]
next_edges.extend(edges)
return torch.tensor(next_edges, dtype=torch.long)
def get_track_features(s_tensor):
# Indices where the binary structure tensor is active
ones_idxs = torch.nonzero(s_tensor)
n_nodes = len(ones_idxs)
tracks = ones_idxs[:, 0]
n_tracks = s_tensor.size(0)
# The feature n_nodes x n_tracks tensor contains one-hot tracks
# representations for each node
features = torch.zeros((n_nodes, n_tracks))
features[torch.arange(n_nodes), tracks] = 1
return features
def graph_from_tensor(s_tensor):
bars = []
# Iterate over bars and construct a graph for each bar
for i in range(s_tensor.size(0)):
bar = s_tensor[i]
# If the bar contains no activations, add a fake one to avoid having
# to deal with empty graphs
if not torch.any(bar):
bar[0, 0] = 1
# Get edges from boolean activations
track_edges = get_track_edges(bar)
onset_edges = get_onset_edges(bar)
next_edges = get_next_edges(bar)
edges = [track_edges, onset_edges, next_edges]
# Concatenate edge tensors (N x 4) (if any)
is_edgeless = (len(track_edges) == 0 and
len(onset_edges) == 0 and
len(next_edges) == 0)
if not is_edgeless:
edge_list = torch.cat([x for x in edges
if torch.numel(x) > 0])
# Adapt tensor to torch_geometric's Data
# If no edges, add fake self-edge
# edge_list[:, :2] contains source and destination node labels
# edge_list[:, 2:] contains edge types and timestep distances
edge_index = (edge_list[:, :2].t().contiguous() if not is_edgeless else
torch.LongTensor([[0], [0]]))
attrs = (edge_list[:, 2:] if not is_edgeless else
torch.Tensor([[0, 0]]))
# Add one hot timestep distance to edge attributes
edge_attrs = torch.zeros(attrs.size(0), s_tensor.shape[-1] + 1)
edge_attrs[:, 0] = attrs[:, 0]
edge_attrs[torch.arange(edge_attrs.size(0)),
attrs.long()[:, 1] + 1] = 1
node_features = get_track_features(bar)
is_drum = node_features[:, 0].bool()
num_nodes = torch.sum(bar, dtype=torch.long)
bars.append(Data(edge_index=edge_index, edge_attrs=edge_attrs,
num_nodes=num_nodes, node_features=node_features,
is_drum=is_drum).to(s_tensor.device))
# Merge the graphs corresponding to different bars into a single big graph
graph, _, _ = collate(
Data,
data_list=bars,
increment=True,
add_batch=True
)
# Change bars assignment vector name (otherwise, Dataloader's collate
# would overwrite graphs.batch)
graph.bars = graph.batch
return graph
class PolyphemusDataset(Dataset):
def __init__(self, dir, n_bars=2):
self.dir = dir
self.files = list(os.scandir(self.dir))
self.len = len(self.files)
self.n_bars = n_bars
def __len__(self):
return self.len
def __getitem__(self, idx):
# Load tensors
sample_path = os.path.join(self.dir, self.files[idx].name)
data = np.load(sample_path)
c_tensor = torch.tensor(data["c_tensor"], dtype=torch.long)
s_tensor = torch.tensor(data["s_tensor"], dtype=torch.bool)
# From (n_tracks x n_timesteps x ...)
# to (n_bars x n_tracks x n_timesteps x ...)
c_tensor = c_tensor.reshape(c_tensor.shape[0], self.n_bars, -1,
c_tensor.shape[2], c_tensor.shape[3])
c_tensor = c_tensor.permute(1, 0, 2, 3, 4)
s_tensor = s_tensor.reshape(s_tensor.shape[0], self.n_bars, -1)
s_tensor = s_tensor.permute(1, 0, 2)
# From decimals to onehot (pitches)
pitches = c_tensor[..., 0]
onehot_p = torch.zeros(
(pitches.shape[0]*pitches.shape[1]*pitches.shape[2]*pitches.shape[3],
constants.N_PITCH_TOKENS),
dtype=torch.float32
)
onehot_p[torch.arange(0, onehot_p.shape[0]), pitches.reshape(-1)] = 1.
onehot_p = onehot_p.reshape(pitches.shape[0], pitches.shape[1],
pitches.shape[2], pitches.shape[3],
constants.N_PITCH_TOKENS)
# From decimals to onehot (durations)
durs = c_tensor[..., 1]
onehot_d = torch.zeros(
(durs.shape[0]*durs.shape[1]*durs.shape[2]*durs.shape[3],
constants.N_DUR_TOKENS),
dtype=torch.float32
)
onehot_d[torch.arange(0, onehot_d.shape[0]), durs.reshape(-1)] = 1.
onehot_d = onehot_d.reshape(durs.shape[0], durs.shape[1],
durs.shape[2], durs.shape[3],
constants.N_DUR_TOKENS)
# Concatenate pitches and durations
c_tensor = torch.cat((onehot_p, onehot_d), dim=-1)
# Build graph structure from structure tensor
graph = graph_from_tensor(s_tensor)
# Filter silences in order to get a sparse representation
c_tensor = c_tensor.reshape(-1, c_tensor.shape[-2], c_tensor.shape[-1])
c_tensor = c_tensor[s_tensor.reshape(-1).bool()]
graph.c_tensor = c_tensor
graph.s_tensor = s_tensor.float()
return graph
|