jerald commited on
Commit
4819bc9
·
1 Parent(s): fcd062e

source dump

Browse files
Files changed (47) hide show
  1. app.py +62 -0
  2. music_transformer.pth +3 -0
  3. requirements.txt +4 -0
  4. utils/.DS_Store +0 -0
  5. utils/musicautobot/.DS_Store +0 -0
  6. utils/musicautobot/__init__.py +3 -0
  7. utils/musicautobot/__pycache__/__init__.cpython-310.pyc +0 -0
  8. utils/musicautobot/__pycache__/config.cpython-310.pyc +0 -0
  9. utils/musicautobot/__pycache__/numpy_encode.cpython-310.pyc +0 -0
  10. utils/musicautobot/__pycache__/vocab.cpython-310.pyc +0 -0
  11. utils/musicautobot/config.py +47 -0
  12. utils/musicautobot/multitask_transformer/__init__.py +3 -0
  13. utils/musicautobot/multitask_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  14. utils/musicautobot/multitask_transformer/__pycache__/dataloader.cpython-310.pyc +0 -0
  15. utils/musicautobot/multitask_transformer/__pycache__/learner.cpython-310.pyc +0 -0
  16. utils/musicautobot/multitask_transformer/__pycache__/model.cpython-310.pyc +0 -0
  17. utils/musicautobot/multitask_transformer/__pycache__/transform.cpython-310.pyc +0 -0
  18. utils/musicautobot/multitask_transformer/dataloader.py +146 -0
  19. utils/musicautobot/multitask_transformer/learner.py +340 -0
  20. utils/musicautobot/multitask_transformer/model.py +258 -0
  21. utils/musicautobot/multitask_transformer/transform.py +68 -0
  22. utils/musicautobot/music_transformer/__init__.py +3 -0
  23. utils/musicautobot/music_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  24. utils/musicautobot/music_transformer/__pycache__/dataloader.cpython-310.pyc +0 -0
  25. utils/musicautobot/music_transformer/__pycache__/learner.cpython-310.pyc +0 -0
  26. utils/musicautobot/music_transformer/__pycache__/model.cpython-310.pyc +0 -0
  27. utils/musicautobot/music_transformer/__pycache__/transform.cpython-310.pyc +0 -0
  28. utils/musicautobot/music_transformer/dataloader.py +229 -0
  29. utils/musicautobot/music_transformer/learner.py +171 -0
  30. utils/musicautobot/music_transformer/model.py +66 -0
  31. utils/musicautobot/music_transformer/transform.py +235 -0
  32. utils/musicautobot/numpy_encode.py +302 -0
  33. utils/musicautobot/utils/__init__.py +0 -0
  34. utils/musicautobot/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  35. utils/musicautobot/utils/__pycache__/attention_mask.cpython-310.pyc +0 -0
  36. utils/musicautobot/utils/__pycache__/file_processing.cpython-310.pyc +0 -0
  37. utils/musicautobot/utils/__pycache__/midifile.cpython-310.pyc +0 -0
  38. utils/musicautobot/utils/__pycache__/setup_musescore.cpython-310.pyc +0 -0
  39. utils/musicautobot/utils/__pycache__/top_k_top_p.cpython-310.pyc +0 -0
  40. utils/musicautobot/utils/attention_mask.py +21 -0
  41. utils/musicautobot/utils/file_processing.py +52 -0
  42. utils/musicautobot/utils/lamb.py +106 -0
  43. utils/musicautobot/utils/midifile.py +107 -0
  44. utils/musicautobot/utils/setup_musescore.py +46 -0
  45. utils/musicautobot/utils/stacked_dataloader.py +70 -0
  46. utils/musicautobot/utils/top_k_top_p.py +35 -0
  47. utils/musicautobot/vocab.py +93 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.musicautobot.numpy_encode import *
2
+ from utils.musicautobot.utils.file_processing import process_all, process_file
3
+ from utils.musicautobot.config import *
4
+ from utils.musicautobot.music_transformer import *
5
+
6
+ import gradio as gr
7
+ from midi2audio import FluidSynth
8
+ import tempfile
9
+ import os
10
+
11
+ # Bootloading model
12
+ data_path = Path('./')
13
+ data = MusicDataBunch.empty(data_path)
14
+ vocab = data.vocab
15
+ pretrained_path='./music_transformer.pth'
16
+ learn = music_model_learner(data, pretrained_path=pretrained_path, config=default_config())
17
+
18
+
19
+ def predict(seed_midi, n_words=400, temperature1=1.1, temperature2=0.4, min_bars=12, top_k=24, top_p=0.7):
20
+ # Load input MIDI file as MusicItem
21
+ cutoff_beat = 10
22
+ item = MusicItem.from_file(seed_midi.name, data.vocab)
23
+ seed_item = item.trim_to_beat(cutoff_beat)
24
+
25
+ # Generate prediction
26
+ pred, full = learn.predict(seed_item, n_words=n_words, temperatures=(temperature1, temperature2), min_bars=min_bars, top_k=top_k, top_p=top_p)
27
+
28
+ # Convert input MIDI to audio
29
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as seed_audio_temp:
30
+ FluidSynth("sound_font.sf2").midi_to_audio(seed_midi.name, seed_audio_temp.name)
31
+
32
+ # Save generated MIDI as temporary file
33
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.midi') as pred_midi_temp:
34
+ pred.stream.write('midi', fp=pred_midi_temp.name)
35
+
36
+ # Convert generated MIDI to audio
37
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as pred_audio_temp:
38
+ FluidSynth("sound_font.sf2").midi_to_audio(pred_midi_temp.name, pred_audio_temp.name)
39
+
40
+ # Cleanup temporary MIDI file
41
+ os.remove(pred_midi_temp.name)
42
+
43
+ return seed_audio_temp.name, pred_audio_temp.name
44
+
45
+
46
+
47
+ iface = gr.Interface(fn=predict,
48
+ inputs=[
49
+ gr.inputs.File(label="Seed MIDI"),
50
+ gr.inputs.Slider(50, 1000, step=10, default=400, label="Number of Words"),
51
+ gr.inputs.Slider(0.0, 2.0, step=0.1, default=1.1, label="Temperature 1"),
52
+ gr.inputs.Slider(0.0, 2.0, step=0.1, default=0.4, label="Temperature 2"),
53
+ gr.inputs.Slider(1, 32, step=1, default=12, label="Min Bars"),
54
+ gr.inputs.Slider(1, 50, step=1, default=24, label="Top K"),
55
+ gr.inputs.Slider(0.0, 1.0, step=0.1, default=0.7, label="Top P")
56
+ ],
57
+ outputs=[
58
+ gr.outputs.Audio(type='filepath', label="Seed Audio"),
59
+ gr.outputs.Audio(type='filepath', label="Generated Audio")
60
+ ],)
61
+
62
+ iface.launch()
music_transformer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9856190b46abee88440104c661349f577eca4754ae485b63cf77030772b0c8cf
3
+ size 657241884
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ midi2audio
3
+ music21
4
+ fastai
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/musicautobot/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/musicautobot/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .utils.setup_musescore import setup_musescore
2
+
3
+ setup_musescore()
utils/musicautobot/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (239 Bytes). View file
 
utils/musicautobot/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.25 kB). View file
 
utils/musicautobot/__pycache__/numpy_encode.cpython-310.pyc ADDED
Binary file (9.77 kB). View file
 
utils/musicautobot/__pycache__/vocab.cpython-310.pyc ADDED
Binary file (5.24 kB). View file
 
utils/musicautobot/config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.text.models.transformer import tfmerXL_lm_config, Activation
2
+ # from .vocab import MusicVocab
3
+
4
+ def default_config():
5
+ config = tfmerXL_lm_config.copy()
6
+ config['act'] = Activation.GeLU
7
+
8
+ config['mem_len'] = 512
9
+ config['d_model'] = 512
10
+ config['d_inner'] = 2048
11
+ config['n_layers'] = 16
12
+
13
+ config['n_heads'] = 8
14
+ config['d_head'] = 64
15
+
16
+ return config
17
+
18
+ def music_config():
19
+ config = default_config()
20
+ config['encode_position'] = True
21
+ return config
22
+
23
+ def musicm_config():
24
+ config = music_config()
25
+ config['d_model'] = 768
26
+ config['d_inner'] = 3072
27
+ config['n_heads'] = 12
28
+ config['d_head'] = 64
29
+ config['n_layers'] = 12
30
+ return config
31
+
32
+ def multitask_config():
33
+ config = default_config()
34
+ config['bias'] = True
35
+ config['enc_layers'] = 8
36
+ config['dec_layers'] = 8
37
+ del config['n_layers']
38
+ return config
39
+
40
+ def multitaskm_config():
41
+ config = musicm_config()
42
+ config['bias'] = True
43
+ config['enc_layers'] = 12
44
+ config['dec_layers'] = 12
45
+ del config['n_layers']
46
+ return config
47
+
utils/musicautobot/multitask_transformer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dataloader import *
2
+ from .model import *
3
+ from .learner import *
utils/musicautobot/multitask_transformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (257 Bytes). View file
 
utils/musicautobot/multitask_transformer/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (6.17 kB). View file
 
utils/musicautobot/multitask_transformer/__pycache__/learner.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
utils/musicautobot/multitask_transformer/__pycache__/model.cpython-310.pyc ADDED
Binary file (11.4 kB). View file
 
utils/musicautobot/multitask_transformer/__pycache__/transform.cpython-310.pyc ADDED
Binary file (3.72 kB). View file
 
