File size: 20,656 Bytes
78b535c
 
 
 
 
 
 
3bf82fd
78b535c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
#! /usr/bin/env python3
import re
import sys
import torch.nn as nn
import torch 
from PIL import Image
import numpy as np
from . import rect_to_square, square_to_rect

CHORD_BORDER = 8   # chord border size in pixels

# my distillation of all output from polyffusion's chord finder for transposed +/-12 semitones POP909 dataset.
NOTE_NAMES = ['C','C#','D','E','Eb','F','F#','G', 'Ab', 'A', 'Bb', 'B'] # these are from polyffusion's chord finder. yes, mixing # & b is weird
#NOTE_NAMES2 = ['A','Ab','B','Bb','C','C#','D','E','Eb','F','F#','G'] # how they are in all_chords.txt file

CHORD_TYPES = ['aug', 'dim', 'dim7', 'hdim7', 
               'maj', 'maj(11)', 'maj13', 'maj/3', 'maj/5', 'maj6', 'maj6(9)', 'maj7', 'maj7/3', 'maj7/5', 'maj7/7', 'maj(9)', 'maj9', 'maj9(11)', 
               'min', 'min(11)', 'min11', 'min13', 'min/5', 'min6', 'min6(9)', 'min7', 'min7/5', 'min7/b7', 'min(9)', 'min9', 'min/b3', 'minmaj7', 
               'sus2', 'sus4', 'sus4(b7)', 'sus4(b7,9)', '7', '7/3', '7/5', '7(#9)', '7/b7', '9', '11', '13']  # 44 chord types

CHORD_IND_PAIRS = [(note, chord) for note in NOTE_NAMES for chord in CHORD_TYPES]
POSSIBLE_CHORDS = [f"{note}:{chord}" for (note, chord) in CHORD_IND_PAIRS]
#POSSIBLE_CHORDS = [f"{note}:{chord}" for note in NOTE_NAMES for chord in CHORD_TYPES]
POSSIBLE_CHORDS += ['N'] # N for no chord
assert len(POSSIBLE_CHORDS) == 12*44+1, f"There should be {12*44+1} possible chords, but there are {len(POSSIBLE_CHORDS)}. Check the NOTE_NAMES and CHORD_TYPES lists."


def to_base_9(n):
    # converts a decimal integer to base 9
    if n == 0: return [0, 0, 0]
    digits = []
    while n:
        digits.append(n % 9)
        n //= 9
    while len(digits) < 3: # add leading zeros
        digits.append(0)
    return digits[::-1]


def chord_num_to_color(cn, scale=30):
    # "embeddings" for chords, from (0,0,30) up to (240,240,240) in each (RGB) channel, in steps of 30
    color = to_base_9(cn+1)
    return tuple(x*scale for x in color)

