PicturesOfMIDI / pom /chords.py
drscotthawley's picture
fixing up py files for run
3bf82fd
raw
history blame
No virus
20.7 kB
#! /usr/bin/env python3
import re
import sys
import torch.nn as nn
import torch
from PIL import Image
import numpy as np
from . import rect_to_square, square_to_rect
CHORD_BORDER = 8 # chord border size in pixels
# my distillation of all output from polyffusion's chord finder for transposed +/-12 semitones POP909 dataset.
NOTE_NAMES = ['C','C#','D','E','Eb','F','F#','G', 'Ab', 'A', 'Bb', 'B'] # these are from polyffusion's chord finder. yes, mixing # & b is weird
#NOTE_NAMES2 = ['A','Ab','B','Bb','C','C#','D','E','Eb','F','F#','G'] # how they are in all_chords.txt file
CHORD_TYPES = ['aug', 'dim', 'dim7', 'hdim7',
'maj', 'maj(11)', 'maj13', 'maj/3', 'maj/5', 'maj6', 'maj6(9)', 'maj7', 'maj7/3', 'maj7/5', 'maj7/7', 'maj(9)', 'maj9', 'maj9(11)',
'min', 'min(11)', 'min11', 'min13', 'min/5', 'min6', 'min6(9)', 'min7', 'min7/5', 'min7/b7', 'min(9)', 'min9', 'min/b3', 'minmaj7',
'sus2', 'sus4', 'sus4(b7)', 'sus4(b7,9)', '7', '7/3', '7/5', '7(#9)', '7/b7', '9', '11', '13'] # 44 chord types
CHORD_IND_PAIRS = [(note, chord) for note in NOTE_NAMES for chord in CHORD_TYPES]
POSSIBLE_CHORDS = [f"{note}:{chord}" for (note, chord) in CHORD_IND_PAIRS]
#POSSIBLE_CHORDS = [f"{note}:{chord}" for note in NOTE_NAMES for chord in CHORD_TYPES]
POSSIBLE_CHORDS += ['N'] # N for no chord
assert len(POSSIBLE_CHORDS) == 12*44+1, f"There should be {12*44+1} possible chords, but there are {len(POSSIBLE_CHORDS)}. Check the NOTE_NAMES and CHORD_TYPES lists."
def to_base_9(n):
# converts a decimal integer to base 9
if n == 0: return [0, 0, 0]
digits = []
while n:
digits.append(n % 9)
n //= 9
while len(digits) < 3: # add leading zeros
digits.append(0)
return digits[::-1]
def chord_num_to_color(cn, scale=30):
# "embeddings" for chords, from (0,0,30) up to (240,240,240) in each (RGB) channel, in steps of 30
color = to_base_9(cn+1)
return tuple(x*scale for x in color)
def color_to_chord_num(color, scale=30, warnings_on=False):
# reverse of chord_num_to_color, note that color goes backwards
out = sum([x//scale * 9**i for i, x in enumerate(color[::-1])])-1
if out < 0:
if warnings_on: print(f"color_to_chord_num: Warning: out should be equal to or greater than 0: color = {color}, out = {out}. Wrapping around to {len(POSSIBLE_CHORDS)+out}")
out = len(POSSIBLE_CHORDS) + out
return out
def simplify_chord(chord_name):
"""Simplifies chord names by applying a few rules:
1. get rid of the ones with parentheses, e.g. change "A:maj(11)" to just "A:maj"?
2. remove the notes in the bass, like mapping all "A:7/3", "A:7/5" and "A:7/b7" to just "A:7"?
3. remove uspension markings, e.g. sus2, sus4?
4. maybe? high-numbered added notes like "G:min11" & "G:min13" -> "G:min"
"""
chord_name = re.sub(r'\(.*','',chord_name) # 1
chord_name = re.sub(r'\/.*','',chord_name) # 2
chord_name = re.sub(r'sus.*','',chord_name) # 3
return chord_name
def get_unique_indices(data):
"""Returns the indices of non-repeating values in a list
Args:
data: A list of any data type.
Example: data = [0, 1, 4, 1, 5, 5, 5, 6, 10, 6, 6, 5]
Returns:
A list of indices for non-repeating values.
Example: result = [0, 1, 2, 3, 6, 7, 8, 10, 11]
"""
return [i for i, (val, next_val) in enumerate(zip(data, data[1:])) if val != next_val] + [len(data) - 1]
def get_nonrepeated_values(data, indices=None):
"""Returns the indices of non-repeating values in a list
Args:
data: A list of any data type.
Example: data = [0, 1, 4, 1, 5, 5, 5, 6, 10, 6, 6, 5]
Returns:
A list of non-repeating values.
Example: returns [0, 1, 4, 1, 5, 6, 10, 6, 5]
"""
if indices is None:
indices = get_unique_indices(data)
return [data[i] for i in indices]
def most_freq_or_first(arr, debug=False):
"returns either the most frequent value in array, or if multiple values are most frequent, it returns the first such value"
assert len(arr.shape) == 1, "arr must be 1D"
savearr = arr.copy()
if debug:
print("most_freq_or_first: arr = ", arr)
if savearr.min() < 0: # if there are negative values, we need to shift them up to 0
arr = arr - savearr.min()
bc = np.bincount(arr)
try:
if np.any(arr < 0): bc[arr < 0] = 0 # don't inlcude negative arr values when checking for most frequent
bc[bc != bc.max()] = 0 # only interested in most frequent values
except Exception as e:
print("Exception ",e)
print("most_freq_or_first: arr.shape = ", arr.shape)
print("most_freq_or_first: arr = ", arr )
print("most_freq_or_first: bc.shape = ", bc.shape)
raise e
out = np.argmax(bc)
# shift numbers back down
if savearr.min() < 0:
out = out + savearr.min()
assert out.max() <= arr.max(), f"out.max() = {out.max()} should be less than arr.max() = {arr.max()}"
return out
def most_freq_or_first_every(arr,
every=4, # pixels per chord label. 4=every quarter note
):
assert len(arr.shape) == 1, "arr must be 1D"
"used to grab most frequent chord labels, assuming we're starting on a beat. arr=chord label indices, e.g. in 0..528"
remainder = len(arr) % every
if remainder != 0:
arr = np.pad(arr, (0, every - remainder), mode='constant', constant_values=(0, arr[- remainder]))
#print("most_freq_or_first_every: Warning: Padding arr with last beat value on end. new arr =",arr)
check = arr.reshape((-1,every))
out = np.array( [most_freq_or_first(a) for a in arr.reshape((-1,every))] )
if out.max() > arr.max():
for i, c in enumerate(check):
mfof = most_freq_or_first(c)
if mfof > c.max():
print(f"i={i}, c={c}, most_freq_or_first(c)={mfof}")
raise ValueError(f"out.max() = {out.max()} should be less than arr.max() = {arr.max()}")
return out
def chord_str_to_pair(chord_str):
"converts a chord string to a pair of (note, chord) indices"
if chord_str == 'N':
return (-1,-1)
note, chord_type = chord_str.split(':')
note_ind = NOTE_NAMES.index(note)
chord_type_ind = CHORD_TYPES.index(chord_type)
return (note_ind, chord_type_ind)
def chords_str_to_pairs(chords_str):
for chord_str in chords_str.split(','):
yield chord_str_to_pair(chord_str)
def chords_str_to_inds(chords_str):
for chord_str in chords_str.split(','):
yield POSSIBLE_CHORDS.index(chord_str)
def pair_to_chord_index(pair):
"converts a pair of (note, chord_type) indices to a single chord index"
note_ind, chord_type_ind = pair
return note_ind*len(CHORD_TYPES) + chord_type_ind
def chord_index_to_pair(ci):
"converts a single chord index to a pair of (note, chord) indices"
note_ind = ci // len(CHORD_TYPES)
chord_type_ind = ci % len(CHORD_TYPES)
return (note_ind, chord_type_ind)
def chord_index_to_str(ci):
"converts a single chord index to a chord string"
return POSSIBLE_CHORDS[ci]
class ChordEmbedding(nn.Module):
def __init__(self, chord_emb_dim=8, note_emb_dim=8, type_emb_dim=8, debug=False):
super(ChordEmbedding, self).__init__()
self.emb_note = nn.Embedding(len(NOTE_NAMES)+1, note_emb_dim) # +1 for "N" ie no chord"
self.emb_type = nn.Embedding(len(CHORD_TYPES), type_emb_dim)
self.compactify = nn.Linear(note_emb_dim + type_emb_dim, chord_emb_dim)
self.chord_emb_dim = chord_emb_dim
self.debug = debug
self.zero_vec = torch.zeros((1, self.chord_emb_dim))
self.chord_emb_dim = chord_emb_dim
def forward(self, chord_inds:torch.Tensor, debug=False):
"""x should have dimensions (B) where B is the batch size each value is the index of the chord in the vocabulary
Any note wherever inds is len(POSSIBLE_CHORDS), we want to return a zero vector, otherwise we want to return the embedding"""
if chord_inds.max() > len(POSSIBLE_CHORDS):
torch.set_printoptions(threshold=10000)
print(f"\nchord_inds.max() = {chord_inds.max()} but len(POSSIBLE_CHORDS) = {len(POSSIBLE_CHORDS)}. \nchord_inds = {chord_inds}")
raise ValueError("chord_inds.max() should be less than len(POSSIBLE_CHORDS)")
note_inds, type_inds = chord_inds // len(CHORD_TYPES), chord_inds % len(CHORD_TYPES)
# note that for 'N' chord in which chord_ind==len(POSSIBLE_CHORDS)-1, we will get note_inds=LEN(NOTE_NAMES) and type_inds=0. that's why self.embed_note has len(NOTE_NAMES)+1
if debug:
print("note_inds, type_inds = ", note_inds, type_inds)
print("note_inds.max(), type_inds.max() = ", note_inds.max(), type_inds.max())
note_emb = self.emb_note(note_inds)
type_emb = self.emb_type(type_inds)
if debug: print("\nnote_emb.shape, type_emb.shape = ", note_emb.shape, type_emb.shape)
combined_emb = torch.cat((note_emb, type_emb), dim=1)
if debug: print("combined_emb.shape = ", combined_emb.shape)
x = self.compactify(combined_emb)
if debug: print("ce: x.shape, self.chord_emb_dim = ", x.shape, self.chord_emb_dim)
return x
class ChordAE(nn.Module):
"""Maybe not needed: Autoencoder for training chord embeddings?
Note: we don't really need an AE for the full model, we can get by with just the encoder (and no decoder)
but the AE is useful for exploring how few dimensions we can get away with"""
def __init__(self, chord_vocab_size=len(POSSIBLE_CHORDS), chord_emb_dim=8):
super(ChordAE, self).__init__()
self.encoder = ChordEmbedding(chord_emb_dim)
self.decoder = nn.Linear(chord_emb_dim, chord_vocab_size) # could do better maybe
def forward(self, x, debug=False):
x = self.encoder(x)
x = self.decoder(x)
return x
def abs_seq_to_rel_seq(seq:torch.Tensor):
"""converts a batch of absolute sequences of chord indices to a batch of relative sequence of chord indices
subtract the note of the first element in each batch from all the other note indices, modulo len(NOTE_NAMES)
overwrite the first element so it's unchanged, and overwrite and 'N' chords with...something else? TODO
"""
assert len(seq.shape)==2, f"seq should be 2D, but seq.shape = {seq.shape}"
# decompose seq into two tensors, one of notes and one of chord types
note_inds, type_inds = seq // len(CHORD_TYPES), seq % len(CHORD_TYPES)
# for note_inds<12, subtract these from the first element in the sequence, modulo len(NOTE_NAMES) i.e. 12
note_inds2 = note_inds.clone()
note_inds2[:,1:] = (note_inds2[:,1:] - note_inds2[:,0].unsqueeze(1)) % len(NOTE_NAMES)
# 'N' chords: whereever note_inds == 12, overwrite note_inds2 with 12
note_inds2[note_inds == len(NOTE_NAMES)] = len(NOTE_NAMES)
# recompose seq
changes_seq = note_inds2 * len(CHORD_TYPES) + type_inds # now these are no longer chords, they are chord *changes* rel to first chord
return changes_seq
class ChordSeqEncoder(nn.Module):
"""Encoder for sequences of chords:
We embed the first chord, then we embed the CHANGES in chords thereafter (using modulo-12 arithmetic on the bass note)
(4 chords per bar x 32 bars = 128 chords),
and then pass the sequence of the chords through some sequence model
(LSTM for now, could use a Transformer or something else later)
to generate a [256]-dimensional embedding of the sequence of chord embeddings
"""
def __init__(self, chord_emb_dim=8, seq_len=512//4, seq_emb_dim=256, hidden_dim=512, dropout=0.2):
super(ChordSeqEncoder, self).__init__()
self.chord_encoder = ChordEmbedding()
self.seq_encoder = nn.LSTM(chord_emb_dim, seq_emb_dim, batch_first=True, num_layers=2, dropout=dropout)
self.seq_len = seq_len
def forward(self, bs):
"x should have dimensions (B, S) where B is the batch size and S is the length of the sequence of chord indices"
B,S = bs.shape
changes_seq = abs_seq_to_rel_seq(bs) # convert to relative sequence of chord indices
# get chord embeddings for every chord in the batch in the sequence
x = self.chord_encoder(changes_seq.flatten())
# reshape x into (B, S, E) where B is the batch size, S is the sequence length, and E is the chord embedding dimension
x = x.view(B, S, -1)
E = x.shape[-1]
#print("before seq_encoder, x.shape = ", x.shape)
#x, _ = self.seq_encoder(x)
output, (hidden, cell) = self.seq_encoder(x)
#output of forward should be a 2-D tensor of shape (B, SE) where SE = seq_emb_dim
x = hidden[0, :, :] # return the hidden state of the LSTM, which is the last state of the sequence
#print("after seq_encoder, x.shape = ", x.shape)
return x
class ChordSeqAE(nn.Module):
"""
Chord Sequence Autoencoder. For pretraining a ChordSeqEncoder
"""
def __init__(self, chord_emb_dim=8, seq_len=512//4, seq_emb_dim=256,
hidden_dim=512, chord_vocab_size=len(POSSIBLE_CHORDS),
vae_scale=0.1):
super(ChordSeqAE, self).__init__()
self.encoder = ChordSeqEncoder(chord_emb_dim=chord_emb_dim, seq_len=seq_len, seq_emb_dim=seq_emb_dim, hidden_dim=hidden_dim)
# made decoder a sequence of linear layers with a ReLU in between
self.decoder = nn.Sequential(
nn.Linear(seq_emb_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, seq_len*chord_vocab_size)
)
self.chord_vocab_size = chord_vocab_size
self.vae_scale = vae_scale
def forward(self, bs, debug=False):
"x should have dimensions (B, S) where B is the batch size and S is the length of the sequence of chord indices"
if debug: print("ChordSeqAE: bs.shape = ", bs.shape)
B,S = bs.shape
x = self.encoder(bs)
if debug: print("ChordSeqAE: encoded x.shape = ", x.shape)
if self.vae_scale > 0 and self.training:
x = x + self.vae_scale*((x.max()-x.min())) * torch.randn_like(x)
x = self.decoder(x)
x = x.view(B, S, -1)
if debug: print("ChordSeqAE: decoded x.shape = ", x.shape)
return x
def chord_seq_from_img(img:Image.Image,
every=8, # was imaginging every beat (every=4) but looking at data, it seems like the smallest chord label is 8 pixels wide
debug=False):
"""extracts a sequence of chord indices from a pianoroll image
hopefully the dataloader will mean we can just do one image and it'll batch them
"""
if debug: print("img.size, img.min, img.max = ",img.size, np.array(img).min(), np.array(img).max())
if img.size[0] == img.size[1]: # if image is square, make it rectangular
img = square_to_rect(img)
img_arr = np.array(img)
top_row = img_arr[CHORD_BORDER//2] # all x's along y=CHORD_BORDER/2
if debug:
img.save("chord_seq_from_img.png")
print("img_arr.shape = ", img_arr.shape)
print("top_row.shape = ", top_row.shape)
print("top_row = ", top_row)
chord_seq = np.array([color_to_chord_num(tuple(c)) for c in top_row])
if chord_seq.max() >= len(POSSIBLE_CHORDS):
print(f"chord_seq.max = {chord_seq.max()} should be less than len(POSSIBLE_CHORDS) = {len(POSSIBLE_CHORDS)}\nchord_seq = {chord_seq}")
indices = np.where(chord_seq >= len(POSSIBLE_CHORDS))[0]
print("indices, chord_seq[indices], top_row[indices] = ", indices, chord_seq[indices], top_row[indices])
raise ValueError("chord_seq.max() should be less than len(POSSIBLE_CHORDS)")
chord_seq_beats = most_freq_or_first_every(chord_seq, every=every)
assert chord_seq_beats.max() <= chord_seq.max(), f"chord_seq_beats.max() = {chord_seq_beats.max()} should be less than chord_seq.max() = {chord_seq.max()}"
if debug: print("chord_seq_beats, len(POSSIBLE_CHORDS) = ", chord_seq_beats, len(POSSIBLE_CHORDS))
assert chord_seq_beats.max() < len(POSSIBLE_CHORDS), f"chord_seq_beats.max() should be less than len(POSSIBLE_CHORDS) = {len(POSSIBLE_CHORDS)}"
return torch.tensor(chord_seq_beats)
def chord_seq_from_img_tensor_batch(img_tensor_batch:torch.Tensor, every=8, debug=False):
"""extracts a sequence of chord indices from a batch of pianoroll images"""
batch_size = img_tensor_batch.shape[0]
itb = (img_tensor_batch + 1.0) * 127.5 #rescale from -1..1 to 0..255
chord_seqs = []
for i in range(batch_size): # TODO: may be a faster way to do this with tensor ops
# converting to images and back is slow this is slow
img = Image.fromarray(np.round( itb[i].cpu().permute(1,2,0).numpy()).astype(np.uint8))
img = square_to_rect(img)
chord_seq = chord_seq_from_img(img, every=every, )
chord_seqs.append(chord_seq)
return torch.stack(chord_seqs).to(img_tensor_batch.device)
def img_batch_to_seq_emb(img_tensor_batch:torch.Tensor, chord_seq_encoder:nn.Module, every=8, debug=False):
"""converts a batch of pianoroll images to a batch of chord sequence embeddings"""
chord_seq_batch = chord_seq_from_img_tensor_batch(img_tensor_batch, every=every, debug=debug)
cs_emb = chord_seq_encoder(chord_seq_batch)
return cs_emb
# TODO: test it!
if __name__ == '__main__':
# FOR TESTING/DEV ONLY
import sys, random
def make_image_tensor_batch(batch_size=2):
"""FOR TESTING/DEV ONLY: makes a batch of random chord-endowed pianoroll (square) images
So I can iterate other parts of this faster w/o having to spin up crowson's training code every time while i write code here
shape = (B, 3, 256, 256), normalization = -1.0 to 1.0
"""
img_batch = torch.zeros((batch_size, 3, 256, 256))
for i in range(batch_size):
n = i+1# np.random.randint(0, 909)
img_filename = f"/data/POP909-Dataset/images_128_rg_chords_TOTAL/{n:03}_TOTAL.png" # place to grab images from
img = Image.open(img_filename).convert('RGB')
# crop to 512 pixels wide
img = img.crop((0,0,512,128))
img = rect_to_square(img)
img_batch[i] = torch.tensor(np.array(img)).permute(2,0,1).float() / 127.5 - 1.0 # normalization done by dataloader makes images -1 to 1
return img_batch
# quick check of the mapping
for cn in range(len(POSSIBLE_CHORDS)):
color = chord_num_to_color(cn)
print("cn, color = ", cn, color)
cn2 = color_to_chord_num(color)
assert cn2 == cn, f"cn2={cn2} should be cn={cn}, color={color}"
if len(sys.argv) <= 1:
print("Testing suite, Usage: python chords.py <some_arg>")
sys.exit(1)
some_arg = sys.argv[1]
batch_size=2
img_tensor_batch = make_image_tensor_batch(batch_size=batch_size)
print("img_tensor_batch.shape = ", img_tensor_batch.shape)
print("img_tensor_batch.min(), img_tensor_batch.max() = ", img_tensor_batch.min(), img_tensor_batch.max())
chord_seq_batch = chord_seq_from_img_tensor_batch(img_tensor_batch, every=8, debug=False)
print("chord_seq_batch.shape = ", chord_seq_batch.shape)
print(f"chord_seq_batch = \n{chord_seq_batch}")
cse = ChordSeqEncoder()
cs_emb = cse(chord_seq_batch)
print("cs_emb.shape = ", cs_emb.shape)
#print(f"cs_emb = \n{cs_emb}")
sys.exit(0)
#img_filename = some_arg
img = Image.open(img_filename).convert('RGB')
chord_ind_seq = chord_seq_from_img(img, debug=False)
print("chord_ind_seq = ", chord_ind_seq)
print("len(chord_ind_seq) = ", len(chord_ind_seq))
chord_embedder = ChordEmbedding(len(POSSIBLE_CHORDS))
#print("chord_embeddings = ", chord_embedder(chord_ind_seq))
sys.exit(0)
#chords_str = some_arg
#cis = chords_str_to_inds(chords_str)
cis = chord_ind_seq
for ci in cis:
print("\n-------")
#ci = pair_to_chord_index(pair)
pair = chord_index_to_pair(ci)
print(f"Input: chord_str = {chords_str}, pair = {pair}, ci = {ci}")
color = chord_num_to_color(ci)
print(color)
cn2 = color_to_chord_num(color)
out_str = chord_index_to_str(cn2)
print(f"Output: cn2 = {cn2}, out_str = {out_str}")
print("Embedding: ")
with torch.no_grad():
x = torch.tensor([ci])
print(chord_embedder(x))