utils/musicautobot/multitask_transformer/dataloader.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basics import *
2
+ from .transform import *
3
+ from ..music_transformer.dataloader import MusicDataBunch, MusicItemList
4
+ # Sequence 2 Sequence Translate
5
+
6
+ class S2SFileProcessor(PreProcessor):
7
+ "`PreProcessor` that opens the filenames and read the texts."
8
+ def process_one(self,item):
9
+ out = np.load(item, allow_pickle=True)
10
+ if out.shape != (2,): return None
11
+ if not 16 < len(out[0]) < 2048: return None
12
+ if not 16 < len(out[1]) < 2048: return None
13
+ return out
14
+
15
+ def process(self, ds:Collection):
16
+ ds.items = [self.process_one(item) for item in ds.items]
17
+ ds.items = [i for i in ds.items if i is not None] # filter out None
18
+
19
+ class S2SPartsProcessor(PreProcessor):
20
+ "Encodes midi file into 2 separate parts - melody and chords."
21
+
22
+ def process_one(self, item):
23
+ m, c = item
24
+ mtrack = MultitrackItem.from_npenc_parts(m, c, vocab=self.vocab)
25
+ return mtrack.to_idx()
26
+
27
+ def process(self, ds):
28
+ self.vocab = ds.vocab
29
+ ds.items = [self.process_one(item) for item in ds.items]
30
+
31
+ class Midi2MultitrackProcessor(PreProcessor):
32
+ "Converts midi files to multitrack items"
33
+ def process_one(self, midi_file):
34
+ try:
35
+ item = MultitrackItem.from_file(midi_file, vocab=self.vocab)
36
+ except Exception as e:
37
+ print(e)
38
+ return None
39
+ return item.to_idx()
40
+
41
+ def process(self, ds):
42
+ self.vocab = ds.vocab
43
+ ds.items = [self.process_one(item) for item in ds.items]
44
+ ds.items = [i for i in ds.items if i is not None]
45
+
46
+ class S2SPreloader(Callback):
47
+ def __init__(self, dataset:LabelList, bptt:int=512,
48
+ transpose_range=None, **kwargs):
49
+ self.dataset,self.bptt = dataset,bptt
50
+ self.vocab = self.dataset.vocab
51
+ self.transpose_range = transpose_range
52
+ self.rand_transpose = partial(rand_transpose_value, rand_range=transpose_range) if transpose_range is not None else None
53
+
54
+ def __getitem__(self, k:int):
55
+ item,empty_label = self.dataset[k]
56
+
57
+ if self.rand_transpose is not None:
58
+ val = self.rand_transpose()
59
+ item = item.transpose(val)
60
+ item = item.pad_to(self.bptt+1)
61
+ ((m_x, m_pos), (c_x, c_pos)) = item.to_idx()
62
+ return m_x, m_pos, c_x, c_pos
63
+
64
+ def __len__(self):
65
+ return len(self.dataset)
66
+
67
+ def rand_transpose_value(rand_range=(0,24), p=0.5):
68
+ if np.random.rand() < p: return np.random.randint(*rand_range)-rand_range[1]//2
69
+ return 0
70
+
71
+ class S2SItemList(MusicItemList):
72
+ _bunch = MusicDataBunch
73
+ def get(self, i):
74
+ return MultitrackItem.from_idx(self.items[i], self.vocab)
75
+
76
+ # DATALOADING AND TRANSFORMATIONS
77
+ # These transforms happen on batch
78
+
79
+ def mask_tfm(b, mask_range, mask_idx, pad_idx, p=0.3):
80
+ # mask range (min, max)
81
+ # replacement vals - [x_replace, y_replace]. Usually [mask_idx, pad_idx]
82
+ # p = replacement probability
83
+ x,y = b
84
+ x,y = x.clone(),y.clone()
85
+ rand = torch.rand(x.shape, device=x.device)
86
+ rand[x < mask_range[0]] = 1.0
87
+ rand[x >= mask_range[1]] = 1.0
88
+
89
+ # p(15%) of words are replaced. Of those p(15%) - 80% are masked. 10% wrong word. 10% unchanged
90
+ y[rand > p] = pad_idx # pad unchanged 80%. Remove these from loss/acc metrics
91
+ x[rand <= (p*.8)] = mask_idx # 80% = mask
92
+ wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
93
+ x[wrong_word] = torch.randint(*mask_range, [wrong_word.sum().item()], device=x.device)
94
+ return x, y
95
+
96
+ def mask_lm_tfm_default(b, vocab, mask_p=0.3):
97
+ return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)
98
+
99
+ def mask_lm_tfm_pitchdur(b, vocab, mask_p=0.9):
100
+ mask_range = vocab.dur_range if np.random.rand() < 0.5 else vocab.note_range
101
+ return mask_lm_tfm(b, mask_range=mask_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)
102
+
103
+ def mask_lm_tfm(b, mask_range, mask_idx, pad_idx, mask_p):
104
+ x,y = b
105
+ x_lm,x_pos = x[...,0], x[...,1]
106
+ y_lm,y_pos = y[...,0], y[...,1]
107
+
108
+ # Note: masking y_lm instead of x_lm. Just in case we ever do sequential s2s training
109
+ x_msk, y_msk = mask_tfm((y_lm, y_lm), mask_range=mask_range, mask_idx=mask_idx, pad_idx=pad_idx, p=mask_p)
110
+ msk_pos = y_pos
111
+
112
+ x_dict = {
113
+ 'msk': { 'x': x_msk, 'pos': msk_pos },
114
+ 'lm': { 'x': x_lm, 'pos': msk_pos }
115
+ }
116
+ y_dict = { 'msk': y_msk, 'lm': y_lm }
117
+ return x_dict, y_dict
118
+
119
+ def melody_chord_tfm(b):
120
+ m,m_pos,c,c_pos = b
121
+
122
+ # offset x and y for next word prediction
123
+ y_m = m[:,1:]
124
+ x_m, m_pos = m[:,:-1], m_pos[:,:-1]
125
+
126
+ y_c = c[:,1:]
127
+ x_c, c_pos = c[:,:-1], c_pos[:,:-1]
128
+
129
+ x_dict = {
130
+ 'c2m': {
131
+ 'enc': x_c,
132
+ 'enc_pos': c_pos,
133
+ 'dec': x_m,
134
+ 'dec_pos': m_pos
135
+ },
136
+ 'm2c': {
137
+ 'enc': x_m,
138
+ 'enc_pos': m_pos,
139
+ 'dec': x_c,
140
+ 'dec_pos': c_pos
141
+ }
142
+ }
143
+ y_dict = {
144
+ 'c2m': y_m, 'm2c': y_c
145
+ }
146
+ return x_dict, y_dict
utils/musicautobot/multitask_transformer/learner.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basics import *
2
+ from ..vocab import *
3
+ from ..utils.top_k_top_p import top_k_top_p
4
+ from ..utils.midifile import is_empty_midi
5
+ from ..music_transformer.transform import *
6
+ from ..music_transformer.learner import filter_invalid_indexes
7
+ from .model import get_multitask_model
8
+ from .dataloader import *
9
+
10
+ def multitask_model_learner(data:DataBunch, config:dict=None, drop_mult:float=1.,
11
+ pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
12
+ "Create a `Learner` with a language model from `data` and `arch`."
13
+ vocab = data.vocab
14
+ vocab_size = len(vocab)
15
+
16
+ if pretrained_path:
17
+ state = torch.load(pretrained_path, map_location='cpu')
18
+ if config is None: config = state['config']
19
+
20
+ model = get_multitask_model(vocab_size, config=config, drop_mult=drop_mult, pad_idx=vocab.pad_idx)
21
+ metrics = [AverageMultiMetric(partial(m, pad_idx=vocab.pad_idx)) for m in [mask_acc, lm_acc, c2m_acc, m2c_acc]]
22
+ loss_func = MultiLoss(ignore_index=data.vocab.pad_idx)
23
+ learn = MultitaskLearner(data, model, loss_func=loss_func, metrics=metrics, **learn_kwargs)
24
+
25
+ if pretrained_path:
26
+ get_model(model).load_state_dict(state['model'], strict=False)
27
+ if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
28
+ try: learn.opt.load_state_dict(state['opt'])
29
+ except: pass
30
+ del state
31
+ gc.collect()
32
+
33
+ return learn
34
+
35
+ class MultitaskLearner(Learner):
36
+ def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
37
+ "Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
38
+ out_path = super().save(file, return_path=True, with_opt=with_opt)
39
+ if config and out_path:
40
+ state = torch.load(out_path)
41
+ state['config'] = config
42
+ torch.save(state, out_path)
43
+ del state
44
+ gc.collect()
45
+ return out_path
46
+
47
+ def predict_nw(self, item:MusicItem, n_words:int=128,
48
+ temperatures:float=(1.0,1.0), min_bars=4,
49
+ top_k=30, top_p=0.6):
50
+ "Return the `n_words` that come after `text`."
51
+ self.model.reset()
52
+ new_idx = []
53
+ vocab = self.data.vocab
54
+ x, pos = item.to_tensor(), item.get_pos_tensor()
55
+ last_pos = pos[-1] if len(pos) else 0
56
+ y = torch.tensor([0])
57
+
58
+ start_pos = last_pos
59
+
60
+ sep_count = 0
61
+ bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
62
+ vocab = self.data.vocab
63
+
64
+ repeat_count = 0
65
+
66
+ for i in progress_bar(range(n_words), leave=True):
67
+ batch = { 'lm': { 'x': x[None], 'pos': pos[None] } }, y
68
+ logits = self.pred_batch(batch=batch)['lm'][-1][-1]
69
+
70
+ prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
71
+
72
+ # Temperature
73
+ # Use first temperatures value if last prediction was duration
74
+ temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
75
+ repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
76
+ temperature += repeat_penalty
77
+ if temperature != 1.: logits = logits / temperature
78
+
79
+
80
+ # Filter
81
+ # bar = 16 beats
82
+ filter_value = -float('Inf')
83
+ if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
84
+
85
+ logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
86
+ logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
87
+
88
+ # Sample
89
+ probs = F.softmax(logits, dim=-1)
90
+ idx = torch.multinomial(probs, 1).item()
91
+
92
+ # Update repeat count
93
+ num_choices = len(probs.nonzero().view(-1))
94
+ if num_choices <= 2: repeat_count += 1
95
+ else: repeat_count = repeat_count // 2
96
+
97
+ if prev_idx==vocab.sep_idx:
98
+ duration = idx - vocab.dur_range[0]
99
+ last_pos = last_pos + duration
100
+
101
+ bars_pred = (last_pos - start_pos) // 16
102
+ abs_bar = last_pos // 16
103
+ # if (bars % 8 == 0) and (bars_pred > min_bars): break
104
+ if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
105
+
106
+
107
+ if idx==vocab.bos_idx:
108
+ print('Predicted BOS token. Returning prediction...')
109
+ break
110
+
111
+ new_idx.append(idx)
112
+ x = x.new_tensor([idx])
113
+ pos = pos.new_tensor([last_pos])
114
+
115
+ pred = vocab.to_music_item(np.array(new_idx))
116
+ full = item.append(pred)
117
+ return pred, full
118
+
119
+ def predict_mask(self, masked_item:MusicItem,
120
+ temperatures:float=(1.0,1.0),
121
+ top_k=20, top_p=0.8):
122
+ x = masked_item.to_tensor()
123
+ pos = masked_item.get_pos_tensor()
124
+ y = torch.tensor([0])
125
+ vocab = self.data.vocab
126
+ self.model.reset()
127
+ mask_idxs = (x == vocab.mask_idx).nonzero().view(-1)
128
+
129
+ repeat_count = 0
130
+
131
+ for midx in progress_bar(mask_idxs, leave=True):
132
+ prev_idx = x[midx-1]
133
+
134
+ # Using original positions, otherwise model gets too off track
135
+ # pos = torch.tensor(-position_enc(xb[0].cpu().numpy()), device=xb.device)[None]
136
+
137
+ # Next Word
138
+ logits = self.pred_batch(batch=({ 'msk': { 'x': x[None], 'pos': pos[None] } }, y) )['msk'][0][midx]
139
+
140
+ # Temperature
141
+ # Use first temperatures value if last prediction was duration
142
+ temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
143
+ repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
144
+ temperature += repeat_penalty
145
+ if temperature != 1.: logits = logits / temperature
146
+
147
+ # Filter
148
+ filter_value = -float('Inf')
149
+ special_idxs = [vocab.bos_idx, vocab.sep_idx, vocab.stoi[EOS]]
150
+ logits[special_idxs] = filter_value # Don't allow any special tokens (as we are only removing notes and durations)
151
+ logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
152
+ logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
153
+
154
+ # Sampling
155
+ probs = F.softmax(logits, dim=-1)
156
+ idx = torch.multinomial(probs, 1).item()
157
+
158
+ # Update repeat count
159
+ num_choices = len(probs.nonzero().view(-1))
160
+ if num_choices <= 2: repeat_count += 1
161
+ else: repeat_count = repeat_count // 2
162
+
163
+ x[midx] = idx
164
+
165
+ return vocab.to_music_item(x.cpu().numpy())
166
+
167
+ def predict_s2s(self, input_item:MusicItem, target_item:MusicItem, n_words:int=256,
168
+ temperatures:float=(1.0,1.0), top_k=30, top_p=0.8,
169
+ use_memory=True):
170
+ vocab = self.data.vocab
171
+
172
+ # Input doesn't change. We can reuse the encoder output on each prediction
173
+ with torch.no_grad():
174
+ inp, inp_pos = input_item.to_tensor(), input_item.get_pos_tensor()
175
+ x_enc = self.model.encoder(inp[None], inp_pos[None])
176
+
177
+ # target
178
+ targ = target_item.data.tolist()
179
+ targ_pos = target_item.position.tolist()
180
+ last_pos = targ_pos[-1]
181
+ self.model.reset()
182
+
183
+ repeat_count = 0
184
+
185
+ max_pos = input_item.position[-1] + SAMPLE_FREQ * 4 # Only predict until both tracks/parts have the same length
186
+ x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
187
+
188
+ for i in progress_bar(range(n_words), leave=True):
189
+ # Predict
190
+ with torch.no_grad():
191
+ dec = self.model.decoder(x[None], pos[None], x_enc)
192
+ logits = self.model.head(dec)[-1, -1]
193
+
194
+ # Temperature
195
+ # Use first temperatures value if last prediction was duration
196
+ prev_idx = targ[-1] if len(targ) else vocab.pad_idx
197
+ temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
198
+ repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
199
+ temperature += repeat_penalty
200
+ if temperature != 1.: logits = logits / temperature
201
+
202
+ # Filter
203
+ filter_value = -float('Inf')
204
+ logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
205
+ logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
206
+
207
+ # Sample
208
+ probs = F.softmax(logits, dim=-1)
209
+ idx = torch.multinomial(probs, 1).item()
210
+
211
+ # Update repeat count
212
+ num_choices = len(probs.nonzero().view(-1))
213
+ if num_choices <= 2: repeat_count += 1
214
+ else: repeat_count = repeat_count // 2
215
+
216
+ if idx == vocab.bos_idx | idx == vocab.stoi[EOS]:
217
+ print('Predicting BOS/EOS')
218
+ break
219
+
220
+ if prev_idx == vocab.sep_idx:
221
+ duration = idx - vocab.dur_range[0]
222
+ last_pos = last_pos + duration
223
+ if last_pos > max_pos:
224
+ print('Predicted past counter-part length. Returning early')
225
+ break
226
+
227
+ targ_pos.append(last_pos)
228
+ targ.append(idx)
229
+
230
+ if use_memory:
231
+ # Relying on memory for kv. Only need last prediction index
232
+ x, pos = inp.new_tensor([targ[-1]]), inp_pos.new_tensor([targ_pos[-1]])
233
+ else:
234
+ # Reset memory after each prediction, since we feeding the whole sequence every time
235
+ self.model.reset()
236
+ x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
237
+
238
+ return vocab.to_music_item(np.array(targ))
239
+
240
+ # High level prediction functions from midi file
241
+ def nw_predict_from_midi(learn, midi=None, n_words=400,
242
+ temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
243
+ vocab = learn.data.vocab
244
+ seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
245
+ if seed_len is not None: seed = seed.trim_to_beat(seed_len)
246
+
247
+ pred, full = learn.predict_nw(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
248
+ return full
249
+
250
+ def s2s_predict_from_midi(learn, midi=None, n_words=200,
251
+ temperatures=(1.0,1.0), top_k=24, top_p=0.7, seed_len=None, pred_melody=True, **kwargs):
252
+ multitrack_item = MultitrackItem.from_file(midi, learn.data.vocab)
253
+ melody, chords = multitrack_item.melody, multitrack_item.chords
254
+ inp, targ = (chords, melody) if pred_melody else (melody, chords)
255
+
256
+ # if seed_len is passed, cutoff sequence so we can predict the rest
257
+ if seed_len is not None: targ = targ.trim_to_beat(seed_len)
258
+ targ = targ.remove_eos()
259
+
260
+ pred = learn.predict_s2s(inp, targ, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
261
+
262
+ part_order = (pred, inp) if pred_melody else (inp, pred)
263
+ return MultitrackItem(*part_order)
264
+
265
+ def mask_predict_from_midi(learn, midi=None, predict_notes=True,
266
+ temperatures=(1.0,1.0), top_k=30, top_p=0.7, section=None, **kwargs):
267
+ item = MusicItem.from_file(midi, learn.data.vocab)
268
+ masked_item = item.mask_pitch(section) if predict_notes else item.mask_duration(section)
269
+ pred = learn.predict_mask(masked_item, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
270
+ return pred
271
+
272
+ # LOSS AND METRICS
273
+
274
+ class MultiLoss():
275
+ def __init__(self, ignore_index=None):
276
+ "Loss mult - Mask, NextWord, Seq2Seq"
277
+ self.loss = CrossEntropyFlat(ignore_index=ignore_index)
278
+
279
+ def __call__(self, inputs:Dict[str,Tensor], targets:Dict[str,Tensor])->Rank0Tensor:
280
+ losses = [self.loss(inputs[key], target) for key,target in targets.items()]
281
+ return sum(losses)
282
+
283
+ def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx)->Rank0Tensor:
284
+ if input is None or targ is None: return None
285
+ n = targ.shape[0]
286
+ input = input.argmax(dim=-1).view(n,-1)
287
+ targ = targ.view(n,-1)
288
+ mask = targ != pad_idx
289
+ return (input[mask]==targ[mask]).float().mean()
290
+
291
+ def acc_index(inputs, targets, key, pad_idx):
292
+ return acc_ignore_pad(inputs.get(key), targets.get(key), pad_idx)
293
+
294
+ def mask_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'msk', pad_idx)
295
+ def lm_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'lm', pad_idx)
296
+ def c2m_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'c2m', pad_idx)
297
+ def m2c_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'm2c', pad_idx)
298
+
299
+
300
+ class AverageMultiMetric(AverageMetric):
301
+ "Updated fastai.AverageMetric to support multi task metrics."
302
+ def on_batch_end(self, last_output, last_target, **kwargs):
303
+ "Update metric computation with `last_output` and `last_target`."
304
+ if not is_listy(last_target): last_target=[last_target]
305
+ val = self.func(last_output, *last_target)
306
+ if val is None: return
307
+ self.count += first_el(last_target).size(0)
308
+ if self.world:
309
+ val = val.clone()
310
+ dist.all_reduce(val, op=dist.ReduceOp.SUM)
311
+ val /= self.world
312
+ self.val += first_el(last_target).size(0) * val.detach().cpu()
313
+
314
+ def on_epoch_end(self, last_metrics, **kwargs):
315
+ "Set the final result in `last_metrics`."
316
+ if self.count == 0: return add_metrics(last_metrics, 0)
317
+ return add_metrics(last_metrics, self.val/self.count)
318
+
319
+
320
+ # MODEL LOADING
321
+ class MTTrainer(LearnerCallback):
322
+ "`Callback` that regroups lr adjustment to seq_len, AR and TAR."
323
+ def __init__(self, learn:Learner, dataloaders=None, starting_mask_window=1):
324
+ super().__init__(learn)
325
+ self.count = 1
326
+ self.mw_start = starting_mask_window
327
+ self.dataloaders = dataloaders
328
+
329
+ def on_epoch_begin(self, **kwargs):
330
+ "Reset the hidden state of the model."
331
+ model = get_model(self.learn.model)
332
+ model.reset()
333
+ model.encoder.mask_steps = max(self.count+self.mw_start, 100)
334
+
335
+ def on_epoch_end(self, last_metrics, **kwargs):
336
+ "Finish the computation and sends the result to the Recorder."
337
+ if self.dataloaders is not None:
338
+ self.learn.data = self.dataloaders[self.count % len(self.dataloaders)]
339
+ self.count += 1
340
+
utils/musicautobot/multitask_transformer/model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basics import *
2
+ from fastai.text.models.transformer import Activation, PositionalEncoding, feed_forward, init_transformer, _line_shift
3
+ from fastai.text.models.awd_lstm import RNNDropout
4
+ from ..utils.attention_mask import *
5
+
6
+ def get_multitask_model(vocab_size:int, config:dict=None, drop_mult:float=1., pad_idx=None):
7
+ "Create a language model from `arch` and its `config`, maybe `pretrained`."
8
+ for k in config.keys():
9
+ if k.endswith('_p'): config[k] *= drop_mult
10
+ n_hid = config['d_model']
11
+ mem_len = config.pop('mem_len')
12
+ embed = TransformerEmbedding(vocab_size, n_hid, embed_p=config['embed_p'], mem_len=mem_len, pad_idx=pad_idx)
13
+ encoder = MTEncoder(embed, n_hid, n_layers=config['enc_layers'], mem_len=0, **config) # encoder doesn't need memory
14
+ decoder = MTEncoder(embed, n_hid, is_decoder=True, n_layers=config['dec_layers'], mem_len=mem_len, **config)
15
+ head = MTLinearDecoder(n_hid, vocab_size, tie_encoder=embed.embed, **config)
16
+ model = MultiTransformer(encoder, decoder, head, mem_len=mem_len)
17
+ return model.apply(init_transformer)
18
+
19
+ class MultiTransformer(nn.Module):
20
+ "Multitask Transformer for training mask, next word, and sequence 2 sequence"
21
+ def __init__(self, encoder, decoder, head, mem_len):
22
+ super().__init__()
23
+ self.encoder = encoder
24
+ self.decoder = decoder
25
+ self.head = head
26
+ self.default_mem_len = mem_len
27
+ self.current_mem_len = None
28
+
29
+ def forward(self, inp):
30
+ # data order: mask, next word, melody, chord
31
+ outputs = {}
32
+ msk, lm, c2m, m2c = [inp.get(key) for key in ['msk', 'lm', 'c2m', 'm2c']]
33
+
34
+ if msk is not None:
35
+ outputs['msk'] = self.head(self.encoder(msk['x'], msk['pos']))
36
+ if lm is not None:
37
+ outputs['lm'] = self.head(self.decoder(lm['x'], lm['pos']))
38
+
39
+ if c2m is not None:
40
+ self.reset()
41
+ c2m_enc = self.encoder(c2m['enc'], c2m['enc_pos'])
42
+ c2m_dec = self.decoder(c2m['dec'], c2m['dec_pos'], c2m_enc)
43
+ outputs['c2m'] = self.head(c2m_dec)
44
+
45
+ if m2c is not None:
46
+ self.reset()
47
+ m2c_enc = self.encoder(m2c['enc'], m2c['enc_pos'])
48
+ m2c_dec = self.decoder(m2c['dec'], m2c['dec_pos'], m2c_enc)
49
+ outputs['m2c'] = self.head(m2c_dec)
50
+
51
+ return outputs
52
+
53
+ "A sequential module that passes the reset call to its children."
54
+ def reset(self):
55
+ for module in self.children():
56
+ reset_children(module)
57
+
58
+ def reset_children(mod):
59
+ if hasattr(mod, 'reset'): mod.reset()
60
+ for module in mod.children():
61
+ reset_children(module)
62
+
63
+ # COMPONENTS
64
+ class TransformerEmbedding(nn.Module):
65
+ "Embedding + positional encoding + dropout"
66
+ def __init__(self, vocab_size:int, emb_sz:int, embed_p:float=0., mem_len=512, beat_len=32, max_bar_len=1024, pad_idx=None):
67
+ super().__init__()
68
+ self.emb_sz = emb_sz
69
+ self.pad_idx = pad_idx
70
+
71
+ self.embed = nn.Embedding(vocab_size, emb_sz, padding_idx=pad_idx)
72
+ self.pos_enc = PositionalEncoding(emb_sz)
73
+ self.beat_len, self.max_bar_len = beat_len, max_bar_len
74
+ self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0)
75
+ self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0)
76
+
77
+ self.drop = nn.Dropout(embed_p)
78
+ self.mem_len = mem_len
79
+
80
+ def forward(self, inp, pos):
81
+ beat_enc = self.beat_enc(pos % self.beat_len)
82
+ bar_pos = pos // self.beat_len % self.max_bar_len
83
+ bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1
84
+ bar_enc = self.bar_enc((bar_pos))
85
+ emb = self.drop(self.embed(inp) + beat_enc + bar_enc)
86
+ return emb
87
+
88
+ def relative_pos_enc(self, emb):
89
+ # return torch.arange(640-1, -1, -1).float().cuda()
90
+ seq_len = emb.shape[1] + self.mem_len
91
+ pos = torch.arange(seq_len-1, -1, -1, device=emb.device, dtype=emb.dtype) # backwards (txl pos encoding)
92
+ return self.pos_enc(pos)
93
+
94
+ class MTLinearDecoder(nn.Module):
95
+ "To go on top of a RNNCore module and create a Language Model."
96
+ initrange=0.1
97
+
98
+ def __init__(self, n_hid:int, n_out:int, output_p:float, tie_encoder:nn.Module=None, out_bias:bool=True, **kwargs):
99
+ super().__init__()
100
+ self.decoder = nn.Linear(n_hid, n_out, bias=out_bias)
101
+ self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
102
+ self.output_dp = RNNDropout(output_p)
103
+ if out_bias: self.decoder.bias.data.zero_()
104
+ if tie_encoder: self.decoder.weight = tie_encoder.weight
105
+
106
+ def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
107
+ output = self.output_dp(input)
108
+ decoded = self.decoder(output)
109
+ return decoded
110
+
111
+
112
+ # DECODER TRANSLATE BLOCK
113
+ class MTEncoder(nn.Module):
114
+ def __init__(self, embed:nn.Module, n_hid:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,
115
+ resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=True, scale:bool=True,
116
+ act:Activation=Activation.ReLU, double_drop:bool=True, mem_len:int=512, is_decoder=False,
117
+ mask_steps=1, mask_p=0.3, **kwargs):
118
+ super().__init__()
119
+ self.embed = embed
120
+ self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
121
+ self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
122
+ self.n_layers,self.d_model = n_layers,d_model
123
+ self.layers = nn.ModuleList([MTEncoderBlock(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
124
+ ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, mem_len=mem_len,
125
+ ) for k in range(n_layers)])
126
+
127
+ self.mask_steps, self.mask_p = mask_steps, mask_p
128
+ self.is_decoder = is_decoder
129
+
130
+ nn.init.normal_(self.u, 0., 0.02)
131
+ nn.init.normal_(self.v, 0., 0.02)
132
+
133
+ def forward(self, x_lm, lm_pos, msk_emb=None):
134
+ bs,lm_len = x_lm.size()
135
+
136
+ lm_emb = self.embed(x_lm, lm_pos)
137
+ if msk_emb is not None and msk_emb.shape[1] > lm_emb.shape[1]:
138
+ pos_enc = self.embed.relative_pos_enc(msk_emb)
139
+ else:
140
+ pos_enc = self.embed.relative_pos_enc(lm_emb)
141
+
142
+ # Masks
143
+ if self.is_decoder:
144
+ lm_mask = rand_window_mask(lm_len, self.embed.mem_len, x_lm.device,
145
+ max_size=self.mask_steps, p=self.mask_p, is_eval=not self.training)
146
+ else:
147
+ lm_mask = None
148
+
149
+ for i, layer in enumerate(self.layers):
150
+ lm_emb = layer(lm_emb, msk_emb, lm_mask=lm_mask,
151
+ r=pos_enc, g_u=self.u, g_v=self.v)
152
+ return lm_emb
153
+
154
+ class MTEncoderBlock(nn.Module):
155
+ "Decoder block of a Transformer model."
156
+ #Can't use Sequential directly cause more than one input...
157
+ def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,
158
+ bias:bool=True, scale:bool=True, double_drop:bool=True, mem_len:int=512, mha2_mem_len=0, **kwargs):
159
+ super().__init__()
160
+ attn_cls = MemMultiHeadRelativeAttentionKV
161
+ self.mha1 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mem_len, r_mask=False)
162
+ self.mha2 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mha2_mem_len, r_mask=True)
163
+ self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, double_drop=double_drop)
164
+
165
+ def forward(self, enc_lm:Tensor, enc_msk:Tensor,
166
+ r=None, g_u=None, g_v=None,
167
+ msk_mask:Tensor=None, lm_mask:Tensor=None):
168
+
169
+ y_lm = self.mha1(enc_lm, enc_lm, enc_lm, r, g_u, g_v, mask=lm_mask)
170
+ if enc_msk is None: return y_lm
171
+ return self.ff(self.mha2(y_lm, enc_msk, enc_msk, r, g_u, g_v, mask=msk_mask))
172
+
173
+
174
+ # Attention Layer
175
+
176
+
177
+ # Attn
178
+
179
+ class MemMultiHeadRelativeAttentionKV(nn.Module):
180
+ "Attention Layer monster - relative positioning, keeps track of own memory, separate kv weights to support sequence2sequence decoding."
181
+ def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True,
182
+ scale:bool=True, mem_len:int=512, r_mask=True):
183
+ super().__init__()
184
+ d_head = ifnone(d_head, d_model//n_heads)
185
+ self.n_heads,self.d_head,self.scale = n_heads,d_head,scale
186
+
187
+ assert(d_model == d_head * n_heads)
188
+ self.q_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
189
+ self.k_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
190
+ self.v_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
191
+
192
+ self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)
193
+ self.ln = nn.LayerNorm(d_model)
194
+ self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)
195
+ self.r_mask = r_mask
196
+
197
+ self.mem_len = mem_len
198
+ self.prev_k = None
199
+ self.prev_v = None
200
+
201
+ def forward(self, q:Tensor, k:Tensor=None, v:Tensor=None,
202
+ r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None,
203
+ mask:Tensor=None, **kwargs):
204
+ if k is None: k = q
205
+ if v is None: v = q
206
+ return self.ln(q + self.drop_res(self._apply_attention(q, k, v, r, g_u, g_v, mask=mask, **kwargs)))
207
+
208
+ def mem_k(self, k):
209
+ if self.mem_len == 0: return k
210
+ if self.prev_k is None or (self.prev_k.shape[0] != k.shape[0]): # reset if wrong batch size
211
+ self.prev_k = k[:, -self.mem_len:]
212
+ return k
213
+ with torch.no_grad():
214
+ k_ext = torch.cat([self.prev_k, k], dim=1)
215
+ self.prev_k = k_ext[:, -self.mem_len:]
216
+ return k_ext.detach()
217
+
218
+ def mem_v(self, v):
219
+ if self.mem_len == 0: return v
220
+ if self.prev_v is None or (self.prev_v.shape[0] != v.shape[0]): # reset if wrong batch size
221
+ self.prev_v = v[:, -self.mem_len:]
222
+ return v
223
+ with torch.no_grad():
224
+ v_ext = torch.cat([self.prev_v, v], dim=1)
225
+ self.prev_v = v_ext[:, -self.mem_len:]
226
+ return v_ext.detach()
227
+
228
+ def reset(self):
229
+ self.prev_v = None
230
+ self.prev_k = None
231
+
232
+ def _apply_attention(self, q:Tensor, k:Tensor, v:Tensor,
233
+ r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None,
234
+ mask:Tensor=None, **kwargs):
235
+ #Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable
236
+ #parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states.
237
+ # bs,x_len,seq_len = q.size(0),q.size(1),r.size(0)
238
+ k = self.mem_k(k)
239
+ v = self.mem_v(v)
240
+ bs,x_len,seq_len = q.size(0),q.size(1),k.size(1)
241
+ wq,wk,wv = self.q_wgt(q),self.k_wgt(k),self.v_wgt(v)
242
+ wq = wq[:,-x_len:]
243
+ wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
244
+ wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
245
+ wkr = self.r_attn(r[-seq_len:])
246
+ wkr = wkr.view(seq_len, self.n_heads, self.d_head)
247
+ wkr = wkr.permute(1,2,0)
248
+ #### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)
249
+ AC = torch.matmul(wq+g_u,wk)
250
+ BD = _line_shift(torch.matmul(wq+g_v, wkr), mask=self.r_mask)
251
+ if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))
252
+ if mask is not None:
253
+ mask = mask[...,-seq_len:]
254
+ if hasattr(mask, 'bool'): mask = mask.bool()
255
+ attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
256
+ attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
257
+ attn_vec = torch.matmul(attn_prob, wv)
258
+ return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
utils/musicautobot/multitask_transformer/transform.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..music_transformer.transform import *
2
+
3
+ class MultitrackItem():
4
+ def __init__(self, melody:MusicItem, chords:MusicItem, stream=None):
5
+ self.melody,self.chords = melody, chords
6
+ self.vocab = melody.vocab
7
+ self._stream = stream
8
+
9
+ @classmethod
10
+ def from_file(cls, midi_file, vocab):
11
+ return cls.from_stream(file2stream(midi_file), vocab)
12
+
13
+ @classmethod
14
+ def from_stream(cls, stream, vocab):
15
+ if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts()
16
+ num_parts = len(stream.parts)
17
+ sort_pitch = False
18
+ if num_parts > 2:
19
+ raise ValueError('Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks')
20
+ elif num_parts == 1:
21
+ print('Warning: only 1 track found. Inferring melody/chords')
22
+ stream = separate_melody_chord(stream)
23
+ sort_pitch = False
24
+
25
+ mpart, cpart = stream2npenc_parts(stream, sort_pitch=sort_pitch)
26
+ return cls.from_npenc_parts(mpart, cpart, vocab, stream)
27
+
28
+ @classmethod
29
+ def from_npenc_parts(cls, mpart, cpart, vocab, stream=None):
30
+ mpart = npenc2idxenc(mpart, seq_type=SEQType.Melody, vocab=vocab, add_eos=False)
31
+ cpart = npenc2idxenc(cpart, seq_type=SEQType.Chords, vocab=vocab, add_eos=False)
32
+ return MultitrackItem(MusicItem(mpart, vocab), MusicItem(cpart, vocab), stream)
33
+
34
+ @classmethod
35
+ def from_idx(cls, item, vocab):
36
+ m, c = item
37
+ return MultitrackItem(MusicItem.from_idx(m, vocab), MusicItem.from_idx(c, vocab))
38
+ def to_idx(self): return np.array((self.melody.to_idx(), self.chords.to_idx()))
39
+
40
+ @property
41
+ def stream(self):
42
+ self._stream = self.to_stream() if self._stream is None else self._stream
43
+ return self._stream
44
+
45
+ def to_stream(self, bpm=120):
46
+ ps = self.melody.to_npenc(), self.chords.to_npenc()
47
+ ps = [npenc2chordarr(p) for p in ps]
48
+ chordarr = chordarr_combine_parts(ps)
49
+ return chordarr2stream(chordarr, bpm=bpm)
50
+
51
+
52
+ def show(self, format:str=None):
53
+ return self.stream.show(format)
54
+ def play(self): self.stream.show('midi')
55
+
56
+ def transpose(self, val):
57
+ return MultitrackItem(self.melody.transpose(val), self.chords.transpose(val))
58
+ def pad_to(self, val):
59
+ return MultitrackItem(self.melody.pad_to(val), self.chords.pad_to(val))
60
+ def trim_to_beat(self, beat):
61
+ return MultitrackItem(self.melody.trim_to_beat(beat), self.chords.trim_to_beat(beat))
62
+
63
+ def combine2chordarr(np1, np2, vocab):
64
+ if len(np1.shape) == 1: np1 = idxenc2npenc(np1, vocab)
65
+ if len(np2.shape) == 1: np2 = idxenc2npenc(np2, vocab)
66
+ p1 = npenc2chordarr(np1)
67
+ p2 = npenc2chordarr(np2)
68
+ return chordarr_combine_parts((p1, p2))
utils/musicautobot/music_transformer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dataloader import *
2
+ from .model import *
3
+ from .learner import *
utils/musicautobot/music_transformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (251 Bytes). View file
 