def color_to_chord_num(color, scale=30, warnings_on=False):
    # reverse of chord_num_to_color, note that color goes backwards
    out  = sum([x//scale * 9**i for i, x in enumerate(color[::-1])])-1  
    if out < 0: 
        if warnings_on: print(f"color_to_chord_num: Warning: out should be equal to or greater than 0: color = {color}, out = {out}. Wrapping around to {len(POSSIBLE_CHORDS)+out}")
        out = len(POSSIBLE_CHORDS) + out
    return out 


def simplify_chord(chord_name):
    """Simplifies chord names by applying a few rules:
    1. get rid of the ones with parentheses, e.g. change "A:maj(11)" to just "A:maj"? 
    2. remove the notes in the bass, like mapping all "A:7/3", "A:7/5" and "A:7/b7" to just "A:7"? 
    3. remove uspension markings, e.g. sus2, sus4?  
    4. maybe? high-numbered added notes like "G:min11"  &  "G:min13"   -> "G:min"
    """
    chord_name = re.sub(r'\(.*','',chord_name) # 1
    chord_name = re.sub(r'\/.*','',chord_name) # 2
    chord_name = re.sub(r'sus.*','',chord_name) # 3
    return chord_name




def get_unique_indices(data):
  """Returns the indices of non-repeating values in a list 
  Args:
      data: A list of any data type. 
      Example: data = [0, 1, 4, 1, 5, 5, 5, 6, 10, 6, 6, 5]

  Returns:
      A list of indices for non-repeating values.
      Example: result = [0, 1, 2, 3, 6, 7, 8, 10, 11]
  """
  return [i for i, (val, next_val) in enumerate(zip(data, data[1:])) if val != next_val] + [len(data) - 1]

def get_nonrepeated_values(data, indices=None):
    """Returns the indices of non-repeating values in a list
    Args:
        data: A list of any data type. 
        Example: data = [0, 1, 4, 1, 5, 5, 5, 6, 10, 6, 6, 5]

    Returns:
        A list of non-repeating values.
        Example: returns [0, 1, 4, 1, 5, 6, 10, 6, 5]
    """
    if indices is None:
        indices = get_unique_indices(data)  
    return [data[i] for i in indices]



def most_freq_or_first(arr, debug=False):
    "returns either the most frequent value in array, or if multiple values are most frequent, it returns the first such value"
    assert len(arr.shape) == 1, "arr must be 1D"
    savearr = arr.copy()
    if debug: 
        print("most_freq_or_first: arr = ", arr)
    if savearr.min() < 0: # if there are negative values, we need to shift them up to 0
        arr = arr - savearr.min()
    bc = np.bincount(arr)
    try: 
        
        if np.any(arr < 0): bc[arr < 0] = 0  # don't inlcude negative arr values when checking for most frequent
        bc[bc != bc.max()] = 0  # only interested in most frequent values
    except Exception as e:
        print("Exception ",e)
        print("most_freq_or_first: arr.shape = ", arr.shape)
        print("most_freq_or_first: arr = ", arr )
        print("most_freq_or_first: bc.shape = ", bc.shape)
        raise e
    out = np.argmax(bc)
    # shift numbers back down
    if savearr.min() < 0: 
        out = out + savearr.min()
    assert out.max() <= arr.max(), f"out.max() = {out.max()} should be less than arr.max() = {arr.max()}"
    return out 


def most_freq_or_first_every(arr, 
                             every=4, # pixels per chord label. 4=every quarter note
                             ):
    assert len(arr.shape) == 1, "arr must be 1D"
    "used to grab most frequent chord labels, assuming we're starting on a beat. arr=chord label indices, e.g. in 0..528"
    remainder = len(arr) % every
    if remainder != 0:
        arr = np.pad(arr, (0, every - remainder), mode='constant', constant_values=(0, arr[- remainder])) 
        #print("most_freq_or_first_every: Warning: Padding arr with last beat value on end. new arr =",arr)
    check = arr.reshape((-1,every))
    out = np.array( [most_freq_or_first(a) for a in arr.reshape((-1,every))] )
    if out.max() > arr.max():
        for i, c in enumerate(check):
            mfof = most_freq_or_first(c)
            if mfof > c.max():
                print(f"i={i}, c={c}, most_freq_or_first(c)={mfof}")
        raise ValueError(f"out.max() = {out.max()} should be less than arr.max() = {arr.max()}")
            
    return out


def chord_str_to_pair(chord_str):
    "converts a chord string to a pair of (note, chord) indices"
    if chord_str == 'N':
        return (-1,-1)
    note, chord_type = chord_str.split(':')
    note_ind = NOTE_NAMES.index(note)
    chord_type_ind = CHORD_TYPES.index(chord_type)
    return (note_ind, chord_type_ind)

def chords_str_to_pairs(chords_str):
    for chord_str in chords_str.split(','):
        yield chord_str_to_pair(chord_str)

def chords_str_to_inds(chords_str):
    for chord_str in chords_str.split(','):
        yield POSSIBLE_CHORDS.index(chord_str)

def pair_to_chord_index(pair):
    "converts a pair of (note, chord_type) indices to a single chord index"
    note_ind, chord_type_ind = pair
    return note_ind*len(CHORD_TYPES) + chord_type_ind

def chord_index_to_pair(ci):
    "converts a single chord index to a pair of (note, chord) indices"
    note_ind = ci // len(CHORD_TYPES)
    chord_type_ind = ci % len(CHORD_TYPES)
    return (note_ind, chord_type_ind)

def chord_index_to_str(ci):
    "converts a single chord index to a chord string"
    return POSSIBLE_CHORDS[ci]


class ChordEmbedding(nn.Module):
    def __init__(self, chord_emb_dim=8, note_emb_dim=8, type_emb_dim=8,  debug=False):
        super(ChordEmbedding, self).__init__()
        self.emb_note = nn.Embedding(len(NOTE_NAMES)+1, note_emb_dim)  # +1 for "N" ie no chord"
        self.emb_type = nn.Embedding(len(CHORD_TYPES), type_emb_dim)
        self.compactify = nn.Linear(note_emb_dim + type_emb_dim, chord_emb_dim)
        self.chord_emb_dim = chord_emb_dim
        self.debug = debug
        self.zero_vec = torch.zeros((1, self.chord_emb_dim))
        self.chord_emb_dim = chord_emb_dim

    def forward(self, chord_inds:torch.Tensor, debug=False):
        """x should have dimensions (B) where B is the batch size each value is the index of the chord in the vocabulary
        Any note wherever inds is len(POSSIBLE_CHORDS), we want to return a zero vector, otherwise we want to return the embedding"""
        if chord_inds.max() > len(POSSIBLE_CHORDS):
            torch.set_printoptions(threshold=10000)
            print(f"\nchord_inds.max() = {chord_inds.max()} but len(POSSIBLE_CHORDS) = {len(POSSIBLE_CHORDS)}. \nchord_inds = {chord_inds}")
            raise ValueError("chord_inds.max() should be less than len(POSSIBLE_CHORDS)")
        note_inds, type_inds = chord_inds // len(CHORD_TYPES), chord_inds % len(CHORD_TYPES)
        # note that for 'N' chord in which chord_ind==len(POSSIBLE_CHORDS)-1, we will get note_inds=LEN(NOTE_NAMES) and type_inds=0. that's why self.embed_note has len(NOTE_NAMES)+1
        if debug:
            print("note_inds, type_inds = ", note_inds, type_inds)
            print("note_inds.max(), type_inds.max() = ", note_inds.max(), type_inds.max())
        note_emb = self.emb_note(note_inds)
        type_emb = self.emb_type(type_inds)
        if debug: print("\nnote_emb.shape, type_emb.shape = ", note_emb.shape, type_emb.shape)
        combined_emb = torch.cat((note_emb, type_emb), dim=1)
        if debug: print("combined_emb.shape = ", combined_emb.shape)
        x = self.compactify(combined_emb)
        if debug: print("ce: x.shape, self.chord_emb_dim = ", x.shape, self.chord_emb_dim)
        return x
    

class ChordAE(nn.Module):
    """Maybe not needed: Autoencoder for training chord embeddings? 
    Note: we don't really need an AE for the full model, we can get by with just the encoder (and no decoder)
    but the AE is useful for exploring how few dimensions we can get away with"""
    def __init__(self, chord_vocab_size=len(POSSIBLE_CHORDS), chord_emb_dim=8):
        super(ChordAE, self).__init__()
        self.encoder = ChordEmbedding(chord_emb_dim)
        self.decoder = nn.Linear(chord_emb_dim, chord_vocab_size) # could do better maybe
    def forward(self, x, debug=False):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

def abs_seq_to_rel_seq(seq:torch.Tensor):
    """converts a batch of absolute sequences of chord indices to a batch of relative sequence of chord indices
       subtract the note of the first element in each batch from all the other note indices, modulo len(NOTE_NAMES)
       overwrite the first element so it's unchanged, and overwrite and 'N' chords with...something else? TODO
    """
    assert len(seq.shape)==2, f"seq should be 2D, but seq.shape = {seq.shape}"
    # decompose seq into two tensors, one of notes and one of chord types
    note_inds, type_inds = seq // len(CHORD_TYPES), seq % len(CHORD_TYPES)
    # for note_inds<12, subtract these from the first element in the sequence, modulo len(NOTE_NAMES) i.e. 12
    note_inds2 = note_inds.clone()
    note_inds2[:,1:] = (note_inds2[:,1:] - note_inds2[:,0].unsqueeze(1)) % len(NOTE_NAMES)
    # 'N' chords: whereever note_inds == 12, overwrite note_inds2 with 12 
    note_inds2[note_inds == len(NOTE_NAMES)] = len(NOTE_NAMES)
    # recompose seq
    changes_seq = note_inds2 * len(CHORD_TYPES) + type_inds  # now these are no longer chords, they are chord *changes* rel to first chord
    return changes_seq


    

class ChordSeqEncoder(nn.Module):
    """Encoder for sequences of chords:
    We embed the first chord, then we embed the CHANGES in chords thereafter (using modulo-12 arithmetic on the bass note)
    (4 chords per bar x 32 bars = 128 chords), 
    and then pass the sequence of the chords through some sequence model 
         (LSTM for now, could use a Transformer or something else later)
    to generate a [256]-dimensional embedding of the sequence of chord embeddings
    """
    def __init__(self, chord_emb_dim=8, seq_len=512//4, seq_emb_dim=256, hidden_dim=512, dropout=0.2):
        super(ChordSeqEncoder, self).__init__()
        self.chord_encoder = ChordEmbedding()
        self.seq_encoder = nn.LSTM(chord_emb_dim, seq_emb_dim, batch_first=True, num_layers=2, dropout=dropout)
        self.seq_len = seq_len
    def forward(self, bs):
        "x should have dimensions (B, S) where B is the batch size and S is the length of the sequence of chord indices"
        B,S = bs.shape
        changes_seq = abs_seq_to_rel_seq(bs)  # convert to relative sequence of chord indices
        # get chord embeddings for every chord in the batch in the sequence
        x = self.chord_encoder(changes_seq.flatten())
        # reshape x into (B, S, E) where B is the batch size, S is the sequence length, and E is the chord embedding dimension
        x = x.view(B, S, -1)
        E = x.shape[-1]
        #print("before seq_encoder, x.shape = ", x.shape)
        #x, _ = self.seq_encoder(x)
        output, (hidden, cell) = self.seq_encoder(x)

        #output of forward should be a 2-D tensor of shape (B, SE) where SE = seq_emb_dim 
        x = hidden[0, :, :]  # return the hidden state of the LSTM, which is the last state of the sequence
        #print("after seq_encoder, x.shape = ", x.shape)
        return x


class ChordSeqAE(nn.Module):
    """
    Chord Sequence Autoencoder. For pretraining a ChordSeqEncoder
    """
    def __init__(self, chord_emb_dim=8, seq_len=512//4, seq_emb_dim=256, 
                 hidden_dim=512,  chord_vocab_size=len(POSSIBLE_CHORDS),
                 vae_scale=0.1):
        super(ChordSeqAE, self).__init__()
        self.encoder = ChordSeqEncoder(chord_emb_dim=chord_emb_dim, seq_len=seq_len, seq_emb_dim=seq_emb_dim, hidden_dim=hidden_dim)
        # made decoder a sequence of linear layers with a ReLU in between
        self.decoder = nn.Sequential(
            nn.Linear(seq_emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, seq_len*chord_vocab_size)
        )
        self.chord_vocab_size = chord_vocab_size
        self.vae_scale = vae_scale

    def forward(self, bs, debug=False):
        "x should have dimensions (B, S) where B is the batch size and S is the length of the sequence of chord indices"
        if debug: print("ChordSeqAE: bs.shape = ", bs.shape)
        B,S = bs.shape
        x = self.encoder(bs)
        if debug: print("ChordSeqAE: encoded x.shape = ", x.shape)
        if self.vae_scale > 0 and self.training:
            x = x + self.vae_scale*((x.max()-x.min())) * torch.randn_like(x)
        x = self.decoder(x) 
        x = x.view(B, S, -1)
        if debug: print("ChordSeqAE: decoded x.shape = ", x.shape)
        return x

def chord_seq_from_img(img:Image.Image, 
                       every=8,  # was imaginging every beat (every=4) but looking at data, it seems like the smallest chord label is 8 pixels wide
                       debug=False):
    """extracts a sequence of chord indices from a pianoroll image 
       hopefully the dataloader will mean we can just do one image and it'll batch them
    """    
    if debug: print("img.size, img.min, img.max = ",img.size, np.array(img).min(), np.array(img).max())
    if img.size[0] == img.size[1]: # if image is square, make it rectangular
        img = square_to_rect(img)
    img_arr = np.array(img)
    top_row = img_arr[CHORD_BORDER//2] # all x's along y=CHORD_BORDER/2
    if debug: 
        img.save("chord_seq_from_img.png")
        print("img_arr.shape = ", img_arr.shape)
        print("top_row.shape = ", top_row.shape)
        print("top_row = ", top_row)
    chord_seq = np.array([color_to_chord_num(tuple(c)) for c in top_row])
    if chord_seq.max() >= len(POSSIBLE_CHORDS):
        print(f"chord_seq.max = {chord_seq.max()} should be less than len(POSSIBLE_CHORDS) = {len(POSSIBLE_CHORDS)}\nchord_seq = {chord_seq}")
        indices = np.where(chord_seq >= len(POSSIBLE_CHORDS))[0]
        print("indices, chord_seq[indices], top_row[indices] = ", indices, chord_seq[indices], top_row[indices])
        raise ValueError("chord_seq.max() should be less than len(POSSIBLE_CHORDS)")
    chord_seq_beats = most_freq_or_first_every(chord_seq, every=every)
    assert chord_seq_beats.max() <= chord_seq.max(), f"chord_seq_beats.max() = {chord_seq_beats.max()} should be less than chord_seq.max() = {chord_seq.max()}"
    if debug: print("chord_seq_beats, len(POSSIBLE_CHORDS) = ", chord_seq_beats, len(POSSIBLE_CHORDS))
    assert chord_seq_beats.max() < len(POSSIBLE_CHORDS), f"chord_seq_beats.max() should be less than len(POSSIBLE_CHORDS) = {len(POSSIBLE_CHORDS)}"
    return torch.tensor(chord_seq_beats)


def chord_seq_from_img_tensor_batch(img_tensor_batch:torch.Tensor, every=8, debug=False):
    """extracts a sequence of chord indices from a batch of pianoroll images"""
    batch_size = img_tensor_batch.shape[0]
    itb = (img_tensor_batch + 1.0) * 127.5 #rescale from -1..1 to 0..255
    chord_seqs = []
    for i in range(batch_size): # TODO: may be a faster way to do this with tensor ops
        # converting to images and back is slow this is slow
        img = Image.fromarray(np.round( itb[i].cpu().permute(1,2,0).numpy()).astype(np.uint8))
        img = square_to_rect(img)
        chord_seq = chord_seq_from_img(img, every=every, )
        chord_seqs.append(chord_seq)
    return torch.stack(chord_seqs).to(img_tensor_batch.device)
                              
def img_batch_to_seq_emb(img_tensor_batch:torch.Tensor, chord_seq_encoder:nn.Module, every=8, debug=False):
    """converts a batch of pianoroll images to a batch of chord sequence embeddings"""
    chord_seq_batch = chord_seq_from_img_tensor_batch(img_tensor_batch, every=every, debug=debug)
    cs_emb = chord_seq_encoder(chord_seq_batch)
    return cs_emb

# TODO: test it!

if __name__ == '__main__':
    # FOR TESTING/DEV ONLY
    import sys, random

    def make_image_tensor_batch(batch_size=2): 
        """FOR TESTING/DEV ONLY: makes a batch of random chord-endowed pianoroll (square) images
        So I can iterate other parts of this faster w/o having to spin up crowson's training code every time while i write code here
        shape = (B, 3, 256, 256), normalization = -1.0 to 1.0
        """
        img_batch = torch.zeros((batch_size, 3, 256, 256))
        for i in range(batch_size):
            n = i+1# np.random.randint(0, 909)
            img_filename = f"/data/POP909-Dataset/images_128_rg_chords_TOTAL/{n:03}_TOTAL.png" # place to grab images from
            img = Image.open(img_filename).convert('RGB')
            # crop to 512 pixels wide
            img = img.crop((0,0,512,128))
            img = rect_to_square(img)
            img_batch[i] = torch.tensor(np.array(img)).permute(2,0,1).float() / 127.5 - 1.0  # normalization done by dataloader makes images -1 to 1
        return img_batch

    # quick check of the mapping
    for cn in range(len(POSSIBLE_CHORDS)):
        color = chord_num_to_color(cn)
        print("cn, color = ", cn, color)
        cn2 = color_to_chord_num(color)
        assert cn2 == cn, f"cn2={cn2} should be cn={cn}, color={color}"


    if len(sys.argv) <= 1:
        print("Testing suite, Usage: python chords.py <some_arg>")
        sys.exit(1)
    some_arg = sys.argv[1]

    batch_size=2
    img_tensor_batch = make_image_tensor_batch(batch_size=batch_size)
    print("img_tensor_batch.shape = ", img_tensor_batch.shape)
    print("img_tensor_batch.min(), img_tensor_batch.max() = ", img_tensor_batch.min(), img_tensor_batch.max())

    chord_seq_batch = chord_seq_from_img_tensor_batch(img_tensor_batch, every=8, debug=False)

    print("chord_seq_batch.shape = ", chord_seq_batch.shape)
    print(f"chord_seq_batch = \n{chord_seq_batch}")


    cse = ChordSeqEncoder()
    cs_emb = cse(chord_seq_batch)

    print("cs_emb.shape = ", cs_emb.shape)
    #print(f"cs_emb = \n{cs_emb}")
    sys.exit(0)




    #img_filename = some_arg
    img = Image.open(img_filename).convert('RGB')
    chord_ind_seq = chord_seq_from_img(img, debug=False)
    print("chord_ind_seq = ", chord_ind_seq)
    print("len(chord_ind_seq) = ", len(chord_ind_seq))
    chord_embedder = ChordEmbedding(len(POSSIBLE_CHORDS))
    #print("chord_embeddings = ", chord_embedder(chord_ind_seq))
    sys.exit(0)
    #chords_str = some_arg
    #cis = chords_str_to_inds(chords_str)
    cis = chord_ind_seq
    for ci in cis:
        print("\n-------")
        #ci = pair_to_chord_index(pair)
        pair = chord_index_to_pair(ci)
        print(f"Input: chord_str = {chords_str}, pair = {pair}, ci = {ci}")
        color = chord_num_to_color(ci)
        print(color)
        cn2 = color_to_chord_num(color)
        out_str = chord_index_to_str(cn2)
        print(f"Output: cn2  = {cn2}, out_str = {out_str}")

        print("Embedding: ")
        with torch.no_grad():
            x = torch.tensor([ci])
            print(chord_embedder(x))