Spaces:
Runtime error
Runtime error
| import itertools | |
| import os | |
| import torch | |
| import numpy as np | |
| from torch_geometric.data import Dataset | |
| from torch_geometric.data import Data | |
| from torch_geometric.data.collate import collate | |
| import constants | |
| from constants import EdgeTypes | |
| def get_node_labels(s_tensor, ones_idxs): | |
| # Build a tensor which has node labels in place of each activation in the | |
| # stucture tensor | |
| labels = torch.zeros_like(s_tensor, dtype=torch.long, | |
| device=s_tensor.device) | |
| n_nodes = len(ones_idxs[0]) | |
| labels[ones_idxs] = torch.arange(n_nodes, device=s_tensor.device) | |
| return labels | |
| def get_track_edges(s_tensor, ones_idxs=None, node_labels=None): | |
| track_edges = [] | |
| if ones_idxs is None: | |
| # Indices where the binary structure tensor is active | |
| ones_idxs = torch.nonzero(s_tensor, as_tuple=True) | |
| if node_labels is None: | |
| node_labels = get_node_labels(s_tensor, ones_idxs) | |
| # For each track, add direct and inverse edges between consecutive nodes | |
| for track in range(s_tensor.size(0)): | |
| # List of active timesteps in the current track | |
| tss = list(ones_idxs[1][ones_idxs[0] == track]) | |
| edge_type = EdgeTypes.TRACK.value + track | |
| edges = [ | |
| # Edge tuple: (u, v, type, ts_distance). Zip is used to obtain | |
| # consecutive active timesteps. Edges in different tracks have | |
| # different types. | |
| (node_labels[track, t1], | |
| node_labels[track, t2], edge_type, t2 - t1) | |
| for t1, t2 in zip(tss[:-1], tss[1:]) | |
| ] | |
| inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges] | |
| track_edges.extend(edges + inverse_edges) | |
| return torch.tensor(track_edges, dtype=torch.long) | |
| def get_onset_edges(s_tensor, ones_idxs=None, node_labels=None): | |
| onset_edges = [] | |
| edge_type = EdgeTypes.ONSET.value | |
| if ones_idxs is None: | |
| # Indices where the binary structure tensor is active | |
| ones_idxs = torch.nonzero(s_tensor, as_tuple=True) | |
| if node_labels is None: | |
| node_labels = get_node_labels(s_tensor, ones_idxs) | |
| # Add direct and inverse edges between nodes played in the same timestep | |
| for ts in range(s_tensor.size(1)): | |
| # List of active tracks in the current timestep | |
| tracks = list(ones_idxs[0][ones_idxs[1] == ts]) | |
| # Obtain all possible pairwise combinations of active tracks | |
| combinations = list(itertools.combinations(tracks, 2)) | |
| edges = [ | |
| # Edge tuple: (u, v, type, ts_distance(=0)). | |
| (node_labels[track1, ts], node_labels[track2, ts], edge_type, 0) | |
| for track1, track2 in combinations | |
| ] | |
| inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges] | |
| onset_edges.extend(edges + inverse_edges) | |
| return torch.tensor(onset_edges, dtype=torch.long) | |
| def get_next_edges(s_tensor, ones_idxs=None, node_labels=None): | |
| next_edges = [] | |
| edge_type = EdgeTypes.NEXT.value | |
| if ones_idxs is None: | |
| # Indices where the binary structure tensor is active | |
| ones_idxs = torch.nonzero(s_tensor, as_tuple=True) | |
| if node_labels is None: | |
| node_labels = get_node_labels(s_tensor, ones_idxs) | |
| # List of active timesteps | |
| tss = torch.nonzero(torch.any(s_tensor.bool(), dim=0)).squeeze() | |
| if tss.dim() == 0: | |
| return torch.tensor([], dtype=torch.long) | |
| for i in range(tss.size(0)-1): | |
| # Get consecutive active timesteps | |
| t1, t2 = tss[i], tss[i+1] | |
| # Get all the active tracks in the two timesteps | |
| t1_tracks = ones_idxs[0][ones_idxs[1] == t1] | |
| t2_tracks = ones_idxs[0][ones_idxs[1] == t2] | |
| # Combine the source and destination tracks, removing combinations with | |
| # the same source and destination track (since these represent track | |
| # edges). | |
| tracks_product = list(itertools.product(t1_tracks, t2_tracks)) | |
| tracks_product = [(track1, track2) | |
| for (track1, track2) in tracks_product | |
| if track1 != track2] | |
| # Edge tuple: (u, v, type, ts_distance). | |
| edges = [(node_labels[track1, t1], node_labels[track2, t2], | |
| edge_type, t2 - t1) | |
| for track1, track2 in tracks_product] | |
| next_edges.extend(edges) | |
| return torch.tensor(next_edges, dtype=torch.long) | |
| def get_track_features(s_tensor): | |
| # Indices where the binary structure tensor is active | |
| ones_idxs = torch.nonzero(s_tensor) | |
| n_nodes = len(ones_idxs) | |
| tracks = ones_idxs[:, 0] | |
| n_tracks = s_tensor.size(0) | |
| # The feature n_nodes x n_tracks tensor contains one-hot tracks | |
| # representations for each node | |
| features = torch.zeros((n_nodes, n_tracks)) | |
| features[torch.arange(n_nodes), tracks] = 1 | |
| return features | |
| def graph_from_tensor(s_tensor): | |
| bars = [] | |
| # Iterate over bars and construct a graph for each bar | |
| for i in range(s_tensor.size(0)): | |
| bar = s_tensor[i] | |
| # If the bar contains no activations, add a fake one to avoid having | |
| # to deal with empty graphs | |
| if not torch.any(bar): | |
| bar[0, 0] = 1 | |
| # Get edges from boolean activations | |
| track_edges = get_track_edges(bar) | |
| onset_edges = get_onset_edges(bar) | |
| next_edges = get_next_edges(bar) | |
| edges = [track_edges, onset_edges, next_edges] | |
| # Concatenate edge tensors (N x 4) (if any) | |
| is_edgeless = (len(track_edges) == 0 and | |
| len(onset_edges) == 0 and | |
| len(next_edges) == 0) | |
| if not is_edgeless: | |
| edge_list = torch.cat([x for x in edges | |
| if torch.numel(x) > 0]) | |
| # Adapt tensor to torch_geometric's Data | |
| # If no edges, add fake self-edge | |
| # edge_list[:, :2] contains source and destination node labels | |
| # edge_list[:, 2:] contains edge types and timestep distances | |
| edge_index = (edge_list[:, :2].t().contiguous() if not is_edgeless else | |
| torch.LongTensor([[0], [0]])) | |
| attrs = (edge_list[:, 2:] if not is_edgeless else | |
| torch.Tensor([[0, 0]])) | |
| # Add one hot timestep distance to edge attributes | |
| edge_attrs = torch.zeros(attrs.size(0), s_tensor.shape[-1] + 1) | |
| edge_attrs[:, 0] = attrs[:, 0] | |
| edge_attrs[torch.arange(edge_attrs.size(0)), | |
| attrs.long()[:, 1] + 1] = 1 | |
| node_features = get_track_features(bar) | |
| is_drum = node_features[:, 0].bool() | |
| num_nodes = torch.sum(bar, dtype=torch.long) | |
| bars.append(Data(edge_index=edge_index, edge_attrs=edge_attrs, | |
| num_nodes=num_nodes, node_features=node_features, | |
| is_drum=is_drum).to(s_tensor.device)) | |
| # Merge the graphs corresponding to different bars into a single big graph | |
| graph, _, _ = collate( | |
| Data, | |
| data_list=bars, | |
| increment=True, | |
| add_batch=True | |
| ) | |
| # Change bars assignment vector name (otherwise, Dataloader's collate | |
| # would overwrite graphs.batch) | |
| graph.bars = graph.batch | |
| return graph | |
| class PolyphemusDataset(Dataset): | |
| def __init__(self, dir, n_bars=2): | |
| self.dir = dir | |
| self.files = list(os.scandir(self.dir)) | |
| self.len = len(self.files) | |
| self.n_bars = n_bars | |
| def __len__(self): | |
| return self.len | |
| def __getitem__(self, idx): | |
| # Load tensors | |
| sample_path = os.path.join(self.dir, self.files[idx].name) | |
| data = np.load(sample_path) | |
| c_tensor = torch.tensor(data["c_tensor"], dtype=torch.long) | |
| s_tensor = torch.tensor(data["s_tensor"], dtype=torch.bool) | |
| # From (n_tracks x n_timesteps x ...) | |
| # to (n_bars x n_tracks x n_timesteps x ...) | |
| c_tensor = c_tensor.reshape(c_tensor.shape[0], self.n_bars, -1, | |
| c_tensor.shape[2], c_tensor.shape[3]) | |
| c_tensor = c_tensor.permute(1, 0, 2, 3, 4) | |
| s_tensor = s_tensor.reshape(s_tensor.shape[0], self.n_bars, -1) | |
| s_tensor = s_tensor.permute(1, 0, 2) | |
| # From decimals to onehot (pitches) | |
| pitches = c_tensor[..., 0] | |
| onehot_p = torch.zeros( | |
| (pitches.shape[0]*pitches.shape[1]*pitches.shape[2]*pitches.shape[3], | |
| constants.N_PITCH_TOKENS), | |
| dtype=torch.float32 | |
| ) | |
| onehot_p[torch.arange(0, onehot_p.shape[0]), pitches.reshape(-1)] = 1. | |
| onehot_p = onehot_p.reshape(pitches.shape[0], pitches.shape[1], | |
| pitches.shape[2], pitches.shape[3], | |
| constants.N_PITCH_TOKENS) | |
| # From decimals to onehot (durations) | |
| durs = c_tensor[..., 1] | |
| onehot_d = torch.zeros( | |
| (durs.shape[0]*durs.shape[1]*durs.shape[2]*durs.shape[3], | |
| constants.N_DUR_TOKENS), | |
| dtype=torch.float32 | |
| ) | |
| onehot_d[torch.arange(0, onehot_d.shape[0]), durs.reshape(-1)] = 1. | |
| onehot_d = onehot_d.reshape(durs.shape[0], durs.shape[1], | |
| durs.shape[2], durs.shape[3], | |
| constants.N_DUR_TOKENS) | |
| # Concatenate pitches and durations | |
| c_tensor = torch.cat((onehot_p, onehot_d), dim=-1) | |
| # Build graph structure from structure tensor | |
| graph = graph_from_tensor(s_tensor) | |
| # Filter silences in order to get a sparse representation | |
| c_tensor = c_tensor.reshape(-1, c_tensor.shape[-2], c_tensor.shape[-1]) | |
| c_tensor = c_tensor[s_tensor.reshape(-1).bool()] | |
| graph.c_tensor = c_tensor | |
| graph.s_tensor = s_tensor.float() | |
| return graph | |