utils/musicautobot/music_transformer/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
utils/musicautobot/music_transformer/__pycache__/learner.cpython-310.pyc ADDED
Binary file (5.94 kB). View file
 
utils/musicautobot/music_transformer/__pycache__/model.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
utils/musicautobot/music_transformer/__pycache__/transform.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
utils/musicautobot/music_transformer/dataloader.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Fastai Language Model Databunch modified to work with music"
2
+ from fastai.basics import *
3
+ # from fastai.basic_data import DataBunch
4
+ from fastai.text.data import LMLabelList
5
+ from .transform import *
6
+ from ..vocab import MusicVocab
7
+
8
+
9
+ class MusicDataBunch(DataBunch):
10
+ "Create a `TextDataBunch` suitable for training a language model."
11
+ @classmethod
12
+ def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', no_check:bool=False, bs=64, val_bs:int=None,
13
+ num_workers:int=0, device:torch.device=None, collate_fn:Callable=data_collate,
14
+ dl_tfms:Optional[Collection[Callable]]=None, bptt:int=70,
15
+ preloader_cls=None, shuffle_dl=False, transpose_range=(0,12), **kwargs) -> DataBunch:
16
+ "Create a `TextDataBunch` in `path` from the `datasets` for language modelling."
17
+ datasets = cls._init_ds(train_ds, valid_ds, test_ds)
18
+ preloader_cls = MusicPreloader if preloader_cls is None else preloader_cls
19
+ val_bs = ifnone(val_bs, bs)
20
+ datasets = [preloader_cls(ds, shuffle=(i==0), bs=(bs if i==0 else val_bs), bptt=bptt, transpose_range=transpose_range, **kwargs)
21
+ for i,ds in enumerate(datasets)]
22
+ val_bs = bs
23
+ dl_tfms = [partially_apply_vocab(tfm, train_ds.vocab) for tfm in listify(dl_tfms)]
24
+ dls = [DataLoader(d, b, shuffle=shuffle_dl) for d,b in zip(datasets, (bs,val_bs,val_bs,val_bs)) if d is not None]
25
+ return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
26
+
27
+ @classmethod
28
+ def from_folder(cls, path:PathOrStr, extensions='.npy', **kwargs):
29
+ files = get_files(path, extensions=extensions, recurse=True);
30
+ return cls.from_files(files, path, **kwargs)
31
+
32
+ @classmethod
33
+ def from_files(cls, files, path, processors=None, split_pct=0.1,
34
+ vocab=None, list_cls=None, **kwargs):
35
+ if vocab is None: vocab = MusicVocab.create()
36
+ if list_cls is None: list_cls = MusicItemList
37
+ src = (list_cls(items=files, path=path, processor=processors, vocab=vocab)
38
+ .split_by_rand_pct(split_pct, seed=6)
39
+ .label_const(label_cls=LMLabelList))
40
+ return src.databunch(**kwargs)
41
+
42
+ @classmethod
43
+ def empty(cls, path, **kwargs):
44
+ vocab = MusicVocab.create()
45
+ src = MusicItemList([], path=path, vocab=vocab, ignore_empty=True).split_none()
46
+ return src.label_const(label_cls=LMLabelList).databunch()
47
+
48
+ def partially_apply_vocab(tfm, vocab):
49
+ if 'vocab' in inspect.getfullargspec(tfm).args:
50
+ return partial(tfm, vocab=vocab)
51
+ return tfm
52
+
53
+ class MusicItemList(ItemList):
54
+ _bunch = MusicDataBunch
55
+
56
+ def __init__(self, items:Iterator, vocab:MusicVocab=None, **kwargs):
57
+ super().__init__(items, **kwargs)
58
+ self.vocab = vocab
59
+ self.copy_new += ['vocab']
60
+
61
+ def get(self, i):
62
+ o = super().get(i)
63
+ if is_pos_enc(o):
64
+ return MusicItem.from_idx(o, self.vocab)
65
+ return MusicItem(o, self.vocab)
66
+
67
+ def is_pos_enc(idxenc):
68
+ if len(idxenc.shape) == 2 and idxenc.shape[0] == 2: return True
69
+ return idxenc.dtype == np.object and idxenc.shape == (2,)
70
+
71
+ class MusicItemProcessor(PreProcessor):
72
+ "`PreProcessor` that transforms numpy files to indexes for training"
73
+ def process_one(self,item):
74
+ item = MusicItem.from_npenc(item, vocab=self.vocab)
75
+ return item.to_idx()
76
+
77
+ def process(self, ds):
78
+ self.vocab = ds.vocab
79
+ super().process(ds)
80
+
81
+ class OpenNPFileProcessor(PreProcessor):
82
+ "`PreProcessor` that opens the filenames and read the texts."
83
+ def process_one(self,item):
84
+ return np.load(item, allow_pickle=True) if isinstance(item, Path) else item
85
+
86
+ class Midi2ItemProcessor(PreProcessor):
87
+ "Skips midi preprocessing step. And encodes midi files to MusicItems"
88
+ def process_one(self,item):
89
+ item = MusicItem.from_file(item, vocab=self.vocab)
90
+ return item.to_idx()
91
+
92
+ def process(self, ds):
93
+ self.vocab = ds.vocab
94
+ super().process(ds)
95
+
96
+ ## For npenc dataset
97
+ class MusicPreloader(Callback):
98
+ "Transforms the tokens in `dataset` to a stream of contiguous batches for language modelling."
99
+
100
+ class CircularIndex():
101
+ "Handles shuffle, direction of indexing, wraps around to head tail in the ragged array as needed"
102
+ def __init__(self, length:int, forward:bool): self.idx, self.forward = np.arange(length), forward
103
+ def __getitem__(self, i):
104
+ return self.idx[ i%len(self.idx) if self.forward else len(self.idx)-1-i%len(self.idx)]
105
+ def __len__(self) -> int: return len(self.idx)
106
+ def shuffle(self): np.random.shuffle(self.idx)
107
+
108
+ def __init__(self, dataset:LabelList, lengths:Collection[int]=None, bs:int=32, bptt:int=70, backwards:bool=False,
109
+ shuffle:bool=False, y_offset:int=1,
110
+ transpose_range=None, transpose_p=0.5,
111
+ encode_position=True,
112
+ **kwargs):
113
+ self.dataset,self.bs,self.bptt,self.shuffle,self.backwards,self.lengths = dataset,bs,bptt,shuffle,backwards,lengths
114
+ self.vocab = self.dataset.vocab
115
+ self.bs *= num_distrib() or 1
116
+ self.totalToks,self.ite_len,self.idx = int(0),None,None
117
+ self.y_offset = y_offset
118
+
119
+ self.transpose_range,self.transpose_p = transpose_range,transpose_p
120
+ self.encode_position = encode_position
121
+ self.bptt_len = self.bptt
122
+
123
+ self.allocate_buffers() # needed for valid_dl on distributed training - otherwise doesn't get initialized on first epoch
124
+
125
+ def __len__(self):
126
+ if self.ite_len is None:
127
+ if self.lengths is None: self.lengths = np.array([len(item) for item in self.dataset.x])
128
+ self.totalToks = self.lengths.sum()
129
+ self.ite_len = self.bs*int( math.ceil( self.totalToks/(self.bptt*self.bs) )) if self.item is None else 1
130
+ return self.ite_len
131
+
132
+ def __getattr__(self,k:str)->Any: return getattr(self.dataset, k)
133
+
134
+ def allocate_buffers(self):
135
+ "Create the ragged array that will be filled when we ask for items."
136
+ if self.ite_len is None: len(self)
137
+ self.idx = MusicPreloader.CircularIndex(len(self.dataset.x), not self.backwards)
138
+
139
+ # batch shape = (bs, bptt, 2 - [index, pos]) if encode_position. Else - (bs, bptt)
140
+ buffer_len = (2,) if self.encode_position else ()
141
+ self.batch = np.zeros((self.bs, self.bptt+self.y_offset) + buffer_len, dtype=np.int64)
142
+ self.batch_x, self.batch_y = self.batch[:,0:self.bptt], self.batch[:,self.y_offset:self.bptt+self.y_offset]
143
+ #ro: index of the text we're at inside our datasets for the various batches
144
+ self.ro = np.zeros(self.bs, dtype=np.int64)
145
+ #ri: index of the token we're at inside our current text for the various batches
146
+ self.ri = np.zeros(self.bs, dtype=np.int)
147
+
148
+ # allocate random transpose values. Need to allocate this before hand.
149
+ self.transpose_values = self.get_random_transpose_values()
150
+
151
+ def get_random_transpose_values(self):
152
+ if self.transpose_range is None: return None
153
+ n = len(self.dataset)
154
+ rt_arr = torch.randint(*self.transpose_range, (n,))-self.transpose_range[1]//2
155
+ mask = torch.rand(rt_arr.shape) > self.transpose_p
156
+ rt_arr[mask] = 0
157
+ return rt_arr
158
+
159
+ def on_epoch_begin(self, **kwargs):
160
+ if self.idx is None: self.allocate_buffers()
161
+ elif self.shuffle:
162
+ self.ite_len = None
163
+ self.idx.shuffle()
164
+ self.transpose_values = self.get_random_transpose_values()
165
+ self.bptt_len = self.bptt
166
+ self.idx.forward = not self.backwards
167
+
168
+ step = self.totalToks / self.bs
169
+ ln_rag, countTokens, i_rag = 0, 0, -1
170
+ for i in range(0,self.bs):
171
+ #Compute the initial values for ro and ri
172
+ while ln_rag + countTokens <= int(step * i):
173
+ countTokens += ln_rag
174
+ i_rag += 1
175
+ ln_rag = self.lengths[self.idx[i_rag]]
176
+ self.ro[i] = i_rag
177
+ self.ri[i] = ( ln_rag - int(step * i - countTokens) ) if self.backwards else int(step * i - countTokens)
178
+
179
+ #Training dl gets on_epoch_begin called, val_dl, on_epoch_end
180
+ def on_epoch_end(self, **kwargs): self.on_epoch_begin()
181
+
182
+ def __getitem__(self, k:int):
183
+ j = k % self.bs
184
+ if j==0:
185
+ if self.item is not None: return self.dataset[0]
186
+ if self.idx is None: self.on_epoch_begin()
187
+
188
+ self.ro[j],self.ri[j] = self.fill_row(not self.backwards, self.dataset.x, self.idx, self.batch[j][:self.bptt_len+self.y_offset],
189
+ self.ro[j], self.ri[j], overlap=1, lengths=self.lengths)
190
+ return self.batch_x[j][:self.bptt_len], self.batch_y[j][:self.bptt_len]
191
+
192
+ def fill_row(self, forward, items, idx, row, ro, ri, overlap, lengths):
193
+ "Fill the row with tokens from the ragged array. --OBS-- overlap != 1 has not been implemented"
194
+ ibuf = n = 0
195
+ ro -= 1
196
+ while ibuf < row.shape[0]:
197
+ ro += 1
198
+ ix = idx[ro]
199
+
200
+ item = items[ix]
201
+ if self.transpose_values is not None:
202
+ item = item.transpose(self.transpose_values[ix].item())
203
+
204
+ if self.encode_position:
205
+ # Positions are colomn stacked with indexes. This makes it easier to keep in sync
206
+ rag = np.stack([item.data, item.position], axis=1)
207
+ else:
208
+ rag = item.data
209
+
210
+ if forward:
211
+ ri = 0 if ibuf else ri
212
+ n = min(lengths[ix] - ri, row.shape[0] - ibuf)
213
+ row[ibuf:ibuf+n] = rag[ri:ri+n]
214
+ else:
215
+ ri = lengths[ix] if ibuf else ri
216
+ n = min(ri, row.size - ibuf)
217
+ row[ibuf:ibuf+n] = rag[ri-n:ri][::-1]
218
+ ibuf += n
219
+ return ro, ri + ((n-overlap) if forward else -(n-overlap))
220
+
221
+ def batch_position_tfm(b):
222
+ "Batch transform for training with positional encoding"
223
+ x,y = b
224
+ x = {
225
+ 'x': x[...,0],
226
+ 'pos': x[...,1]
227
+ }
228
+ return x, y[...,0]
229
+
utils/musicautobot/music_transformer/learner.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basics import *
2
+ from fastai.text.learner import LanguageLearner, get_language_model, _model_meta
3
+ from .model import *
4
+ from .transform import MusicItem
5
+ from ..numpy_encode import SAMPLE_FREQ
6
+ from ..utils.top_k_top_p import top_k_top_p
7
+ from ..utils.midifile import is_empty_midi
8
+
9
+ _model_meta[MusicTransformerXL] = _model_meta[TransformerXL] # copy over fastai's model metadata
10
+
11
+ def music_model_learner(data:DataBunch, arch=MusicTransformerXL, config:dict=None, drop_mult:float=1.,
12
+ pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
13
+ "Create a `Learner` with a language model from `data` and `arch`."
14
+ meta = _model_meta[arch]
15
+
16
+ if pretrained_path:
17
+ state = torch.load(pretrained_path, map_location='cpu')
18
+ if config is None: config = state['config']
19
+
20
+ model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)
21
+ learn = MusicLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)
22
+
23
+ if pretrained_path:
24
+ get_model(model).load_state_dict(state['model'], strict=False)
25
+ if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
26
+ try: learn.opt.load_state_dict(state['opt'])
27
+ except: pass
28
+ del state
29
+ gc.collect()
30
+
31
+ return learn
32
+
33
+ # Predictions
34
+ from fastai import basic_train # for predictions
35
+ class MusicLearner(LanguageLearner):
36
+ def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
37
+ "Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
38
+ out_path = super().save(file, return_path=True, with_opt=with_opt)
39
+ if config and out_path:
40
+ state = torch.load(out_path)
41
+ state['config'] = config
42
+ torch.save(state, out_path)
43
+ del state
44
+ gc.collect()
45
+ return out_path
46
+
47
+ def beam_search(self, xb:Tensor, n_words:int, top_k:int=10, beam_sz:int=10, temperature:float=1.,
48
+ ):
49
+ "Return the `n_words` that come after `text` using beam search."
50
+ self.model.reset()
51
+ self.model.eval()
52
+ xb_length = xb.shape[-1]
53
+ if xb.shape[0] > 1: xb = xb[0][None]
54
+ yb = torch.ones_like(xb)
55
+
56
+ nodes = None
57
+ xb = xb.repeat(top_k, 1)
58
+ nodes = xb.clone()
59
+ scores = xb.new_zeros(1).float()
60
+ with torch.no_grad():
61
+ for k in progress_bar(range(n_words), leave=False):
62
+ out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)
63
+ values, indices = out.topk(top_k, dim=-1)
64
+ scores = (-values + scores[:,None]).view(-1)
65
+ indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)
66
+ sort_idx = scores.argsort()[:beam_sz]
67
+ scores = scores[sort_idx]
68
+ nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),
69
+ indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)
70
+ nodes = nodes.view(-1, nodes.size(2))[sort_idx]
71
+ self.model[0].select_hidden(indices_idx[sort_idx])
72
+ xb = nodes[:,-1][:,None]
73
+ if temperature != 1.: scores.div_(temperature)
74
+ node_idx = torch.multinomial(torch.exp(-scores), 1).item()
75
+ return [i.item() for i in nodes[node_idx][xb_length:] ]
76
+
77
+ def predict(self, item:MusicItem, n_words:int=128,
78
+ temperatures:float=(1.0,1.0), min_bars=4,
79
+ top_k=30, top_p=0.6):
80
+ "Return the `n_words` that come after `text`."
81
+ self.model.reset()
82
+ new_idx = []
83
+ vocab = self.data.vocab
84
+ x, pos = item.to_tensor(), item.get_pos_tensor()
85
+ last_pos = pos[-1] if len(pos) else 0
86
+ y = torch.tensor([0])
87
+
88
+ start_pos = last_pos
89
+
90
+ sep_count = 0
91
+ bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
92
+ vocab = self.data.vocab
93
+
94
+ repeat_count = 0
95
+ if hasattr(self.model[0], 'encode_position'):
96
+ encode_position = self.model[0].encode_position
97
+ else: encode_position = False
98
+
99
+ for i in progress_bar(range(n_words), leave=True):
100
+ with torch.no_grad():
101
+ if encode_position:
102
+ batch = { 'x': x[None], 'pos': pos[None] }
103
+ logits = self.model(batch)[0][-1][-1]
104
+ else:
105
+ logits = self.model(x[None])[0][-1][-1]
106
+
107
+ prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
108
+
109
+ # Temperature
110
+ # Use first temperatures value if last prediction was duration
111
+ temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
112
+ repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
113
+ temperature += repeat_penalty
114
+ if temperature != 1.: logits = logits / temperature
115
+
116
+
117
+ # Filter
118
+ # bar = 16 beats
119
+ filter_value = -float('Inf')
120
+ if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
121
+
122
+ logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
123
+ logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
124
+
125
+ # Sample
126
+ probs = F.softmax(logits, dim=-1)
127
+ idx = torch.multinomial(probs, 1).item()
128
+
129
+ # Update repeat count
130
+ num_choices = len(probs.nonzero().view(-1))
131
+ if num_choices <= 2: repeat_count += 1
132
+ else: repeat_count = repeat_count // 2
133
+
134
+ if prev_idx==vocab.sep_idx:
135
+ duration = idx - vocab.dur_range[0]
136
+ last_pos = last_pos + duration
137
+
138
+ bars_pred = (last_pos - start_pos) // 16
139
+ abs_bar = last_pos // 16
140
+ # if (bars % 8 == 0) and (bars_pred > min_bars): break
141
+ if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
142
+
143
+
144
+ if idx==vocab.bos_idx:
145
+ print('Predicted BOS token. Returning prediction...')
146
+ break
147
+
148
+ new_idx.append(idx)
149
+ x = x.new_tensor([idx])
150
+ pos = pos.new_tensor([last_pos])
151
+
152
+ pred = vocab.to_music_item(np.array(new_idx))
153
+ full = item.append(pred)
154
+ return pred, full
155
+
156
+ # High level prediction functions from midi file
157
+ def predict_from_midi(learn, midi=None, n_words=400,
158
+ temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
159
+ vocab = learn.data.vocab
160
+ seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
161
+ if seed_len is not None: seed = seed.trim_to_beat(seed_len)
162
+
163
+ pred, full = learn.predict(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
164
+ return full
165
+
166
+ def filter_invalid_indexes(res, prev_idx, vocab, filter_value=-float('Inf')):
167
+ if vocab.is_duration_or_pad(prev_idx):
168
+ res[list(range(*vocab.dur_range))] = filter_value
169
+ else:
170
+ res[list(range(*vocab.note_range))] = filter_value
171
+ return res
utils/musicautobot/music_transformer/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basics import *
2
+ from fastai.text.models.transformer import TransformerXL
3
+ from ..utils.attention_mask import rand_window_mask
4
+
5
+ class MusicTransformerXL(TransformerXL):
6
+ "Exactly like fastai's TransformerXL, but with more aggressive attention mask: see `rand_window_mask`"
7
+ def __init__(self, *args, encode_position=True, mask_steps=1, **kwargs):
8
+ import inspect
9
+ sig = inspect.signature(TransformerXL)
10
+ arg_params = { k:kwargs[k] for k in sig.parameters if k in kwargs }
11
+ super().__init__(*args, **arg_params)
12
+
13
+ self.encode_position = encode_position
14
+ if self.encode_position: self.beat_enc = BeatPositionEncoder(kwargs['d_model'])
15
+
16
+ self.mask_steps=mask_steps
17
+
18
+
19
+ def forward(self, x):
20
+ #The hidden state has to be initiliazed in the forward pass for nn.DataParallel
21
+ if self.mem_len > 0 and not self.init:
22
+ self.reset()
23
+ self.init = True
24
+
25
+ benc = 0
26
+ if self.encode_position:
27
+ x,pos = x['x'], x['pos']
28
+ benc = self.beat_enc(pos)
29
+
30
+ bs,x_len = x.size()
31
+ inp = self.drop_emb(self.encoder(x) + benc) #.mul_(self.d_model ** 0.5)
32
+ m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
33
+ seq_len = m_len + x_len
34
+
35
+ mask = rand_window_mask(x_len, m_len, inp.device, max_size=self.mask_steps, is_eval=not self.training) if self.mask else None
36
+ if m_len == 0: mask[...,0,0] = 0
37
+ #[None,:,:None] for einsum implementation of attention
38
+ hids = []
39
+ pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
40
+ pos_enc = self.pos_enc(pos)
41
+ hids.append(inp)
42
+ for i, layer in enumerate(self.layers):
43
+ mem = self.hidden[i] if self.mem_len > 0 else None
44
+ inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
45
+ hids.append(inp)
46
+ core_out = inp[:,-x_len:]
47
+ if self.mem_len > 0 : self._update_mems(hids)
48
+ return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]
49
+
50
+
51
+ # Beat encoder
52
+ class BeatPositionEncoder(nn.Module):
53
+ "Embedding + positional encoding + dropout"
54
+ def __init__(self, emb_sz:int, beat_len=32, max_bar_len=1024):
55
+ super().__init__()
56
+
57
+ self.beat_len, self.max_bar_len = beat_len, max_bar_len
58
+ self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0)
59
+ self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0)
60
+
61
+ def forward(self, pos):
62
+ beat_enc = self.beat_enc(pos % self.beat_len)
63
+ bar_pos = pos // self.beat_len % self.max_bar_len
64
+ bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1
65
+ bar_enc = self.bar_enc((bar_pos))
66
+ return beat_enc + bar_enc
utils/musicautobot/music_transformer/transform.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..numpy_encode import *
2
+ import numpy as np
3
+ from enum import Enum
4
+ import torch
5
+ from ..vocab import *
6
+ from functools import partial
7
+
8
+ SEQType = Enum('SEQType', 'Mask, Sentence, Melody, Chords, Empty')
9
+
10
+ class MusicItem():
11
+ def __init__(self, data, vocab, stream=None, position=None):
12
+ self.data = data
13
+ self.vocab = vocab
14
+ self._stream = stream
15
+ self._position = position
16
+ def __repr__(self): return '\n'.join([
17
+ f'\n{self.__class__.__name__} - {self.data.shape}',
18
+ f'{self.vocab.textify(self.data[:10])}...'])
19
+ def __len__(self): return len(self.data)
20
+
21
+ @classmethod
22
+ def from_file(cls, midi_file, vocab):
23
+ return cls.from_stream(file2stream(midi_file), vocab)
24
+ @classmethod
25
+ def from_stream(cls, stream, vocab):
26
+ if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts()
27
+ chordarr = stream2chordarr(stream) # 2.
28
+ npenc = chordarr2npenc(chordarr) # 3.
29
+ return cls.from_npenc(npenc, vocab, stream)
30
+ @classmethod
31
+ def from_npenc(cls, npenc, vocab, stream=None): return MusicItem(npenc2idxenc(npenc, vocab), vocab, stream)
32
+
33
+ @classmethod
34
+ def from_idx(cls, item, vocab):
35
+ idx,pos = item
36
+ return MusicItem(idx, vocab=vocab, position=pos)
37
+ def to_idx(self): return self.data, self.position
38
+
39
+ @classmethod
40
+ def empty(cls, vocab, seq_type=SEQType.Sentence):
41
+ return MusicItem(seq_prefix(seq_type, vocab), vocab)
42
+
43
+ @property
44
+ def stream(self):
45
+ self._stream = self.to_stream() if self._stream is None else self._stream
46
+ return self._stream
47
+
48
+ def to_stream(self, bpm=120):
49
+ return idxenc2stream(self.data, self.vocab, bpm=bpm)
50
+
51
+ def to_tensor(self, device=None):
52
+ return to_tensor(self.data, device)
53
+
54
+ def to_text(self, sep=' '): return self.vocab.textify(self.data, sep)
55
+
56
+ @property
57
+ def position(self):
58
+ self._position = position_enc(self.data, self.vocab) if self._position is None else self._position
59
+ return self._position
60
+
61
+ def get_pos_tensor(self, device=None): return to_tensor(self.position, device)
62
+
63
+ def to_npenc(self):
64
+ return idxenc2npenc(self.data, self.vocab)
65
+
66
+ def show(self, format:str=None):
67
+ return self.stream.show(format)
68
+ def play(self): self.stream.show('midi')
69
+
70
+ #Added by caslabs
71
+ def download(self, filename:str=None, ext:str=None):
72
+ return self.stream.write('midi', fp=filename)
73
+
74
+ @property
75
+ def new(self):
76
+ return partial(type(self), vocab=self.vocab)
77
+
78
+ def trim_to_beat(self, beat, include_last_sep=False):
79
+ return self.new(trim_to_beat(self.data, self.position, self.vocab, beat, include_last_sep))
80
+
81
+ def transpose(self, interval):
82
+ return self.new(tfm_transpose(self.data, interval, self.vocab), position=self._position)
83
+
84
+ def append(self, item):
85
+ return self.new(np.concatenate((self.data, item.data), axis=0))
86
+
87
+ def mask_pitch(self, section=None):
88
+ return self.new(self.mask(self.vocab.note_range, section), position=self.position)
89
+
90
+ def mask_duration(self, section=None, keep_position_enc=True):
91
+ masked_data = self.mask(self.vocab.dur_range, section)
92
+ if keep_position_enc: return self.new(masked_data, position=self.position)
93
+ return self.new(masked_data)
94
+
95
+ def mask(self, token_range, section_range=None):
96
+ return mask_section(self.data, self.position, token_range, self.vocab.mask_idx, section_range=section_range)
97
+
98
+ def pad_to(self, bptt):
99
+ data = pad_seq(self.data, bptt, self.vocab.pad_idx)
100
+ pos = pad_seq(self.position, bptt, 0)
101
+ return self.new(data, stream=self._stream, position=pos)
102
+
103
+ def split_stream_parts(self):
104
+ self._stream = separate_melody_chord(self.stream)
105
+ return self.stream
106
+
107
+ def remove_eos(self):
108
+ if self.data[-1] == self.vocab.stoi[EOS]: return self.new(self.data, stream=self.stream)
109
+ return self
110
+
111
+ def split_parts(self):
112
+ return self.new(self.data, stream=separate_melody_chord(self.stream), position=self.position)
113
+
114
+ def pad_seq(seq, bptt, value):
115
+ pad_len = max(bptt-seq.shape[0], 0)
116
+ return np.pad(seq, (0, pad_len), 'constant', constant_values=value)[:bptt]
117
+
118
+ def to_tensor(t, device=None):
119
+ t = t if isinstance(t, torch.Tensor) else torch.tensor(t)
120
+ if device is None and torch.cuda.is_available(): t = t.cuda()
121
+ else: t.to(device)
122
+ return t.long()
123
+
124
+ def midi2idxenc(midi_file, vocab):
125
+ "Converts midi file to index encoding for training"
126
+ npenc = midi2npenc(midi_file) # 3.
127
+ return npenc2idxenc(npenc, vocab)
128
+
129
+ def idxenc2stream(arr, vocab, bpm=120):
130
+ "Converts index encoding to music21 stream"
131
+ npenc = idxenc2npenc(arr, vocab)
132
+ return npenc2stream(npenc, bpm=bpm)
133
+
134
+ # single stream instead of note,dur
135
+ def npenc2idxenc(t, vocab, seq_type=SEQType.Sentence, add_eos=False):
136
+ "Transforms numpy array from 2 column (note, duration) matrix to a single column"
137
+ "[[n1, d1], [n2, d2], ...] -> [n1, d1, n2, d2]"
138
+ if isinstance(t, (list, tuple)) and len(t) == 2:
139
+ return [npenc2idxenc(x, vocab, start_seq) for x in t]
140
+ t = t.copy()
141
+
142
+ t[:, 0] = t[:, 0] + vocab.note_range[0]
143
+ t[:, 1] = t[:, 1] + vocab.dur_range[0]
144
+
145
+ prefix = seq_prefix(seq_type, vocab)
146
+ suffix = np.array([vocab.stoi[EOS]]) if add_eos else np.empty(0, dtype=int)
147
+ return np.concatenate([prefix, t.reshape(-1), suffix])
148
+
149
+ def seq_prefix(seq_type, vocab):
150
+ if seq_type == SEQType.Empty: return np.empty(0, dtype=int)
151
+ start_token = vocab.bos_idx
152
+ if seq_type == SEQType.Chords: start_token = vocab.stoi[CSEQ]
153
+ if seq_type == SEQType.Melody: start_token = vocab.stoi[MSEQ]
154
+ return np.array([start_token, vocab.pad_idx])
155
+
156
+ def idxenc2npenc(t, vocab, validate=True):
157
+ if validate: t = to_valid_idxenc(t, vocab.npenc_range)
158
+ t = t.copy().reshape(-1, 2)
159
+ if t.shape[0] == 0: return t
160
+
161
+ t[:, 0] = t[:, 0] - vocab.note_range[0]
162
+ t[:, 1] = t[:, 1] - vocab.dur_range[0]
163
+
164
+ if validate: return to_valid_npenc(t)
165
+ return t
166
+
167
+ def to_valid_idxenc(t, valid_range):
168
+ r = valid_range
169
+ t = t[np.where((t >= r[0]) & (t < r[1]))]
170
+ if t.shape[-1] % 2 == 1: t = t[..., :-1]
171
+ return t
172
+
173
+ def to_valid_npenc(t):
174
+ is_note = (t[:, 0] < VALTSEP) | (t[:, 0] >= NOTE_SIZE)
175
+ invalid_note_idx = is_note.argmax()
176
+ invalid_dur_idx = (t[:, 1] < 0).argmax()
177
+
178
+ invalid_idx = max(invalid_dur_idx, invalid_note_idx)
179
+ if invalid_idx > 0:
180
+ if invalid_note_idx > 0 and invalid_dur_idx > 0: invalid_idx = min(invalid_dur_idx, invalid_note_idx)
181
+ print('Non midi note detected. Only returning valid portion. Index, seed', invalid_idx, t.shape)
182
+ return t[:invalid_idx]
183
+ return t
184
+
185
+ def position_enc(idxenc, vocab):
186
+ "Calculates positional beat encoding."
187
+ sep_idxs = (idxenc == vocab.sep_idx).nonzero()[0]
188
+ sep_idxs = sep_idxs[sep_idxs+2 < idxenc.shape[0]] # remove any indexes right before out of bounds (sep_idx+2)
189
+ dur_vals = idxenc[sep_idxs+1]
190
+ dur_vals[dur_vals == vocab.mask_idx] = vocab.dur_range[0] # make sure masked durations are 0
191
+ dur_vals -= vocab.dur_range[0]
192
+
193
+ posenc = np.zeros_like(idxenc)
194
+ posenc[sep_idxs+2] = dur_vals
195
+ return posenc.cumsum()
196
+
197
+ def beat2index(idxenc, pos, vocab, beat, include_last_sep=False):
198
+ cutoff = find_beat(pos, beat)
199
+ if cutoff < 2: return 2 # always leave starter tokens
200
+ if len(idxenc) < 2 or include_last_sep: return cutoff
201
+ if idxenc[cutoff - 2] == vocab.sep_idx: return cutoff - 2
202
+ return cutoff
203
+
204
+ def find_beat(pos, beat, sample_freq=SAMPLE_FREQ, side='left'):
205
+ return np.searchsorted(pos, beat * sample_freq, side=side)
206
+
207
+ # TRANSFORMS
208
+
209
+ def tfm_transpose(x, value, vocab):
210
+ x = x.copy()
211
+ x[(x >= vocab.note_range[0]) & (x < vocab.note_range[1])] += value
212
+ return x
213
+
214
+ def trim_to_beat(idxenc, pos, vocab, to_beat=None, include_last_sep=True):
215
+ if to_beat is None: return idxenc
216
+ cutoff = beat2index(idxenc, pos, vocab, to_beat, include_last_sep=include_last_sep)
217
+ return idxenc[:cutoff]
218
+
219
+ def mask_input(xb, mask_range, replacement_idx):
220
+ xb = xb.copy()
221
+ xb[(xb >= mask_range[0]) & (xb < mask_range[1])] = replacement_idx
222
+ return xb
223
+
224
+ def mask_section(xb, pos, token_range, replacement_idx, section_range=None):
225
+ xb = xb.copy()
226
+ token_mask = (xb >= token_range[0]) & (xb < token_range[1])
227
+
228
+ if section_range is None: section_range = (None, None)
229
+ section_mask = np.zeros_like(xb, dtype=bool)
230
+ start_idx = find_beat(pos, section_range[0]) if section_range[0] is not None else 0
231
+ end_idx = find_beat(pos, section_range[1]) if section_range[1] is not None else xb.shape[0]
232
+ section_mask[start_idx:end_idx] = True
233
+
234
+ xb[token_mask & section_mask] = replacement_idx
235
+ return xb
utils/musicautobot/numpy_encode.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Encoding music21 streams -> numpy array -> text"
2
+
3
+ # import re
4
+ import music21
5
+ import numpy as np
6
+ # from pathlib import Path
7
+
8
+ BPB = 4 # beats per bar
9
+ TIMESIG = f'{BPB}/4' # default time signature
10
+ PIANO_RANGE = (21, 108)
11
+ VALTSEP = -1 # separator value for numpy encoding
12
+ VALTCONT = -2 # numpy value for TCONT - needed for compressing chord array
13
+
14
+ SAMPLE_FREQ = 4
15
+ NOTE_SIZE = 128
16
+ DUR_SIZE = (10*BPB*SAMPLE_FREQ)+1 # Max length - 8 bars. Or 16 beats/quarternotes
17
+ MAX_NOTE_DUR = (8*BPB*SAMPLE_FREQ)
18
+
19
+ # Encoding process
20
+ # 1. midi -> music21.Stream
21
+ # 2. Stream -> numpy chord array (timestep X instrument X noterange)
22
+ # 3. numpy array -> List[Timestep][NoteEnc]
23
+ def midi2npenc(midi_file, skip_last_rest=True):
24
+ "Converts midi file to numpy encoding for language model"
25
+ stream = file2stream(midi_file) # 1.
26
+ chordarr = stream2chordarr(stream) # 2.
27
+ return chordarr2npenc(chordarr, skip_last_rest=skip_last_rest) # 3.
28
+
29
+ # Decoding process
30
+ # 1. NoteEnc -> numpy chord array
31
+ # 2. numpy array -> music21.Stream
32
+ def npenc2stream(arr, bpm=120):
33
+ "Converts numpy encoding to music21 stream"
34
+ chordarr = npenc2chordarr(np.array(arr)) # 1.
35
+ return chordarr2stream(chordarr, bpm=bpm) # 2.
36
+
37
+ ##### ENCODING ######
38
+
39
+ # 1. File To STream
40
+
41
+ def file2stream(fp):
42
+ if isinstance(fp, music21.midi.MidiFile): return music21.midi.translate.midiFileToStream(fp)
43
+ return music21.converter.parse(fp)
44
+
45
+ # 2.
46
+ def stream2chordarr(s, note_size=NOTE_SIZE, sample_freq=SAMPLE_FREQ, max_note_dur=MAX_NOTE_DUR):
47
+ "Converts music21.Stream to 1-hot numpy array"
48
+ # assuming 4/4 time
49
+ # note x instrument x pitch
50
+ # FYI: midi middle C value=60
51
+
52
+ # (AS) TODO: need to order by instruments most played and filter out percussion or include the channel
53
+ highest_time = max(s.flat.getElementsByClass('Note').highestTime, s.flat.getElementsByClass('Chord').highestTime)
54
+ maxTimeStep = round(highest_time * sample_freq)+1
55
+ score_arr = np.zeros((maxTimeStep, len(s.parts), NOTE_SIZE))
56
+
57
+ def note_data(pitch, note):
58
+ return (pitch.midi, int(round(note.offset*sample_freq)), int(round(note.duration.quarterLength*sample_freq)))
59
+
60
+ for idx,part in enumerate(s.parts):
61
+ notes=[]
62
+ for elem in part.flat:
63
+ if isinstance(elem, music21.note.Note):
64
+ notes.append(note_data(elem.pitch, elem))
65
+ if isinstance(elem, music21.chord.Chord):
66
+ for p in elem.pitches:
67
+ notes.append(note_data(p, elem))
68
+
69
+ # sort notes by offset (1), duration (2) so that hits are not overwritten and longer notes have priority
70
+ notes_sorted = sorted(notes, key=lambda x: (x[1], x[2]))
71
+ for n in notes_sorted:
72
+ if n is None: continue
73
+ pitch,offset,duration = n
74
+ if max_note_dur is not None and duration > max_note_dur: duration = max_note_dur
75
+ score_arr[offset, idx, pitch] = duration
76
+ score_arr[offset+1:offset+duration, idx, pitch] = VALTCONT # Continue holding note
77
+ return score_arr
78
+
79
+ def chordarr2npenc(chordarr, skip_last_rest=True):
80
+ # combine instruments
81
+ result = []
82
+ wait_count = 0
83
+ for idx,timestep in enumerate(chordarr):
84
+ flat_time = timestep2npenc(timestep)
85
+ if len(flat_time) == 0:
86
+ wait_count += 1
87
+ else:
88
+ # pitch, octave, duration, instrument
89
+ if wait_count > 0: result.append([VALTSEP, wait_count])
90
+ result.extend(flat_time)
91
+ wait_count = 1
92
+ if wait_count > 0 and not skip_last_rest: result.append([VALTSEP, wait_count])
93
+ return np.array(result, dtype=int).reshape(-1, 2) # reshaping. Just in case result is empty
94
+
95
+ # Note: not worrying about overlaps - as notes will still play. just look tied
96
+ # http://web.mit.edu/music21/doc/moduleReference/moduleStream.html#music21.stream.Stream.getOverlaps
97
+ def timestep2npenc(timestep, note_range=PIANO_RANGE, enc_type=None):
98
+ # inst x pitch
99
+ notes = []
100
+ for i,n in zip(*timestep.nonzero()):
101
+ d = timestep[i,n]
102
+ if d < 0: continue # only supporting short duration encoding for now
103
+ if n < note_range[0] or n >= note_range[1]: continue # must be within midi range
104
+ notes.append([n,d,i])
105
+
106
+ notes = sorted(notes, key=lambda x: x[0], reverse=True) # sort by note (highest to lowest)
107
+
108
+ if enc_type is None:
109
+ # note, duration
110
+ return [n[:2] for n in notes]
111
+ if enc_type == 'parts':
112
+ # note, duration, part
113
+ return [n for n in notes]
114
+ if enc_type == 'full':
115
+ # note_class, duration, octave, instrument
116
+ return [[n%12, d, n//12, i] for n,d,i in notes]
117
+
118
+ ##### DECODING #####
119
+
120
+ # 1.
121
+ def npenc2chordarr(npenc, note_size=NOTE_SIZE):
122
+ num_instruments = 1 if len(npenc.shape) <= 2 else npenc.max(axis=0)[-1]
123
+
124
+ max_len = npenc_len(npenc)
125
+ # score_arr = (steps, inst, note)
126
+ score_arr = np.zeros((max_len, num_instruments, note_size))
127
+
128
+ idx = 0
129
+ for step in npenc:
130
+ n,d,i = (step.tolist()+[0])[:3] # or n,d,i
131
+ if n < VALTSEP: continue # special token
132
+ if n == VALTSEP:
133
+ idx += d
134
+ continue
135
+ score_arr[idx,i,n] = d
136
+ return score_arr
137
+
138
+ def npenc_len(npenc):
139
+ duration = 0
140
+ for t in npenc:
141
+ if t[0] == VALTSEP: duration += t[1]
142
+ return duration + 1
143
+
144
+
145
+ # 2.
146
+ def chordarr2stream(arr, sample_freq=SAMPLE_FREQ, bpm=120):
147
+ duration = music21.duration.Duration(1. / sample_freq)
148
+ stream = music21.stream.Score()
149
+ stream.append(music21.meter.TimeSignature(TIMESIG))
150
+ stream.append(music21.tempo.MetronomeMark(number=bpm))
151
+ stream.append(music21.key.KeySignature(0))
152
+ for inst in range(arr.shape[1]):
153
+ p = partarr2stream(arr[:,inst,:], duration)
154
+ stream.append(p)
155
+ stream = stream.transpose(0)
156
+ return stream
157
+
158
+ # 2b.
159
+ def partarr2stream(partarr, duration):
160
+ "convert instrument part to music21 chords"
161
+ part = music21.stream.Part()
162
+ part.append(music21.instrument.Piano())
163
+ part_append_duration_notes(partarr, duration, part) # notes already have duration calculated
164
+
165
+ return part
166
+
167
+ def part_append_duration_notes(partarr, duration, stream):
168
+ "convert instrument part to music21 chords"
169
+ for tidx,t in enumerate(partarr):
170
+ note_idxs = np.where(t > 0)[0] # filter out any negative values (continuous mode)
171
+ if len(note_idxs) == 0: continue
172
+ notes = []
173
+ for nidx in note_idxs:
174
+ note = music21.note.Note(nidx)
175
+ note.duration = music21.duration.Duration(partarr[tidx,nidx]*duration.quarterLength)
176
+ notes.append(note)
177
+ for g in group_notes_by_duration(notes):
178
+ if len(g) == 1:
179
+ stream.insert(tidx*duration.quarterLength, g[0])
180
+ else:
181
+ chord = music21.chord.Chord(g)
182
+ stream.insert(tidx*duration.quarterLength, chord)
183
+ return stream
184
+
185
+ from itertools import groupby
186
+ # combining notes with different durations into a single chord may overwrite conflicting durations. Example: aylictal/still-waters-run-deep
187
+ def group_notes_by_duration(notes):
188
+ "separate notes into chord groups"
189
+ keyfunc = lambda n: n.duration.quarterLength
190
+ notes = sorted(notes, key=keyfunc)
191
+ return [list(g) for k,g in groupby(notes, keyfunc)]
192
+
193
+
194
+ # Midi -> npenc Conversion helpers
195
+ def is_valid_npenc(npenc, note_range=PIANO_RANGE, max_dur=DUR_SIZE,
196
+ min_notes=32, input_path=None, verbose=True):
197
+ if len(npenc) < min_notes:
198
+ if verbose: print('Sequence too short:', len(npenc), input_path)
199
+ return False
200
+ if (npenc[:,1] >= max_dur).any():
201
+ if verbose: print(f'npenc exceeds max {max_dur} duration:', npenc[:,1].max(), input_path)
202
+ return False
203
+ # https://en.wikipedia.org/wiki/Scientific_pitch_notation - 88 key range - 21 = A0, 108 = C8
204
+ if ((npenc[...,0] > VALTSEP) & ((npenc[...,0] < note_range[0]) | (npenc[...,0] >= note_range[1]))).any():
205
+ print(f'npenc out of piano note range {note_range}:', input_path)
206
+ return False
207
+ return True
208
+
209
+ # seperates overlapping notes to different tracks
210
+ def remove_overlaps(stream, separate_chords=True):
211
+ if not separate_chords:
212
+ return stream.flat.makeVoices().voicesToParts()
213
+ return separate_melody_chord(stream)
214
+
215
+ # seperates notes and chords to different tracks
216
+ def separate_melody_chord(stream):
217
+ new_stream = music21.stream.Score()
218
+ if stream.timeSignature: new_stream.append(stream.timeSignature)
219
+ new_stream.append(stream.metronomeMarkBoundaries()[0][-1])
220
+ if stream.keySignature: new_stream.append(stream.keySignature)
221
+
222
+ melody_part = music21.stream.Part(stream.flat.getElementsByClass('Note'))
223
+ melody_part.insert(0, stream.getInstrument())
224
+ chord_part = music21.stream.Part(stream.flat.getElementsByClass('Chord'))
225
+ chord_part.insert(0, stream.getInstrument())
226
+ new_stream.append(melody_part)
227
+ new_stream.append(chord_part)
228
+ return new_stream
229
+
230
+ # processing functions for sanitizing data
231
+
232
+ def compress_chordarr(chordarr):
233
+ return shorten_chordarr_rests(trim_chordarr_rests(chordarr))
234
+
235
+ def trim_chordarr_rests(arr, max_rests=4, sample_freq=SAMPLE_FREQ):
236
+ # max rests is in quarter notes
237
+ # max 1 bar between song start and end
238
+ start_idx = 0
239
+ max_sample = max_rests*sample_freq
240
+ for idx,t in enumerate(arr):
241
+ if (t != 0).any(): break
242
+ start_idx = idx+1
243
+
244
+ end_idx = 0
245
+ for idx,t in enumerate(reversed(arr)):
246
+ if (t != 0).any(): break
247
+ end_idx = idx+1
248
+ start_idx = start_idx - start_idx % max_sample
249
+ end_idx = end_idx - end_idx % max_sample
250
+ # if start_idx > 0 or end_idx > 0: print('Trimming rests. Start, end:', start_idx, len(arr)-end_idx, end_idx)
251
+ return arr[start_idx:(len(arr)-end_idx)]
252
+
253
+ def shorten_chordarr_rests(arr, max_rests=8, sample_freq=SAMPLE_FREQ):
254
+ # max rests is in quarter notes
255
+ # max 2 bar pause
256
+ rest_count = 0
257
+ result = []
258
+ max_sample = max_rests*sample_freq
259
+ for timestep in arr:
260
+ if (timestep==0).all():
261
+ rest_count += 1
262
+ else:
263
+ if rest_count > max_sample:
264
+ # old_count = rest_count
265
+ rest_count = (rest_count % sample_freq) + max_sample
266
+ # print(f'Compressing rests: {old_count} -> {rest_count}')
267
+ for i in range(rest_count): result.append(np.zeros(timestep.shape))
268
+ rest_count = 0
269
+ result.append(timestep)
270
+ for i in range(rest_count): result.append(np.zeros(timestep.shape))
271
+ return np.array(result)
272
+
273
+ # sequence 2 sequence convenience functions
274
+
275
+ def stream2npenc_parts(stream, sort_pitch=True):
276
+ chordarr = stream2chordarr(stream)
277
+ _,num_parts,_ = chordarr.shape
278
+ parts = [part_enc(chordarr, i) for i in range(num_parts)]
279
+ return sorted(parts, key=avg_pitch, reverse=True) if sort_pitch else parts
280
+
281
+ def chordarr_combine_parts(parts):
282
+ max_ts = max([p.shape[0] for p in parts])
283
+ parts_padded = [pad_part_to(p, max_ts) for p in parts]
284
+ chordarr_comb = np.concatenate(parts_padded, axis=1)
285
+ return chordarr_comb
286
+
287
+ def pad_part_to(p, target_size):
288
+ pad_width = ((0,target_size-p.shape[0]),(0,0),(0,0))
289
+ return np.pad(p, pad_width, 'constant')
290
+
291
+ def part_enc(chordarr, part):
292
+ partarr = chordarr[:,part:part+1,:]
293
+ npenc = chordarr2npenc(partarr)
294
+ return npenc
295
+
296
+ def avg_tempo(t, sep_idx=VALTSEP):
297
+ avg = t[t[:, 0] == sep_idx][:, 1].sum()/t.shape[0]
298
+ avg = int(round(avg/SAMPLE_FREQ))
299
+ return 'mt'+str(min(avg, MTEMPO_SIZE-1))
300
+
301
+ def avg_pitch(t, sep_idx=VALTSEP):
302
+ return t[t[:, 0] > sep_idx][:, 0].mean()
utils/musicautobot/utils/__init__.py ADDED
File without changes
utils/musicautobot/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (176 Bytes). View file
 
utils/musicautobot/utils/__pycache__/attention_mask.cpython-310.pyc ADDED
Binary file (1.3 kB). View file
 
utils/musicautobot/utils/__pycache__/file_processing.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
utils/musicautobot/utils/__pycache__/midifile.cpython-310.pyc ADDED
Binary file (4.5 kB). View file
 
utils/musicautobot/utils/__pycache__/setup_musescore.cpython-310.pyc ADDED
Binary file (1.79 kB). View file
 
utils/musicautobot/utils/__pycache__/top_k_top_p.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
utils/musicautobot/utils/attention_mask.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ def window_mask(x_len, device, m_len=0, size=(1,1)):
5
+ win_size,k = size
6
+ mem_mask = torch.zeros((x_len,m_len), device=device)
7
+ tri_mask = torch.triu(torch.ones((x_len//win_size+1,x_len//win_size+1), device=device),diagonal=k)
8
+ window_mask = tri_mask.repeat_interleave(win_size,dim=0).repeat_interleave(win_size,dim=1)[:x_len,:x_len]
9
+ if x_len: window_mask[...,0] = 0 # Always allowing first index to see. Otherwise you'll get NaN loss
10
+ mask = torch.cat((mem_mask, window_mask), dim=1)[None,None]
11
+ return mask.bool() if hasattr(mask, 'bool') else mask.byte()
12
+
13
+ def rand_window_mask(x_len,m_len,device,max_size:int=None,p:float=0.2,is_eval:bool=False):
14
+ if is_eval or np.random.rand() >= p or max_size is None:
15
+ win_size,k = (1,1)
16
+ else: win_size,k = (np.random.randint(0,max_size)+1,0)
17
+ return window_mask(x_len, device, m_len, size=(win_size,k))
18
+
19
+ def lm_mask(x_len, device):
20
+ mask = torch.triu(torch.ones((x_len, x_len), device=device), diagonal=1)[None,None]
21
+ return mask.bool() if hasattr(mask, 'bool') else mask.byte()
utils/musicautobot/utils/file_processing.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Parallel processing for midi files"
2
+ import csv
3
+ from fastprogress.fastprogress import master_bar, progress_bar
4
+ from pathlib import Path
5
+ from pebble import ProcessPool
6
+ from concurrent.futures import TimeoutError
7
+ import numpy as np
8
+
9
+ # https://stackoverflow.com/questions/20991968/asynchronous-multiprocessing-with-a-worker-pool-in-python-how-to-keep-going-aft
10
+ def process_all(func, arr, timeout_func=None, total=None, max_workers=None, timeout=None):
11
+ with ProcessPool() as pool:
12
+ future = pool.map(func, arr, timeout=timeout)
13
+
14
+ iterator = future.result()
15
+ results = []
16
+ for i in progress_bar(range(len(arr)), total=len(arr)):
17
+ try:
18
+ result = next(iterator)
19
+ if result: results.append(result)
20
+ except StopIteration:
21
+ break
22
+ except TimeoutError as error:
23
+ if timeout_func: timeout_func(arr[i], error.args[1])
24
+ return results
25
+
26
+ def process_file(file_path, tfm_func=None, src_path=None, dest_path=None):
27
+ "Utility function that transforms midi file to numpy array."
28
+ output_file = Path(str(file_path).replace(str(src_path), str(dest_path))).with_suffix('.npy')
29
+ if output_file.exists(): return output_file
30
+ output_file.parent.mkdir(parents=True, exist_ok=True)
31
+
32
+ # Call tfm_func and save file
33
+ npenc = tfm_func(file_path)
34
+ if npenc is not None:
35
+ np.save(output_file, npenc)
36
+ return output_file
37
+
38
+ def arr2csv(arr, out_file):
39
+ "Convert metadata array to csv"
40
+ all_keys = {k for d in arr for k in d.keys()}
41
+ arr = [format_values(x) for x in arr]
42
+ with open(out_file, 'w') as f:
43
+ dict_writer = csv.DictWriter(f, list(all_keys))
44
+ dict_writer.writeheader()
45
+ dict_writer.writerows(arr)
46
+
47
+ def format_values(d):
48
+ "Format array values for csv encoding"
49
+ def format_value(v):
50
+ if isinstance(v, list): return ','.join(v)
51
+ return v
52
+ return {k:format_value(v) for k,v in d.items()}
utils/musicautobot/utils/lamb.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SOURCE: https://github.com/cybertronai/pytorch-lamb/
2
+
3
+ import collections
4
+ import math
5
+
6
+ import torch
7
+ from torch.optim import Optimizer
8
+
9
+
10
+ class Lamb(Optimizer):
11
+ r"""Implements Lamb algorithm.
12
+
13
+ It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_.
14
+
15
+ Arguments:
16
+ params (iterable): iterable of parameters to optimize or dicts defining
17
+ parameter groups
18
+ lr (float, optional): learning rate (default: 1e-3)
19
+ betas (Tuple[float, float], optional): coefficients used for computing
20
+ running averages of gradient and its square (default: (0.9, 0.999))
21
+ eps (float, optional): term added to the denominator to improve
22
+ numerical stability (default: 1e-8)
23
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
24
+ adam (bool, optional): always use trust ratio = 1, which turns this into
25
+ Adam. Useful for comparison purposes.
26
+
27
+ .. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes:
28
+ https://arxiv.org/abs/1904.00962
29
+ """
30
+
31
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-4,
32
+ weight_decay=0, adam=False):
33
+ if not 0.0 <= lr:
34
+ raise ValueError("Invalid learning rate: {}".format(lr))
35
+ if not 0.0 <= eps:
36
+ raise ValueError("Invalid epsilon value: {}".format(eps))
37
+ if not 0.0 <= betas[0] < 1.0:
38
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
39
+ if not 0.0 <= betas[1] < 1.0:
40
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
41
+ defaults = dict(lr=lr, betas=betas, eps=eps,
42
+ weight_decay=weight_decay)
43
+ self.adam = adam
44
+ super(Lamb, self).__init__(params, defaults)
45
+
46
+ def step(self, closure=None):
47
+ """Performs a single optimization step.
48
+
49
+ Arguments:
50
+ closure (callable, optional): A closure that reevaluates the model
51
+ and returns the loss.
52
+ """
53
+ loss = None
54
+ if closure is not None:
55
+ loss = closure()
56
+
57
+ for group in self.param_groups:
58
+ for p in group['params']:
59
+ if p.grad is None:
60
+ continue
61
+ grad = p.grad.data
62
+ if grad.is_sparse:
63
+ raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
64
+
65
+ state = self.state[p]
66
+
67
+ # State initialization
68
+ if len(state) == 0:
69
+ state['step'] = 0
70
+ # Exponential moving average of gradient values
71
+ state['exp_avg'] = torch.zeros_like(p.data)
72
+ # Exponential moving average of squared gradient values
73
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
74
+
75
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
76
+ beta1, beta2 = group['betas']
77
+
78
+ state['step'] += 1
79
+
80
+ if group['weight_decay'] != 0:
81
+ grad.add_(group['weight_decay'], p.data)
82
+
83
+ # Decay the first and second moment running average coefficient
84
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
85
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
86
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
87
+
88
+ bias_correction1 = 1 - beta1 ** state['step']
89
+ bias_correction2 = 1 - beta2 ** state['step']
90
+ # Apply bias to lr to avoid broadcast.
91
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
92
+
93
+ adam_step = exp_avg / denom
94
+ # L2 norm uses sum, but here since we're dividing, use mean to avoid overflow.
95
+ r1 = p.data.pow(2).mean().sqrt()
96
+ r2 = adam_step.pow(2).mean().sqrt()
97
+ r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10)
98
+ state['r1'] = r1
99
+ state['r2'] = r2
100
+ state['r'] = r
101
+ if self.adam:
102
+ r = 1
103
+
104
+ p.data.add_(-step_size * r, adam_step)
105
+
106
+ return loss
utils/musicautobot/utils/midifile.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Transform functions for raw midi files"
2
+ from enum import Enum
3
+ import music21
4
+
5
+ PIANO_TYPES = list(range(24)) + list(range(80, 96)) # Piano, Synths
6
+ PLUCK_TYPES = list(range(24, 40)) + list(range(104, 112)) # Guitar, Bass, Ethnic
7
+ BRIGHT_TYPES = list(range(40, 56)) + list(range(56, 80))
8
+
9
+ PIANO_RANGE = (21, 109) # https://en.wikipedia.org/wiki/Scientific_pitch_notation
10
+
11
+ class Track(Enum):
12
+ PIANO = 0 # discrete instruments - keyboard, woodwinds
13
+ PLUCK = 1 # continuous instruments with pitch bend: violin, trombone, synths
14
+ BRIGHT = 2
15
+ PERC = 3
16
+ UNDEF = 4
17
+
18
+ type2inst = {
19
+ # use print_music21_instruments() to see supported types
20
+ Track.PIANO: 0, # Piano
21
+ Track.PLUCK: 24, # Guitar
22
+ Track.BRIGHT: 40, # Violin
23
+ Track.PERC: 114, # Steel Drum
24
+ }
25
+
26
+ # INFO_TYPES = set(['TIME_SIGNATURE', 'KEY_SIGNATURE'])
27
+ INFO_TYPES = set(['TIME_SIGNATURE', 'KEY_SIGNATURE', 'SET_TEMPO'])
28
+
29
+ def file2mf(fp):
30
+ mf = music21.midi.MidiFile()
31
+ if isinstance(fp, bytes):
32
+ mf.readstr(fp)
33
+ else:
34
+ mf.open(fp)
35
+ mf.read()
36
+ mf.close()
37
+ return mf
38
+
39
+ def mf2stream(mf): return music21.midi.translate.midiFileToStream(mf)
40
+
41
+ def is_empty_midi(fp):
42
+ if fp is None: return False
43
+ mf = file2mf(fp)
44
+ return not any([t.hasNotes() for t in mf.tracks])
45
+
46
+ def num_piano_tracks(fp):
47
+ music_file = file2mf(fp)
48
+ note_tracks = [t for t in music_file.tracks if t.hasNotes() and get_track_type(t) == Track.PIANO]
49
+ return len(note_tracks)
50
+
51
+ def is_channel(t, c_val):
52
+ return any([c == c_val for c in t.getChannels()])
53
+
54
+ def track_sort(t): # sort by 1. variation of pitch, 2. number of notes
55
+ return len(unique_track_notes(t)), len(t.events)
56
+
57
+ def is_piano_note(pitch):
58
+ return (pitch >= PIANO_RANGE[0]) and (pitch < PIANO_RANGE[1])
59
+
60
+ def unique_track_notes(t):
61
+ return { e.pitch for e in t.events if e.pitch is not None }
62
+
63
+ def compress_midi_file(fp, cutoff=6, min_variation=3, supported_types=set([Track.PIANO, Track.PLUCK, Track.BRIGHT])):
64
+ music_file = file2mf(fp)
65
+
66
+ info_tracks = [t for t in music_file.tracks if not t.hasNotes()]
67
+ note_tracks = [t for t in music_file.tracks if t.hasNotes()]
68
+
69
+ if len(note_tracks) > cutoff:
70
+ note_tracks = sorted(note_tracks, key=track_sort, reverse=True)
71
+
72
+ supported_tracks = []
73
+ for idx,t in enumerate(note_tracks):
74
+ if len(supported_tracks) >= cutoff: break
75
+ track_type = get_track_type(t)
76
+ if track_type not in supported_types: continue
77
+ pitch_set = unique_track_notes(t)
78
+ if (len(pitch_set) < min_variation): continue # must have more than x unique notes
79
+ if not all(map(is_piano_note, pitch_set)): continue # must not contain midi notes outside of piano range
80
+ # if track_type == Track.UNDEF: print('Could not designate track:', fp, t)
81
+ change_track_instrument(t, type2inst[track_type])
82
+ supported_tracks.append(t)
83
+ if not supported_tracks: return None
84
+ music_file.tracks = info_tracks + supported_tracks
85
+ return music_file
86
+
87
+ def get_track_type(t):
88
+ if is_channel(t, 10): return Track.PERC
89
+ i = get_track_instrument(t)
90
+ if i in PIANO_TYPES: return Track.PIANO
91
+ if i in PLUCK_TYPES: return Track.PLUCK
92
+ if i in BRIGHT_TYPES: return Track.BRIGHT
93
+ return Track.UNDEF
94
+
95
+ def get_track_instrument(t):
96
+ for idx,e in enumerate(t.events):
97
+ if e.type == 'PROGRAM_CHANGE': return e.data
98
+ return None
99
+
100
+ def change_track_instrument(t, value):
101
+ for idx,e in enumerate(t.events):
102
+ if e.type == 'PROGRAM_CHANGE': e.data = value
103
+
104
+ def print_music21_instruments():
105
+ for i in range(200):
106
+ try: print(i, music21.instrument.instrumentFromMidiProgram(i))
107
+ except: pass
utils/musicautobot/utils/setup_musescore.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def setup_musescore(musescore_path=None):
2
+ if not is_ipython(): return
3
+
4
+ import platform
5
+ from music21 import environment
6
+ from pathlib import Path
7
+
8
+ system = platform.system()
9
+ if system == 'Linux':
10
+ import os
11
+ os.environ['QT_QPA_PLATFORM']='offscreen' # https://musescore.org/en/node/29041
12
+
13
+ existing_path = environment.get('musicxmlPath')
14
+ if existing_path: return
15
+ if musescore_path is None:
16
+ if system == 'Darwin':
17
+ app_paths = list(Path('/Applications').glob('MuseScore *.app'))
18
+ if len(app_paths): musescore_path = app_paths[-1]/'Contents/MacOS/mscore'
19
+ elif system == 'Linux':
20
+ musescore_path = '/usr/bin/musescore'
21
+
22
+ if musescore_path is None or not Path(musescore_path).exists():
23
+ print('Warning: Could not find musescore installation. Please install musescore (see README) and/or update music21 environment paths')
24
+ else :
25
+ environment.set('musicxmlPath', musescore_path)
26
+ environment.set('musescoreDirectPNGPath', musescore_path)
27
+
28
+ def is_ipython():
29
+ try: get_ipython
30
+ except: return False
31
+ return True
32
+
33
+ def is_colab():
34
+ try: import google.colab
35
+ except: return False
36
+ return True
37
+
38
+ def setup_fluidsynth():
39
+ from midi2audio import FluidSynth
40
+ from IPython.display import Audio
41
+
42
+ def play_wav(stream):
43
+ out_midi = stream.write('midi')
44
+ out_wav = str(Path(out_midi).with_suffix('.wav'))
45
+ FluidSynth("font.sf2").midi_to_audio(out_midi, out_wav)
46
+ return Audio(out_wav)
utils/musicautobot/utils/stacked_dataloader.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Dataloader wrapper that can combine and handle multiple dataloaders for multitask training"
2
+ from fastai.callback import Callback
3
+ from typing import Callable
4
+
5
+ __all__ = ['StackedDataBunch']
6
+
7
+ # DataLoading
8
+ class StackedDataBunch():
9
+ def __init__(self, dbs, num_it=100):
10
+ self.dbs = dbs
11
+ self.train_dl = StackedDataloader([db.train_dl for db in self.dbs], num_it)
12
+ self.valid_dl = StackedDataloader([db.valid_dl for db in self.dbs], num_it)
13
+ self.train_ds = None
14
+ self.path = dbs[0].path
15
+ self.device = dbs[0].device
16
+ self.vocab = dbs[0].vocab
17
+ self.empty_val = False
18
+
19
+ def add_tfm(self,tfm:Callable)->None:
20
+ for dl in self.dbs: dl.add_tfm(tfm)
21
+
22
+ def remove_tfm(self,tfm:Callable)->None:
23
+ for dl in self.dbs: dl.remove_tfm(tfm)
24
+
25
+ # Helper functions
26
+ class StackedDataset(Callback):
27
+ def __init__(self, dss):
28
+ self.dss = dss
29
+ def __getattribute__(self, attr):
30
+ if attr == 'dss': return super().__getattribute__(attr)
31
+ def redirected(*args, **kwargs):
32
+ for ds in self.dss:
33
+ if hasattr(ds, attr): getattr(ds, attr)(*args, **kwargs)
34
+ return redirected
35
+ def __len__(self)->int: return sum([len(ds) for ds in self.dss])
36
+ def __repr__(self): return '\n'.join([self.__class__.__name__] + [repr(ds) for ds in self.dss])
37
+
38
+ class StackedDataloader():
39
+ def __init__(self, dls, num_it=100):
40
+ self.dls = dls
41
+ self.dataset = StackedDataset([dl.dataset for dl in dls if hasattr(dl, 'dataset')])
42
+ self.num_it = num_it
43
+ self.dl_idx = -1
44
+
45
+ def __len__(self)->int: return sum([len(dl) for dl in self.dls])
46
+ def __getattr__(self, attr):
47
+ def redirected(*args, **kwargs):
48
+ for dl in self.dls:
49
+ if hasattr(dl, attr):
50
+ getattr(dl, attr)(*args, **kwargs)
51
+ return redirected
52
+
53
+ def __iter__(self):
54
+ "Process and returns items from `DataLoader`."
55
+ iters = [iter(dl) for dl in self.dls]
56
+ self.dl_idx = -1
57
+ while len(iters):
58
+ self.dl_idx = (self.dl_idx+1) % len(iters)
59
+ for b in range(self.num_it):
60
+ try:
61
+ yield next(iters[self.dl_idx])
62
+ except StopIteration as e:
63
+ iters.remove(iters[self.dl_idx])
64
+ break
65
+ # raise StopIteration
66
+
67
+ def new(self, **kwargs):
68
+ "Create a new copy of `self` with `kwargs` replacing current values."
69
+ new_dls = [dl.new(**kwargs) for dl in self.dls]
70
+ return StackedDataloader(new_dls, self.num_it)
utils/musicautobot/utils/top_k_top_p.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ __all__ = ['top_k_top_p']
5
+
6
+ # top_k + nucleus filter - https://twitter.com/thom_wolf/status/1124263861727760384?lang=en
7
+ # https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
8
+ def top_k_top_p(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
9
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
10
+ Args:
11
+ logits: logits distribution shape (vocabulary size)
12
+ top_k >0: keep only top k tokens with highest probability (top-k filtering).
13
+ top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
14
+ """
15
+ logits = logits.clone()
16
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
17
+ top_k = min(top_k, logits.size(-1)) # Safety check
18
+ if top_k > 0:
19
+ # Remove all tokens with a probability less than the last token of the top-k
20
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
21
+ logits[indices_to_remove] = filter_value
22
+
23
+ if top_p > 0.0:
24
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
25
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
26
+
27
+ # Remove tokens with cumulative probability above the threshold
28
+ sorted_indices_to_remove = cumulative_probs > top_p
29
+ # Shift the indices to the right to keep also the first token above the threshold
30
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
31
+ sorted_indices_to_remove[..., 0] = 0
32
+
33
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
34
+ logits[indices_to_remove] = filter_value
35
+ return logits
utils/musicautobot/vocab.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastai.basics import *
2
+ from .numpy_encode import *
3
+ from .music_transformer import transform
4
+
5
+ BOS = 'xxbos'
6
+ PAD = 'xxpad'
7
+ EOS = 'xxeos'
8
+ MASK = 'xxmask' # Used for BERT masked language modeling.
9
+ CSEQ = 'xxcseq' # Used for Seq2Seq translation - denotes start of chord sequence
10
+ MSEQ = 'xxmseq' # Used for Seq2Seq translation - denotes start of melody sequence
11
+
12
+ # Deprecated tokens. Kept for compatibility
13
+ S2SCLS = 'xxs2scls' # deprecated
14
+ NSCLS = 'xxnscls' # deprecated
15
+
16
+ SEP = 'xxsep' # Used to denote end of timestep (required for polyphony). separator idx = -1 (part of notes)
17
+
18
+ SPECIAL_TOKS = [BOS, PAD, EOS, S2SCLS, MASK, CSEQ, MSEQ, NSCLS, SEP] # Important: SEP token must be last
19
+
20
+ NOTE_TOKS = [f'n{i}' for i in range(NOTE_SIZE)]
21
+ DUR_TOKS = [f'd{i}' for i in range(DUR_SIZE)]
22
+ NOTE_START, NOTE_END = NOTE_TOKS[0], NOTE_TOKS[-1]
23
+ DUR_START, DUR_END = DUR_TOKS[0], DUR_TOKS[-1]
24
+
25
+ MTEMPO_SIZE = 10
26
+ MTEMPO_OFF = 'mt0'
27
+ MTEMPO_TOKS = [f'mt{i}' for i in range(MTEMPO_SIZE)]
28
+
29
+ # Vocab - token to index mapping
30
+ class MusicVocab():
31
+ "Contain the correspondence between numbers and tokens and numericalize."
32
+ def __init__(self, itos:Collection[str]):
33
+ self.itos = itos
34
+ self.stoi = {v:k for k,v in enumerate(self.itos)}
35
+
36
+ def numericalize(self, t:Collection[str]) -> List[int]:
37
+ "Convert a list of tokens `t` to their ids."
38
+ return [self.stoi[w] for w in t]
39
+
40
+ def textify(self, nums:Collection[int], sep=' ') -> List[str]:
41
+ "Convert a list of `nums` to their tokens."
42
+ items = [self.itos[i] for i in nums]
43
+ return sep.join(items) if sep is not None else items
44
+
45
+ def to_music_item(self, idxenc):
46
+ return transform.MusicItem(idxenc, self)
47
+
48
+ @property
49
+ def mask_idx(self): return self.stoi[MASK]
50
+ @property
51
+ def pad_idx(self): return self.stoi[PAD]
52
+ @property
53
+ def bos_idx(self): return self.stoi[BOS]
54
+ @property
55
+ def sep_idx(self): return self.stoi[SEP]
56
+ @property
57
+ def npenc_range(self): return (self.stoi[SEP], self.stoi[DUR_END]+1)
58
+ @property
59
+ def note_range(self): return self.stoi[NOTE_START], self.stoi[NOTE_END]+1
60
+ @property
61
+ def dur_range(self): return self.stoi[DUR_START], self.stoi[DUR_END]+1
62
+
63
+ def is_duration(self, idx):
64
+ return idx >= self.dur_range[0] and idx < self.dur_range[1]
65
+ def is_duration_or_pad(self, idx):
66
+ return idx == self.pad_idx or self.is_duration(idx)
67
+
68
+ def __getstate__(self):
69
+ return {'itos':self.itos}
70
+
71
+ def __setstate__(self, state:dict):
72
+ self.itos = state['itos']
73
+ self.stoi = {v:k for k,v in enumerate(self.itos)}
74
+
75
+ def __len__(self): return len(self.itos)
76
+
77
+ def save(self, path):
78
+ "Save `self.itos` in `path`"
79
+ pickle.dump(self.itos, open(path, 'wb'))
80
+
81
+ @classmethod
82
+ def create(cls) -> 'Vocab':
83
+ "Create a vocabulary from a set of `tokens`."
84
+ itos = SPECIAL_TOKS + NOTE_TOKS + DUR_TOKS + MTEMPO_TOKS
85
+ if len(itos)%8 != 0:
86
+ itos = itos + [f'dummy{i}' for i in range(len(itos)%8)]
87
+ return cls(itos)
88
+
89
+ @classmethod
90
+ def load(cls, path):
91
+ "Load the `Vocab` contained in `path`"
92
+ itos = pickle.load(open(path, 'rb'))
93
+ return cls(itos)