Spaces:
Build error
Build error
jerald
commited on
Commit
·
4819bc9
1
Parent(s):
fcd062e
source dump
Browse files- app.py +62 -0
- music_transformer.pth +3 -0
- requirements.txt +4 -0
- utils/.DS_Store +0 -0
- utils/musicautobot/.DS_Store +0 -0
- utils/musicautobot/__init__.py +3 -0
- utils/musicautobot/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/musicautobot/__pycache__/config.cpython-310.pyc +0 -0
- utils/musicautobot/__pycache__/numpy_encode.cpython-310.pyc +0 -0
- utils/musicautobot/__pycache__/vocab.cpython-310.pyc +0 -0
- utils/musicautobot/config.py +47 -0
- utils/musicautobot/multitask_transformer/__init__.py +3 -0
- utils/musicautobot/multitask_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/musicautobot/multitask_transformer/__pycache__/dataloader.cpython-310.pyc +0 -0
- utils/musicautobot/multitask_transformer/__pycache__/learner.cpython-310.pyc +0 -0
- utils/musicautobot/multitask_transformer/__pycache__/model.cpython-310.pyc +0 -0
- utils/musicautobot/multitask_transformer/__pycache__/transform.cpython-310.pyc +0 -0
- utils/musicautobot/multitask_transformer/dataloader.py +146 -0
- utils/musicautobot/multitask_transformer/learner.py +340 -0
- utils/musicautobot/multitask_transformer/model.py +258 -0
- utils/musicautobot/multitask_transformer/transform.py +68 -0
- utils/musicautobot/music_transformer/__init__.py +3 -0
- utils/musicautobot/music_transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/musicautobot/music_transformer/__pycache__/dataloader.cpython-310.pyc +0 -0
- utils/musicautobot/music_transformer/__pycache__/learner.cpython-310.pyc +0 -0
- utils/musicautobot/music_transformer/__pycache__/model.cpython-310.pyc +0 -0
- utils/musicautobot/music_transformer/__pycache__/transform.cpython-310.pyc +0 -0
- utils/musicautobot/music_transformer/dataloader.py +229 -0
- utils/musicautobot/music_transformer/learner.py +171 -0
- utils/musicautobot/music_transformer/model.py +66 -0
- utils/musicautobot/music_transformer/transform.py +235 -0
- utils/musicautobot/numpy_encode.py +302 -0
- utils/musicautobot/utils/__init__.py +0 -0
- utils/musicautobot/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/musicautobot/utils/__pycache__/attention_mask.cpython-310.pyc +0 -0
- utils/musicautobot/utils/__pycache__/file_processing.cpython-310.pyc +0 -0
- utils/musicautobot/utils/__pycache__/midifile.cpython-310.pyc +0 -0
- utils/musicautobot/utils/__pycache__/setup_musescore.cpython-310.pyc +0 -0
- utils/musicautobot/utils/__pycache__/top_k_top_p.cpython-310.pyc +0 -0
- utils/musicautobot/utils/attention_mask.py +21 -0
- utils/musicautobot/utils/file_processing.py +52 -0
- utils/musicautobot/utils/lamb.py +106 -0
- utils/musicautobot/utils/midifile.py +107 -0
- utils/musicautobot/utils/setup_musescore.py +46 -0
- utils/musicautobot/utils/stacked_dataloader.py +70 -0
- utils/musicautobot/utils/top_k_top_p.py +35 -0
- utils/musicautobot/vocab.py +93 -0
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.musicautobot.numpy_encode import *
|
2 |
+
from utils.musicautobot.utils.file_processing import process_all, process_file
|
3 |
+
from utils.musicautobot.config import *
|
4 |
+
from utils.musicautobot.music_transformer import *
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from midi2audio import FluidSynth
|
8 |
+
import tempfile
|
9 |
+
import os
|
10 |
+
|
11 |
+
# Bootloading model
|
12 |
+
data_path = Path('./')
|
13 |
+
data = MusicDataBunch.empty(data_path)
|
14 |
+
vocab = data.vocab
|
15 |
+
pretrained_path='./music_transformer.pth'
|
16 |
+
learn = music_model_learner(data, pretrained_path=pretrained_path, config=default_config())
|
17 |
+
|
18 |
+
|
19 |
+
def predict(seed_midi, n_words=400, temperature1=1.1, temperature2=0.4, min_bars=12, top_k=24, top_p=0.7):
|
20 |
+
# Load input MIDI file as MusicItem
|
21 |
+
cutoff_beat = 10
|
22 |
+
item = MusicItem.from_file(seed_midi.name, data.vocab)
|
23 |
+
seed_item = item.trim_to_beat(cutoff_beat)
|
24 |
+
|
25 |
+
# Generate prediction
|
26 |
+
pred, full = learn.predict(seed_item, n_words=n_words, temperatures=(temperature1, temperature2), min_bars=min_bars, top_k=top_k, top_p=top_p)
|
27 |
+
|
28 |
+
# Convert input MIDI to audio
|
29 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as seed_audio_temp:
|
30 |
+
FluidSynth("sound_font.sf2").midi_to_audio(seed_midi.name, seed_audio_temp.name)
|
31 |
+
|
32 |
+
# Save generated MIDI as temporary file
|
33 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.midi') as pred_midi_temp:
|
34 |
+
pred.stream.write('midi', fp=pred_midi_temp.name)
|
35 |
+
|
36 |
+
# Convert generated MIDI to audio
|
37 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as pred_audio_temp:
|
38 |
+
FluidSynth("sound_font.sf2").midi_to_audio(pred_midi_temp.name, pred_audio_temp.name)
|
39 |
+
|
40 |
+
# Cleanup temporary MIDI file
|
41 |
+
os.remove(pred_midi_temp.name)
|
42 |
+
|
43 |
+
return seed_audio_temp.name, pred_audio_temp.name
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
iface = gr.Interface(fn=predict,
|
48 |
+
inputs=[
|
49 |
+
gr.inputs.File(label="Seed MIDI"),
|
50 |
+
gr.inputs.Slider(50, 1000, step=10, default=400, label="Number of Words"),
|
51 |
+
gr.inputs.Slider(0.0, 2.0, step=0.1, default=1.1, label="Temperature 1"),
|
52 |
+
gr.inputs.Slider(0.0, 2.0, step=0.1, default=0.4, label="Temperature 2"),
|
53 |
+
gr.inputs.Slider(1, 32, step=1, default=12, label="Min Bars"),
|
54 |
+
gr.inputs.Slider(1, 50, step=1, default=24, label="Top K"),
|
55 |
+
gr.inputs.Slider(0.0, 1.0, step=0.1, default=0.7, label="Top P")
|
56 |
+
],
|
57 |
+
outputs=[
|
58 |
+
gr.outputs.Audio(type='filepath', label="Seed Audio"),
|
59 |
+
gr.outputs.Audio(type='filepath', label="Generated Audio")
|
60 |
+
],)
|
61 |
+
|
62 |
+
iface.launch()
|
music_transformer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9856190b46abee88440104c661349f577eca4754ae485b63cf77030772b0c8cf
|
3 |
+
size 657241884
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
midi2audio
|
3 |
+
music21
|
4 |
+
fastai
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/musicautobot/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/musicautobot/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .utils.setup_musescore import setup_musescore
|
2 |
+
|
3 |
+
setup_musescore()
|
utils/musicautobot/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (239 Bytes). View file
|
|
utils/musicautobot/__pycache__/config.cpython-310.pyc
ADDED
Binary file (1.25 kB). View file
|
|
utils/musicautobot/__pycache__/numpy_encode.cpython-310.pyc
ADDED
Binary file (9.77 kB). View file
|
|
utils/musicautobot/__pycache__/vocab.cpython-310.pyc
ADDED
Binary file (5.24 kB). View file
|
|
utils/musicautobot/config.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.text.models.transformer import tfmerXL_lm_config, Activation
|
2 |
+
# from .vocab import MusicVocab
|
3 |
+
|
4 |
+
def default_config():
|
5 |
+
config = tfmerXL_lm_config.copy()
|
6 |
+
config['act'] = Activation.GeLU
|
7 |
+
|
8 |
+
config['mem_len'] = 512
|
9 |
+
config['d_model'] = 512
|
10 |
+
config['d_inner'] = 2048
|
11 |
+
config['n_layers'] = 16
|
12 |
+
|
13 |
+
config['n_heads'] = 8
|
14 |
+
config['d_head'] = 64
|
15 |
+
|
16 |
+
return config
|
17 |
+
|
18 |
+
def music_config():
|
19 |
+
config = default_config()
|
20 |
+
config['encode_position'] = True
|
21 |
+
return config
|
22 |
+
|
23 |
+
def musicm_config():
|
24 |
+
config = music_config()
|
25 |
+
config['d_model'] = 768
|
26 |
+
config['d_inner'] = 3072
|
27 |
+
config['n_heads'] = 12
|
28 |
+
config['d_head'] = 64
|
29 |
+
config['n_layers'] = 12
|
30 |
+
return config
|
31 |
+
|
32 |
+
def multitask_config():
|
33 |
+
config = default_config()
|
34 |
+
config['bias'] = True
|
35 |
+
config['enc_layers'] = 8
|
36 |
+
config['dec_layers'] = 8
|
37 |
+
del config['n_layers']
|
38 |
+
return config
|
39 |
+
|
40 |
+
def multitaskm_config():
|
41 |
+
config = musicm_config()
|
42 |
+
config['bias'] = True
|
43 |
+
config['enc_layers'] = 12
|
44 |
+
config['dec_layers'] = 12
|
45 |
+
del config['n_layers']
|
46 |
+
return config
|
47 |
+
|
utils/musicautobot/multitask_transformer/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .dataloader import *
|
2 |
+
from .model import *
|
3 |
+
from .learner import *
|
utils/musicautobot/multitask_transformer/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (257 Bytes). View file
|
|
utils/musicautobot/multitask_transformer/__pycache__/dataloader.cpython-310.pyc
ADDED
Binary file (6.17 kB). View file
|
|
utils/musicautobot/multitask_transformer/__pycache__/learner.cpython-310.pyc
ADDED
Binary file (11.5 kB). View file
|
|
utils/musicautobot/multitask_transformer/__pycache__/model.cpython-310.pyc
ADDED
Binary file (11.4 kB). View file
|
|
utils/musicautobot/multitask_transformer/__pycache__/transform.cpython-310.pyc
ADDED
Binary file (3.72 kB). View file
|
|
utils/musicautobot/multitask_transformer/dataloader.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basics import *
|
2 |
+
from .transform import *
|
3 |
+
from ..music_transformer.dataloader import MusicDataBunch, MusicItemList
|
4 |
+
# Sequence 2 Sequence Translate
|
5 |
+
|
6 |
+
class S2SFileProcessor(PreProcessor):
|
7 |
+
"`PreProcessor` that opens the filenames and read the texts."
|
8 |
+
def process_one(self,item):
|
9 |
+
out = np.load(item, allow_pickle=True)
|
10 |
+
if out.shape != (2,): return None
|
11 |
+
if not 16 < len(out[0]) < 2048: return None
|
12 |
+
if not 16 < len(out[1]) < 2048: return None
|
13 |
+
return out
|
14 |
+
|
15 |
+
def process(self, ds:Collection):
|
16 |
+
ds.items = [self.process_one(item) for item in ds.items]
|
17 |
+
ds.items = [i for i in ds.items if i is not None] # filter out None
|
18 |
+
|
19 |
+
class S2SPartsProcessor(PreProcessor):
|
20 |
+
"Encodes midi file into 2 separate parts - melody and chords."
|
21 |
+
|
22 |
+
def process_one(self, item):
|
23 |
+
m, c = item
|
24 |
+
mtrack = MultitrackItem.from_npenc_parts(m, c, vocab=self.vocab)
|
25 |
+
return mtrack.to_idx()
|
26 |
+
|
27 |
+
def process(self, ds):
|
28 |
+
self.vocab = ds.vocab
|
29 |
+
ds.items = [self.process_one(item) for item in ds.items]
|
30 |
+
|
31 |
+
class Midi2MultitrackProcessor(PreProcessor):
|
32 |
+
"Converts midi files to multitrack items"
|
33 |
+
def process_one(self, midi_file):
|
34 |
+
try:
|
35 |
+
item = MultitrackItem.from_file(midi_file, vocab=self.vocab)
|
36 |
+
except Exception as e:
|
37 |
+
print(e)
|
38 |
+
return None
|
39 |
+
return item.to_idx()
|
40 |
+
|
41 |
+
def process(self, ds):
|
42 |
+
self.vocab = ds.vocab
|
43 |
+
ds.items = [self.process_one(item) for item in ds.items]
|
44 |
+
ds.items = [i for i in ds.items if i is not None]
|
45 |
+
|
46 |
+
class S2SPreloader(Callback):
|
47 |
+
def __init__(self, dataset:LabelList, bptt:int=512,
|
48 |
+
transpose_range=None, **kwargs):
|
49 |
+
self.dataset,self.bptt = dataset,bptt
|
50 |
+
self.vocab = self.dataset.vocab
|
51 |
+
self.transpose_range = transpose_range
|
52 |
+
self.rand_transpose = partial(rand_transpose_value, rand_range=transpose_range) if transpose_range is not None else None
|
53 |
+
|
54 |
+
def __getitem__(self, k:int):
|
55 |
+
item,empty_label = self.dataset[k]
|
56 |
+
|
57 |
+
if self.rand_transpose is not None:
|
58 |
+
val = self.rand_transpose()
|
59 |
+
item = item.transpose(val)
|
60 |
+
item = item.pad_to(self.bptt+1)
|
61 |
+
((m_x, m_pos), (c_x, c_pos)) = item.to_idx()
|
62 |
+
return m_x, m_pos, c_x, c_pos
|
63 |
+
|
64 |
+
def __len__(self):
|
65 |
+
return len(self.dataset)
|
66 |
+
|
67 |
+
def rand_transpose_value(rand_range=(0,24), p=0.5):
|
68 |
+
if np.random.rand() < p: return np.random.randint(*rand_range)-rand_range[1]//2
|
69 |
+
return 0
|
70 |
+
|
71 |
+
class S2SItemList(MusicItemList):
|
72 |
+
_bunch = MusicDataBunch
|
73 |
+
def get(self, i):
|
74 |
+
return MultitrackItem.from_idx(self.items[i], self.vocab)
|
75 |
+
|
76 |
+
# DATALOADING AND TRANSFORMATIONS
|
77 |
+
# These transforms happen on batch
|
78 |
+
|
79 |
+
def mask_tfm(b, mask_range, mask_idx, pad_idx, p=0.3):
|
80 |
+
# mask range (min, max)
|
81 |
+
# replacement vals - [x_replace, y_replace]. Usually [mask_idx, pad_idx]
|
82 |
+
# p = replacement probability
|
83 |
+
x,y = b
|
84 |
+
x,y = x.clone(),y.clone()
|
85 |
+
rand = torch.rand(x.shape, device=x.device)
|
86 |
+
rand[x < mask_range[0]] = 1.0
|
87 |
+
rand[x >= mask_range[1]] = 1.0
|
88 |
+
|
89 |
+
# p(15%) of words are replaced. Of those p(15%) - 80% are masked. 10% wrong word. 10% unchanged
|
90 |
+
y[rand > p] = pad_idx # pad unchanged 80%. Remove these from loss/acc metrics
|
91 |
+
x[rand <= (p*.8)] = mask_idx # 80% = mask
|
92 |
+
wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
|
93 |
+
x[wrong_word] = torch.randint(*mask_range, [wrong_word.sum().item()], device=x.device)
|
94 |
+
return x, y
|
95 |
+
|
96 |
+
def mask_lm_tfm_default(b, vocab, mask_p=0.3):
|
97 |
+
return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)
|
98 |
+
|
99 |
+
def mask_lm_tfm_pitchdur(b, vocab, mask_p=0.9):
|
100 |
+
mask_range = vocab.dur_range if np.random.rand() < 0.5 else vocab.note_range
|
101 |
+
return mask_lm_tfm(b, mask_range=mask_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p)
|
102 |
+
|
103 |
+
def mask_lm_tfm(b, mask_range, mask_idx, pad_idx, mask_p):
|
104 |
+
x,y = b
|
105 |
+
x_lm,x_pos = x[...,0], x[...,1]
|
106 |
+
y_lm,y_pos = y[...,0], y[...,1]
|
107 |
+
|
108 |
+
# Note: masking y_lm instead of x_lm. Just in case we ever do sequential s2s training
|
109 |
+
x_msk, y_msk = mask_tfm((y_lm, y_lm), mask_range=mask_range, mask_idx=mask_idx, pad_idx=pad_idx, p=mask_p)
|
110 |
+
msk_pos = y_pos
|
111 |
+
|
112 |
+
x_dict = {
|
113 |
+
'msk': { 'x': x_msk, 'pos': msk_pos },
|
114 |
+
'lm': { 'x': x_lm, 'pos': msk_pos }
|
115 |
+
}
|
116 |
+
y_dict = { 'msk': y_msk, 'lm': y_lm }
|
117 |
+
return x_dict, y_dict
|
118 |
+
|
119 |
+
def melody_chord_tfm(b):
|
120 |
+
m,m_pos,c,c_pos = b
|
121 |
+
|
122 |
+
# offset x and y for next word prediction
|
123 |
+
y_m = m[:,1:]
|
124 |
+
x_m, m_pos = m[:,:-1], m_pos[:,:-1]
|
125 |
+
|
126 |
+
y_c = c[:,1:]
|
127 |
+
x_c, c_pos = c[:,:-1], c_pos[:,:-1]
|
128 |
+
|
129 |
+
x_dict = {
|
130 |
+
'c2m': {
|
131 |
+
'enc': x_c,
|
132 |
+
'enc_pos': c_pos,
|
133 |
+
'dec': x_m,
|
134 |
+
'dec_pos': m_pos
|
135 |
+
},
|
136 |
+
'm2c': {
|
137 |
+
'enc': x_m,
|
138 |
+
'enc_pos': m_pos,
|
139 |
+
'dec': x_c,
|
140 |
+
'dec_pos': c_pos
|
141 |
+
}
|
142 |
+
}
|
143 |
+
y_dict = {
|
144 |
+
'c2m': y_m, 'm2c': y_c
|
145 |
+
}
|
146 |
+
return x_dict, y_dict
|
utils/musicautobot/multitask_transformer/learner.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basics import *
|
2 |
+
from ..vocab import *
|
3 |
+
from ..utils.top_k_top_p import top_k_top_p
|
4 |
+
from ..utils.midifile import is_empty_midi
|
5 |
+
from ..music_transformer.transform import *
|
6 |
+
from ..music_transformer.learner import filter_invalid_indexes
|
7 |
+
from .model import get_multitask_model
|
8 |
+
from .dataloader import *
|
9 |
+
|
10 |
+
def multitask_model_learner(data:DataBunch, config:dict=None, drop_mult:float=1.,
|
11 |
+
pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
|
12 |
+
"Create a `Learner` with a language model from `data` and `arch`."
|
13 |
+
vocab = data.vocab
|
14 |
+
vocab_size = len(vocab)
|
15 |
+
|
16 |
+
if pretrained_path:
|
17 |
+
state = torch.load(pretrained_path, map_location='cpu')
|
18 |
+
if config is None: config = state['config']
|
19 |
+
|
20 |
+
model = get_multitask_model(vocab_size, config=config, drop_mult=drop_mult, pad_idx=vocab.pad_idx)
|
21 |
+
metrics = [AverageMultiMetric(partial(m, pad_idx=vocab.pad_idx)) for m in [mask_acc, lm_acc, c2m_acc, m2c_acc]]
|
22 |
+
loss_func = MultiLoss(ignore_index=data.vocab.pad_idx)
|
23 |
+
learn = MultitaskLearner(data, model, loss_func=loss_func, metrics=metrics, **learn_kwargs)
|
24 |
+
|
25 |
+
if pretrained_path:
|
26 |
+
get_model(model).load_state_dict(state['model'], strict=False)
|
27 |
+
if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
|
28 |
+
try: learn.opt.load_state_dict(state['opt'])
|
29 |
+
except: pass
|
30 |
+
del state
|
31 |
+
gc.collect()
|
32 |
+
|
33 |
+
return learn
|
34 |
+
|
35 |
+
class MultitaskLearner(Learner):
|
36 |
+
def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
|
37 |
+
"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
|
38 |
+
out_path = super().save(file, return_path=True, with_opt=with_opt)
|
39 |
+
if config and out_path:
|
40 |
+
state = torch.load(out_path)
|
41 |
+
state['config'] = config
|
42 |
+
torch.save(state, out_path)
|
43 |
+
del state
|
44 |
+
gc.collect()
|
45 |
+
return out_path
|
46 |
+
|
47 |
+
def predict_nw(self, item:MusicItem, n_words:int=128,
|
48 |
+
temperatures:float=(1.0,1.0), min_bars=4,
|
49 |
+
top_k=30, top_p=0.6):
|
50 |
+
"Return the `n_words` that come after `text`."
|
51 |
+
self.model.reset()
|
52 |
+
new_idx = []
|
53 |
+
vocab = self.data.vocab
|
54 |
+
x, pos = item.to_tensor(), item.get_pos_tensor()
|
55 |
+
last_pos = pos[-1] if len(pos) else 0
|
56 |
+
y = torch.tensor([0])
|
57 |
+
|
58 |
+
start_pos = last_pos
|
59 |
+
|
60 |
+
sep_count = 0
|
61 |
+
bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
|
62 |
+
vocab = self.data.vocab
|
63 |
+
|
64 |
+
repeat_count = 0
|
65 |
+
|
66 |
+
for i in progress_bar(range(n_words), leave=True):
|
67 |
+
batch = { 'lm': { 'x': x[None], 'pos': pos[None] } }, y
|
68 |
+
logits = self.pred_batch(batch=batch)['lm'][-1][-1]
|
69 |
+
|
70 |
+
prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
|
71 |
+
|
72 |
+
# Temperature
|
73 |
+
# Use first temperatures value if last prediction was duration
|
74 |
+
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
|
75 |
+
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
|
76 |
+
temperature += repeat_penalty
|
77 |
+
if temperature != 1.: logits = logits / temperature
|
78 |
+
|
79 |
+
|
80 |
+
# Filter
|
81 |
+
# bar = 16 beats
|
82 |
+
filter_value = -float('Inf')
|
83 |
+
if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
|
84 |
+
|
85 |
+
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
|
86 |
+
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
|
87 |
+
|
88 |
+
# Sample
|
89 |
+
probs = F.softmax(logits, dim=-1)
|
90 |
+
idx = torch.multinomial(probs, 1).item()
|
91 |
+
|
92 |
+
# Update repeat count
|
93 |
+
num_choices = len(probs.nonzero().view(-1))
|
94 |
+
if num_choices <= 2: repeat_count += 1
|
95 |
+
else: repeat_count = repeat_count // 2
|
96 |
+
|
97 |
+
if prev_idx==vocab.sep_idx:
|
98 |
+
duration = idx - vocab.dur_range[0]
|
99 |
+
last_pos = last_pos + duration
|
100 |
+
|
101 |
+
bars_pred = (last_pos - start_pos) // 16
|
102 |
+
abs_bar = last_pos // 16
|
103 |
+
# if (bars % 8 == 0) and (bars_pred > min_bars): break
|
104 |
+
if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
|
105 |
+
|
106 |
+
|
107 |
+
if idx==vocab.bos_idx:
|
108 |
+
print('Predicted BOS token. Returning prediction...')
|
109 |
+
break
|
110 |
+
|
111 |
+
new_idx.append(idx)
|
112 |
+
x = x.new_tensor([idx])
|
113 |
+
pos = pos.new_tensor([last_pos])
|
114 |
+
|
115 |
+
pred = vocab.to_music_item(np.array(new_idx))
|
116 |
+
full = item.append(pred)
|
117 |
+
return pred, full
|
118 |
+
|
119 |
+
def predict_mask(self, masked_item:MusicItem,
|
120 |
+
temperatures:float=(1.0,1.0),
|
121 |
+
top_k=20, top_p=0.8):
|
122 |
+
x = masked_item.to_tensor()
|
123 |
+
pos = masked_item.get_pos_tensor()
|
124 |
+
y = torch.tensor([0])
|
125 |
+
vocab = self.data.vocab
|
126 |
+
self.model.reset()
|
127 |
+
mask_idxs = (x == vocab.mask_idx).nonzero().view(-1)
|
128 |
+
|
129 |
+
repeat_count = 0
|
130 |
+
|
131 |
+
for midx in progress_bar(mask_idxs, leave=True):
|
132 |
+
prev_idx = x[midx-1]
|
133 |
+
|
134 |
+
# Using original positions, otherwise model gets too off track
|
135 |
+
# pos = torch.tensor(-position_enc(xb[0].cpu().numpy()), device=xb.device)[None]
|
136 |
+
|
137 |
+
# Next Word
|
138 |
+
logits = self.pred_batch(batch=({ 'msk': { 'x': x[None], 'pos': pos[None] } }, y) )['msk'][0][midx]
|
139 |
+
|
140 |
+
# Temperature
|
141 |
+
# Use first temperatures value if last prediction was duration
|
142 |
+
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
|
143 |
+
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
|
144 |
+
temperature += repeat_penalty
|
145 |
+
if temperature != 1.: logits = logits / temperature
|
146 |
+
|
147 |
+
# Filter
|
148 |
+
filter_value = -float('Inf')
|
149 |
+
special_idxs = [vocab.bos_idx, vocab.sep_idx, vocab.stoi[EOS]]
|
150 |
+
logits[special_idxs] = filter_value # Don't allow any special tokens (as we are only removing notes and durations)
|
151 |
+
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
|
152 |
+
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
|
153 |
+
|
154 |
+
# Sampling
|
155 |
+
probs = F.softmax(logits, dim=-1)
|
156 |
+
idx = torch.multinomial(probs, 1).item()
|
157 |
+
|
158 |
+
# Update repeat count
|
159 |
+
num_choices = len(probs.nonzero().view(-1))
|
160 |
+
if num_choices <= 2: repeat_count += 1
|
161 |
+
else: repeat_count = repeat_count // 2
|
162 |
+
|
163 |
+
x[midx] = idx
|
164 |
+
|
165 |
+
return vocab.to_music_item(x.cpu().numpy())
|
166 |
+
|
167 |
+
def predict_s2s(self, input_item:MusicItem, target_item:MusicItem, n_words:int=256,
|
168 |
+
temperatures:float=(1.0,1.0), top_k=30, top_p=0.8,
|
169 |
+
use_memory=True):
|
170 |
+
vocab = self.data.vocab
|
171 |
+
|
172 |
+
# Input doesn't change. We can reuse the encoder output on each prediction
|
173 |
+
with torch.no_grad():
|
174 |
+
inp, inp_pos = input_item.to_tensor(), input_item.get_pos_tensor()
|
175 |
+
x_enc = self.model.encoder(inp[None], inp_pos[None])
|
176 |
+
|
177 |
+
# target
|
178 |
+
targ = target_item.data.tolist()
|
179 |
+
targ_pos = target_item.position.tolist()
|
180 |
+
last_pos = targ_pos[-1]
|
181 |
+
self.model.reset()
|
182 |
+
|
183 |
+
repeat_count = 0
|
184 |
+
|
185 |
+
max_pos = input_item.position[-1] + SAMPLE_FREQ * 4 # Only predict until both tracks/parts have the same length
|
186 |
+
x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
|
187 |
+
|
188 |
+
for i in progress_bar(range(n_words), leave=True):
|
189 |
+
# Predict
|
190 |
+
with torch.no_grad():
|
191 |
+
dec = self.model.decoder(x[None], pos[None], x_enc)
|
192 |
+
logits = self.model.head(dec)[-1, -1]
|
193 |
+
|
194 |
+
# Temperature
|
195 |
+
# Use first temperatures value if last prediction was duration
|
196 |
+
prev_idx = targ[-1] if len(targ) else vocab.pad_idx
|
197 |
+
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
|
198 |
+
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
|
199 |
+
temperature += repeat_penalty
|
200 |
+
if temperature != 1.: logits = logits / temperature
|
201 |
+
|
202 |
+
# Filter
|
203 |
+
filter_value = -float('Inf')
|
204 |
+
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
|
205 |
+
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
|
206 |
+
|
207 |
+
# Sample
|
208 |
+
probs = F.softmax(logits, dim=-1)
|
209 |
+
idx = torch.multinomial(probs, 1).item()
|
210 |
+
|
211 |
+
# Update repeat count
|
212 |
+
num_choices = len(probs.nonzero().view(-1))
|
213 |
+
if num_choices <= 2: repeat_count += 1
|
214 |
+
else: repeat_count = repeat_count // 2
|
215 |
+
|
216 |
+
if idx == vocab.bos_idx | idx == vocab.stoi[EOS]:
|
217 |
+
print('Predicting BOS/EOS')
|
218 |
+
break
|
219 |
+
|
220 |
+
if prev_idx == vocab.sep_idx:
|
221 |
+
duration = idx - vocab.dur_range[0]
|
222 |
+
last_pos = last_pos + duration
|
223 |
+
if last_pos > max_pos:
|
224 |
+
print('Predicted past counter-part length. Returning early')
|
225 |
+
break
|
226 |
+
|
227 |
+
targ_pos.append(last_pos)
|
228 |
+
targ.append(idx)
|
229 |
+
|
230 |
+
if use_memory:
|
231 |
+
# Relying on memory for kv. Only need last prediction index
|
232 |
+
x, pos = inp.new_tensor([targ[-1]]), inp_pos.new_tensor([targ_pos[-1]])
|
233 |
+
else:
|
234 |
+
# Reset memory after each prediction, since we feeding the whole sequence every time
|
235 |
+
self.model.reset()
|
236 |
+
x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos)
|
237 |
+
|
238 |
+
return vocab.to_music_item(np.array(targ))
|
239 |
+
|
240 |
+
# High level prediction functions from midi file
|
241 |
+
def nw_predict_from_midi(learn, midi=None, n_words=400,
|
242 |
+
temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
|
243 |
+
vocab = learn.data.vocab
|
244 |
+
seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
|
245 |
+
if seed_len is not None: seed = seed.trim_to_beat(seed_len)
|
246 |
+
|
247 |
+
pred, full = learn.predict_nw(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
|
248 |
+
return full
|
249 |
+
|
250 |
+
def s2s_predict_from_midi(learn, midi=None, n_words=200,
|
251 |
+
temperatures=(1.0,1.0), top_k=24, top_p=0.7, seed_len=None, pred_melody=True, **kwargs):
|
252 |
+
multitrack_item = MultitrackItem.from_file(midi, learn.data.vocab)
|
253 |
+
melody, chords = multitrack_item.melody, multitrack_item.chords
|
254 |
+
inp, targ = (chords, melody) if pred_melody else (melody, chords)
|
255 |
+
|
256 |
+
# if seed_len is passed, cutoff sequence so we can predict the rest
|
257 |
+
if seed_len is not None: targ = targ.trim_to_beat(seed_len)
|
258 |
+
targ = targ.remove_eos()
|
259 |
+
|
260 |
+
pred = learn.predict_s2s(inp, targ, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
|
261 |
+
|
262 |
+
part_order = (pred, inp) if pred_melody else (inp, pred)
|
263 |
+
return MultitrackItem(*part_order)
|
264 |
+
|
265 |
+
def mask_predict_from_midi(learn, midi=None, predict_notes=True,
|
266 |
+
temperatures=(1.0,1.0), top_k=30, top_p=0.7, section=None, **kwargs):
|
267 |
+
item = MusicItem.from_file(midi, learn.data.vocab)
|
268 |
+
masked_item = item.mask_pitch(section) if predict_notes else item.mask_duration(section)
|
269 |
+
pred = learn.predict_mask(masked_item, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
|
270 |
+
return pred
|
271 |
+
|
272 |
+
# LOSS AND METRICS
|
273 |
+
|
274 |
+
class MultiLoss():
|
275 |
+
def __init__(self, ignore_index=None):
|
276 |
+
"Loss mult - Mask, NextWord, Seq2Seq"
|
277 |
+
self.loss = CrossEntropyFlat(ignore_index=ignore_index)
|
278 |
+
|
279 |
+
def __call__(self, inputs:Dict[str,Tensor], targets:Dict[str,Tensor])->Rank0Tensor:
|
280 |
+
losses = [self.loss(inputs[key], target) for key,target in targets.items()]
|
281 |
+
return sum(losses)
|
282 |
+
|
283 |
+
def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx)->Rank0Tensor:
|
284 |
+
if input is None or targ is None: return None
|
285 |
+
n = targ.shape[0]
|
286 |
+
input = input.argmax(dim=-1).view(n,-1)
|
287 |
+
targ = targ.view(n,-1)
|
288 |
+
mask = targ != pad_idx
|
289 |
+
return (input[mask]==targ[mask]).float().mean()
|
290 |
+
|
291 |
+
def acc_index(inputs, targets, key, pad_idx):
|
292 |
+
return acc_ignore_pad(inputs.get(key), targets.get(key), pad_idx)
|
293 |
+
|
294 |
+
def mask_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'msk', pad_idx)
|
295 |
+
def lm_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'lm', pad_idx)
|
296 |
+
def c2m_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'c2m', pad_idx)
|
297 |
+
def m2c_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'm2c', pad_idx)
|
298 |
+
|
299 |
+
|
300 |
+
class AverageMultiMetric(AverageMetric):
|
301 |
+
"Updated fastai.AverageMetric to support multi task metrics."
|
302 |
+
def on_batch_end(self, last_output, last_target, **kwargs):
|
303 |
+
"Update metric computation with `last_output` and `last_target`."
|
304 |
+
if not is_listy(last_target): last_target=[last_target]
|
305 |
+
val = self.func(last_output, *last_target)
|
306 |
+
if val is None: return
|
307 |
+
self.count += first_el(last_target).size(0)
|
308 |
+
if self.world:
|
309 |
+
val = val.clone()
|
310 |
+
dist.all_reduce(val, op=dist.ReduceOp.SUM)
|
311 |
+
val /= self.world
|
312 |
+
self.val += first_el(last_target).size(0) * val.detach().cpu()
|
313 |
+
|
314 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
315 |
+
"Set the final result in `last_metrics`."
|
316 |
+
if self.count == 0: return add_metrics(last_metrics, 0)
|
317 |
+
return add_metrics(last_metrics, self.val/self.count)
|
318 |
+
|
319 |
+
|
320 |
+
# MODEL LOADING
|
321 |
+
class MTTrainer(LearnerCallback):
|
322 |
+
"`Callback` that regroups lr adjustment to seq_len, AR and TAR."
|
323 |
+
def __init__(self, learn:Learner, dataloaders=None, starting_mask_window=1):
|
324 |
+
super().__init__(learn)
|
325 |
+
self.count = 1
|
326 |
+
self.mw_start = starting_mask_window
|
327 |
+
self.dataloaders = dataloaders
|
328 |
+
|
329 |
+
def on_epoch_begin(self, **kwargs):
|
330 |
+
"Reset the hidden state of the model."
|
331 |
+
model = get_model(self.learn.model)
|
332 |
+
model.reset()
|
333 |
+
model.encoder.mask_steps = max(self.count+self.mw_start, 100)
|
334 |
+
|
335 |
+
def on_epoch_end(self, last_metrics, **kwargs):
|
336 |
+
"Finish the computation and sends the result to the Recorder."
|
337 |
+
if self.dataloaders is not None:
|
338 |
+
self.learn.data = self.dataloaders[self.count % len(self.dataloaders)]
|
339 |
+
self.count += 1
|
340 |
+
|
utils/musicautobot/multitask_transformer/model.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basics import *
|
2 |
+
from fastai.text.models.transformer import Activation, PositionalEncoding, feed_forward, init_transformer, _line_shift
|
3 |
+
from fastai.text.models.awd_lstm import RNNDropout
|
4 |
+
from ..utils.attention_mask import *
|
5 |
+
|
6 |
+
def get_multitask_model(vocab_size:int, config:dict=None, drop_mult:float=1., pad_idx=None):
|
7 |
+
"Create a language model from `arch` and its `config`, maybe `pretrained`."
|
8 |
+
for k in config.keys():
|
9 |
+
if k.endswith('_p'): config[k] *= drop_mult
|
10 |
+
n_hid = config['d_model']
|
11 |
+
mem_len = config.pop('mem_len')
|
12 |
+
embed = TransformerEmbedding(vocab_size, n_hid, embed_p=config['embed_p'], mem_len=mem_len, pad_idx=pad_idx)
|
13 |
+
encoder = MTEncoder(embed, n_hid, n_layers=config['enc_layers'], mem_len=0, **config) # encoder doesn't need memory
|
14 |
+
decoder = MTEncoder(embed, n_hid, is_decoder=True, n_layers=config['dec_layers'], mem_len=mem_len, **config)
|
15 |
+
head = MTLinearDecoder(n_hid, vocab_size, tie_encoder=embed.embed, **config)
|
16 |
+
model = MultiTransformer(encoder, decoder, head, mem_len=mem_len)
|
17 |
+
return model.apply(init_transformer)
|
18 |
+
|
19 |
+
class MultiTransformer(nn.Module):
|
20 |
+
"Multitask Transformer for training mask, next word, and sequence 2 sequence"
|
21 |
+
def __init__(self, encoder, decoder, head, mem_len):
|
22 |
+
super().__init__()
|
23 |
+
self.encoder = encoder
|
24 |
+
self.decoder = decoder
|
25 |
+
self.head = head
|
26 |
+
self.default_mem_len = mem_len
|
27 |
+
self.current_mem_len = None
|
28 |
+
|
29 |
+
def forward(self, inp):
|
30 |
+
# data order: mask, next word, melody, chord
|
31 |
+
outputs = {}
|
32 |
+
msk, lm, c2m, m2c = [inp.get(key) for key in ['msk', 'lm', 'c2m', 'm2c']]
|
33 |
+
|
34 |
+
if msk is not None:
|
35 |
+
outputs['msk'] = self.head(self.encoder(msk['x'], msk['pos']))
|
36 |
+
if lm is not None:
|
37 |
+
outputs['lm'] = self.head(self.decoder(lm['x'], lm['pos']))
|
38 |
+
|
39 |
+
if c2m is not None:
|
40 |
+
self.reset()
|
41 |
+
c2m_enc = self.encoder(c2m['enc'], c2m['enc_pos'])
|
42 |
+
c2m_dec = self.decoder(c2m['dec'], c2m['dec_pos'], c2m_enc)
|
43 |
+
outputs['c2m'] = self.head(c2m_dec)
|
44 |
+
|
45 |
+
if m2c is not None:
|
46 |
+
self.reset()
|
47 |
+
m2c_enc = self.encoder(m2c['enc'], m2c['enc_pos'])
|
48 |
+
m2c_dec = self.decoder(m2c['dec'], m2c['dec_pos'], m2c_enc)
|
49 |
+
outputs['m2c'] = self.head(m2c_dec)
|
50 |
+
|
51 |
+
return outputs
|
52 |
+
|
53 |
+
"A sequential module that passes the reset call to its children."
|
54 |
+
def reset(self):
|
55 |
+
for module in self.children():
|
56 |
+
reset_children(module)
|
57 |
+
|
58 |
+
def reset_children(mod):
|
59 |
+
if hasattr(mod, 'reset'): mod.reset()
|
60 |
+
for module in mod.children():
|
61 |
+
reset_children(module)
|
62 |
+
|
63 |
+
# COMPONENTS
|
64 |
+
class TransformerEmbedding(nn.Module):
|
65 |
+
"Embedding + positional encoding + dropout"
|
66 |
+
def __init__(self, vocab_size:int, emb_sz:int, embed_p:float=0., mem_len=512, beat_len=32, max_bar_len=1024, pad_idx=None):
|
67 |
+
super().__init__()
|
68 |
+
self.emb_sz = emb_sz
|
69 |
+
self.pad_idx = pad_idx
|
70 |
+
|
71 |
+
self.embed = nn.Embedding(vocab_size, emb_sz, padding_idx=pad_idx)
|
72 |
+
self.pos_enc = PositionalEncoding(emb_sz)
|
73 |
+
self.beat_len, self.max_bar_len = beat_len, max_bar_len
|
74 |
+
self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0)
|
75 |
+
self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0)
|
76 |
+
|
77 |
+
self.drop = nn.Dropout(embed_p)
|
78 |
+
self.mem_len = mem_len
|
79 |
+
|
80 |
+
def forward(self, inp, pos):
|
81 |
+
beat_enc = self.beat_enc(pos % self.beat_len)
|
82 |
+
bar_pos = pos // self.beat_len % self.max_bar_len
|
83 |
+
bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1
|
84 |
+
bar_enc = self.bar_enc((bar_pos))
|
85 |
+
emb = self.drop(self.embed(inp) + beat_enc + bar_enc)
|
86 |
+
return emb
|
87 |
+
|
88 |
+
def relative_pos_enc(self, emb):
|
89 |
+
# return torch.arange(640-1, -1, -1).float().cuda()
|
90 |
+
seq_len = emb.shape[1] + self.mem_len
|
91 |
+
pos = torch.arange(seq_len-1, -1, -1, device=emb.device, dtype=emb.dtype) # backwards (txl pos encoding)
|
92 |
+
return self.pos_enc(pos)
|
93 |
+
|
94 |
+
class MTLinearDecoder(nn.Module):
|
95 |
+
"To go on top of a RNNCore module and create a Language Model."
|
96 |
+
initrange=0.1
|
97 |
+
|
98 |
+
def __init__(self, n_hid:int, n_out:int, output_p:float, tie_encoder:nn.Module=None, out_bias:bool=True, **kwargs):
|
99 |
+
super().__init__()
|
100 |
+
self.decoder = nn.Linear(n_hid, n_out, bias=out_bias)
|
101 |
+
self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
|
102 |
+
self.output_dp = RNNDropout(output_p)
|
103 |
+
if out_bias: self.decoder.bias.data.zero_()
|
104 |
+
if tie_encoder: self.decoder.weight = tie_encoder.weight
|
105 |
+
|
106 |
+
def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
|
107 |
+
output = self.output_dp(input)
|
108 |
+
decoded = self.decoder(output)
|
109 |
+
return decoded
|
110 |
+
|
111 |
+
|
112 |
+
# DECODER TRANSLATE BLOCK
|
113 |
+
class MTEncoder(nn.Module):
|
114 |
+
def __init__(self, embed:nn.Module, n_hid:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,
|
115 |
+
resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=True, scale:bool=True,
|
116 |
+
act:Activation=Activation.ReLU, double_drop:bool=True, mem_len:int=512, is_decoder=False,
|
117 |
+
mask_steps=1, mask_p=0.3, **kwargs):
|
118 |
+
super().__init__()
|
119 |
+
self.embed = embed
|
120 |
+
self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
|
121 |
+
self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
|
122 |
+
self.n_layers,self.d_model = n_layers,d_model
|
123 |
+
self.layers = nn.ModuleList([MTEncoderBlock(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
|
124 |
+
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, mem_len=mem_len,
|
125 |
+
) for k in range(n_layers)])
|
126 |
+
|
127 |
+
self.mask_steps, self.mask_p = mask_steps, mask_p
|
128 |
+
self.is_decoder = is_decoder
|
129 |
+
|
130 |
+
nn.init.normal_(self.u, 0., 0.02)
|
131 |
+
nn.init.normal_(self.v, 0., 0.02)
|
132 |
+
|
133 |
+
def forward(self, x_lm, lm_pos, msk_emb=None):
|
134 |
+
bs,lm_len = x_lm.size()
|
135 |
+
|
136 |
+
lm_emb = self.embed(x_lm, lm_pos)
|
137 |
+
if msk_emb is not None and msk_emb.shape[1] > lm_emb.shape[1]:
|
138 |
+
pos_enc = self.embed.relative_pos_enc(msk_emb)
|
139 |
+
else:
|
140 |
+
pos_enc = self.embed.relative_pos_enc(lm_emb)
|
141 |
+
|
142 |
+
# Masks
|
143 |
+
if self.is_decoder:
|
144 |
+
lm_mask = rand_window_mask(lm_len, self.embed.mem_len, x_lm.device,
|
145 |
+
max_size=self.mask_steps, p=self.mask_p, is_eval=not self.training)
|
146 |
+
else:
|
147 |
+
lm_mask = None
|
148 |
+
|
149 |
+
for i, layer in enumerate(self.layers):
|
150 |
+
lm_emb = layer(lm_emb, msk_emb, lm_mask=lm_mask,
|
151 |
+
r=pos_enc, g_u=self.u, g_v=self.v)
|
152 |
+
return lm_emb
|
153 |
+
|
154 |
+
class MTEncoderBlock(nn.Module):
|
155 |
+
"Decoder block of a Transformer model."
|
156 |
+
#Can't use Sequential directly cause more than one input...
|
157 |
+
def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,
|
158 |
+
bias:bool=True, scale:bool=True, double_drop:bool=True, mem_len:int=512, mha2_mem_len=0, **kwargs):
|
159 |
+
super().__init__()
|
160 |
+
attn_cls = MemMultiHeadRelativeAttentionKV
|
161 |
+
self.mha1 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mem_len, r_mask=False)
|
162 |
+
self.mha2 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mha2_mem_len, r_mask=True)
|
163 |
+
self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, double_drop=double_drop)
|
164 |
+
|
165 |
+
def forward(self, enc_lm:Tensor, enc_msk:Tensor,
|
166 |
+
r=None, g_u=None, g_v=None,
|
167 |
+
msk_mask:Tensor=None, lm_mask:Tensor=None):
|
168 |
+
|
169 |
+
y_lm = self.mha1(enc_lm, enc_lm, enc_lm, r, g_u, g_v, mask=lm_mask)
|
170 |
+
if enc_msk is None: return y_lm
|
171 |
+
return self.ff(self.mha2(y_lm, enc_msk, enc_msk, r, g_u, g_v, mask=msk_mask))
|
172 |
+
|
173 |
+
|
174 |
+
# Attention Layer
|
175 |
+
|
176 |
+
|
177 |
+
# Attn
|
178 |
+
|
179 |
+
class MemMultiHeadRelativeAttentionKV(nn.Module):
|
180 |
+
"Attention Layer monster - relative positioning, keeps track of own memory, separate kv weights to support sequence2sequence decoding."
|
181 |
+
def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True,
|
182 |
+
scale:bool=True, mem_len:int=512, r_mask=True):
|
183 |
+
super().__init__()
|
184 |
+
d_head = ifnone(d_head, d_model//n_heads)
|
185 |
+
self.n_heads,self.d_head,self.scale = n_heads,d_head,scale
|
186 |
+
|
187 |
+
assert(d_model == d_head * n_heads)
|
188 |
+
self.q_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
|
189 |
+
self.k_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
|
190 |
+
self.v_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias)
|
191 |
+
|
192 |
+
self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)
|
193 |
+
self.ln = nn.LayerNorm(d_model)
|
194 |
+
self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)
|
195 |
+
self.r_mask = r_mask
|
196 |
+
|
197 |
+
self.mem_len = mem_len
|
198 |
+
self.prev_k = None
|
199 |
+
self.prev_v = None
|
200 |
+
|
201 |
+
def forward(self, q:Tensor, k:Tensor=None, v:Tensor=None,
|
202 |
+
r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None,
|
203 |
+
mask:Tensor=None, **kwargs):
|
204 |
+
if k is None: k = q
|
205 |
+
if v is None: v = q
|
206 |
+
return self.ln(q + self.drop_res(self._apply_attention(q, k, v, r, g_u, g_v, mask=mask, **kwargs)))
|
207 |
+
|
208 |
+
def mem_k(self, k):
|
209 |
+
if self.mem_len == 0: return k
|
210 |
+
if self.prev_k is None or (self.prev_k.shape[0] != k.shape[0]): # reset if wrong batch size
|
211 |
+
self.prev_k = k[:, -self.mem_len:]
|
212 |
+
return k
|
213 |
+
with torch.no_grad():
|
214 |
+
k_ext = torch.cat([self.prev_k, k], dim=1)
|
215 |
+
self.prev_k = k_ext[:, -self.mem_len:]
|
216 |
+
return k_ext.detach()
|
217 |
+
|
218 |
+
def mem_v(self, v):
|
219 |
+
if self.mem_len == 0: return v
|
220 |
+
if self.prev_v is None or (self.prev_v.shape[0] != v.shape[0]): # reset if wrong batch size
|
221 |
+
self.prev_v = v[:, -self.mem_len:]
|
222 |
+
return v
|
223 |
+
with torch.no_grad():
|
224 |
+
v_ext = torch.cat([self.prev_v, v], dim=1)
|
225 |
+
self.prev_v = v_ext[:, -self.mem_len:]
|
226 |
+
return v_ext.detach()
|
227 |
+
|
228 |
+
def reset(self):
|
229 |
+
self.prev_v = None
|
230 |
+
self.prev_k = None
|
231 |
+
|
232 |
+
def _apply_attention(self, q:Tensor, k:Tensor, v:Tensor,
|
233 |
+
r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None,
|
234 |
+
mask:Tensor=None, **kwargs):
|
235 |
+
#Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable
|
236 |
+
#parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states.
|
237 |
+
# bs,x_len,seq_len = q.size(0),q.size(1),r.size(0)
|
238 |
+
k = self.mem_k(k)
|
239 |
+
v = self.mem_v(v)
|
240 |
+
bs,x_len,seq_len = q.size(0),q.size(1),k.size(1)
|
241 |
+
wq,wk,wv = self.q_wgt(q),self.k_wgt(k),self.v_wgt(v)
|
242 |
+
wq = wq[:,-x_len:]
|
243 |
+
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
|
244 |
+
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
|
245 |
+
wkr = self.r_attn(r[-seq_len:])
|
246 |
+
wkr = wkr.view(seq_len, self.n_heads, self.d_head)
|
247 |
+
wkr = wkr.permute(1,2,0)
|
248 |
+
#### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)
|
249 |
+
AC = torch.matmul(wq+g_u,wk)
|
250 |
+
BD = _line_shift(torch.matmul(wq+g_v, wkr), mask=self.r_mask)
|
251 |
+
if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))
|
252 |
+
if mask is not None:
|
253 |
+
mask = mask[...,-seq_len:]
|
254 |
+
if hasattr(mask, 'bool'): mask = mask.bool()
|
255 |
+
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
|
256 |
+
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
|
257 |
+
attn_vec = torch.matmul(attn_prob, wv)
|
258 |
+
return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
|
utils/musicautobot/multitask_transformer/transform.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..music_transformer.transform import *
|
2 |
+
|
3 |
+
class MultitrackItem():
|
4 |
+
def __init__(self, melody:MusicItem, chords:MusicItem, stream=None):
|
5 |
+
self.melody,self.chords = melody, chords
|
6 |
+
self.vocab = melody.vocab
|
7 |
+
self._stream = stream
|
8 |
+
|
9 |
+
@classmethod
|
10 |
+
def from_file(cls, midi_file, vocab):
|
11 |
+
return cls.from_stream(file2stream(midi_file), vocab)
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def from_stream(cls, stream, vocab):
|
15 |
+
if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts()
|
16 |
+
num_parts = len(stream.parts)
|
17 |
+
sort_pitch = False
|
18 |
+
if num_parts > 2:
|
19 |
+
raise ValueError('Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks')
|
20 |
+
elif num_parts == 1:
|
21 |
+
print('Warning: only 1 track found. Inferring melody/chords')
|
22 |
+
stream = separate_melody_chord(stream)
|
23 |
+
sort_pitch = False
|
24 |
+
|
25 |
+
mpart, cpart = stream2npenc_parts(stream, sort_pitch=sort_pitch)
|
26 |
+
return cls.from_npenc_parts(mpart, cpart, vocab, stream)
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def from_npenc_parts(cls, mpart, cpart, vocab, stream=None):
|
30 |
+
mpart = npenc2idxenc(mpart, seq_type=SEQType.Melody, vocab=vocab, add_eos=False)
|
31 |
+
cpart = npenc2idxenc(cpart, seq_type=SEQType.Chords, vocab=vocab, add_eos=False)
|
32 |
+
return MultitrackItem(MusicItem(mpart, vocab), MusicItem(cpart, vocab), stream)
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def from_idx(cls, item, vocab):
|
36 |
+
m, c = item
|
37 |
+
return MultitrackItem(MusicItem.from_idx(m, vocab), MusicItem.from_idx(c, vocab))
|
38 |
+
def to_idx(self): return np.array((self.melody.to_idx(), self.chords.to_idx()))
|
39 |
+
|
40 |
+
@property
|
41 |
+
def stream(self):
|
42 |
+
self._stream = self.to_stream() if self._stream is None else self._stream
|
43 |
+
return self._stream
|
44 |
+
|
45 |
+
def to_stream(self, bpm=120):
|
46 |
+
ps = self.melody.to_npenc(), self.chords.to_npenc()
|
47 |
+
ps = [npenc2chordarr(p) for p in ps]
|
48 |
+
chordarr = chordarr_combine_parts(ps)
|
49 |
+
return chordarr2stream(chordarr, bpm=bpm)
|
50 |
+
|
51 |
+
|
52 |
+
def show(self, format:str=None):
|
53 |
+
return self.stream.show(format)
|
54 |
+
def play(self): self.stream.show('midi')
|
55 |
+
|
56 |
+
def transpose(self, val):
|
57 |
+
return MultitrackItem(self.melody.transpose(val), self.chords.transpose(val))
|
58 |
+
def pad_to(self, val):
|
59 |
+
return MultitrackItem(self.melody.pad_to(val), self.chords.pad_to(val))
|
60 |
+
def trim_to_beat(self, beat):
|
61 |
+
return MultitrackItem(self.melody.trim_to_beat(beat), self.chords.trim_to_beat(beat))
|
62 |
+
|
63 |
+
def combine2chordarr(np1, np2, vocab):
|
64 |
+
if len(np1.shape) == 1: np1 = idxenc2npenc(np1, vocab)
|
65 |
+
if len(np2.shape) == 1: np2 = idxenc2npenc(np2, vocab)
|
66 |
+
p1 = npenc2chordarr(np1)
|
67 |
+
p2 = npenc2chordarr(np2)
|
68 |
+
return chordarr_combine_parts((p1, p2))
|
utils/musicautobot/music_transformer/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .dataloader import *
|
2 |
+
from .model import *
|
3 |
+
from .learner import *
|
utils/musicautobot/music_transformer/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (251 Bytes). View file
|
|
utils/musicautobot/music_transformer/__pycache__/dataloader.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
utils/musicautobot/music_transformer/__pycache__/learner.cpython-310.pyc
ADDED
Binary file (5.94 kB). View file
|
|
utils/musicautobot/music_transformer/__pycache__/model.cpython-310.pyc
ADDED
Binary file (3 kB). View file
|
|
utils/musicautobot/music_transformer/__pycache__/transform.cpython-310.pyc
ADDED
Binary file (10.7 kB). View file
|
|
utils/musicautobot/music_transformer/dataloader.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Fastai Language Model Databunch modified to work with music"
|
2 |
+
from fastai.basics import *
|
3 |
+
# from fastai.basic_data import DataBunch
|
4 |
+
from fastai.text.data import LMLabelList
|
5 |
+
from .transform import *
|
6 |
+
from ..vocab import MusicVocab
|
7 |
+
|
8 |
+
|
9 |
+
class MusicDataBunch(DataBunch):
|
10 |
+
"Create a `TextDataBunch` suitable for training a language model."
|
11 |
+
@classmethod
|
12 |
+
def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', no_check:bool=False, bs=64, val_bs:int=None,
|
13 |
+
num_workers:int=0, device:torch.device=None, collate_fn:Callable=data_collate,
|
14 |
+
dl_tfms:Optional[Collection[Callable]]=None, bptt:int=70,
|
15 |
+
preloader_cls=None, shuffle_dl=False, transpose_range=(0,12), **kwargs) -> DataBunch:
|
16 |
+
"Create a `TextDataBunch` in `path` from the `datasets` for language modelling."
|
17 |
+
datasets = cls._init_ds(train_ds, valid_ds, test_ds)
|
18 |
+
preloader_cls = MusicPreloader if preloader_cls is None else preloader_cls
|
19 |
+
val_bs = ifnone(val_bs, bs)
|
20 |
+
datasets = [preloader_cls(ds, shuffle=(i==0), bs=(bs if i==0 else val_bs), bptt=bptt, transpose_range=transpose_range, **kwargs)
|
21 |
+
for i,ds in enumerate(datasets)]
|
22 |
+
val_bs = bs
|
23 |
+
dl_tfms = [partially_apply_vocab(tfm, train_ds.vocab) for tfm in listify(dl_tfms)]
|
24 |
+
dls = [DataLoader(d, b, shuffle=shuffle_dl) for d,b in zip(datasets, (bs,val_bs,val_bs,val_bs)) if d is not None]
|
25 |
+
return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
def from_folder(cls, path:PathOrStr, extensions='.npy', **kwargs):
|
29 |
+
files = get_files(path, extensions=extensions, recurse=True);
|
30 |
+
return cls.from_files(files, path, **kwargs)
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_files(cls, files, path, processors=None, split_pct=0.1,
|
34 |
+
vocab=None, list_cls=None, **kwargs):
|
35 |
+
if vocab is None: vocab = MusicVocab.create()
|
36 |
+
if list_cls is None: list_cls = MusicItemList
|
37 |
+
src = (list_cls(items=files, path=path, processor=processors, vocab=vocab)
|
38 |
+
.split_by_rand_pct(split_pct, seed=6)
|
39 |
+
.label_const(label_cls=LMLabelList))
|
40 |
+
return src.databunch(**kwargs)
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def empty(cls, path, **kwargs):
|
44 |
+
vocab = MusicVocab.create()
|
45 |
+
src = MusicItemList([], path=path, vocab=vocab, ignore_empty=True).split_none()
|
46 |
+
return src.label_const(label_cls=LMLabelList).databunch()
|
47 |
+
|
48 |
+
def partially_apply_vocab(tfm, vocab):
|
49 |
+
if 'vocab' in inspect.getfullargspec(tfm).args:
|
50 |
+
return partial(tfm, vocab=vocab)
|
51 |
+
return tfm
|
52 |
+
|
53 |
+
class MusicItemList(ItemList):
|
54 |
+
_bunch = MusicDataBunch
|
55 |
+
|
56 |
+
def __init__(self, items:Iterator, vocab:MusicVocab=None, **kwargs):
|
57 |
+
super().__init__(items, **kwargs)
|
58 |
+
self.vocab = vocab
|
59 |
+
self.copy_new += ['vocab']
|
60 |
+
|
61 |
+
def get(self, i):
|
62 |
+
o = super().get(i)
|
63 |
+
if is_pos_enc(o):
|
64 |
+
return MusicItem.from_idx(o, self.vocab)
|
65 |
+
return MusicItem(o, self.vocab)
|
66 |
+
|
67 |
+
def is_pos_enc(idxenc):
|
68 |
+
if len(idxenc.shape) == 2 and idxenc.shape[0] == 2: return True
|
69 |
+
return idxenc.dtype == np.object and idxenc.shape == (2,)
|
70 |
+
|
71 |
+
class MusicItemProcessor(PreProcessor):
|
72 |
+
"`PreProcessor` that transforms numpy files to indexes for training"
|
73 |
+
def process_one(self,item):
|
74 |
+
item = MusicItem.from_npenc(item, vocab=self.vocab)
|
75 |
+
return item.to_idx()
|
76 |
+
|
77 |
+
def process(self, ds):
|
78 |
+
self.vocab = ds.vocab
|
79 |
+
super().process(ds)
|
80 |
+
|
81 |
+
class OpenNPFileProcessor(PreProcessor):
|
82 |
+
"`PreProcessor` that opens the filenames and read the texts."
|
83 |
+
def process_one(self,item):
|
84 |
+
return np.load(item, allow_pickle=True) if isinstance(item, Path) else item
|
85 |
+
|
86 |
+
class Midi2ItemProcessor(PreProcessor):
|
87 |
+
"Skips midi preprocessing step. And encodes midi files to MusicItems"
|
88 |
+
def process_one(self,item):
|
89 |
+
item = MusicItem.from_file(item, vocab=self.vocab)
|
90 |
+
return item.to_idx()
|
91 |
+
|
92 |
+
def process(self, ds):
|
93 |
+
self.vocab = ds.vocab
|
94 |
+
super().process(ds)
|
95 |
+
|
96 |
+
## For npenc dataset
|
97 |
+
class MusicPreloader(Callback):
|
98 |
+
"Transforms the tokens in `dataset` to a stream of contiguous batches for language modelling."
|
99 |
+
|
100 |
+
class CircularIndex():
|
101 |
+
"Handles shuffle, direction of indexing, wraps around to head tail in the ragged array as needed"
|
102 |
+
def __init__(self, length:int, forward:bool): self.idx, self.forward = np.arange(length), forward
|
103 |
+
def __getitem__(self, i):
|
104 |
+
return self.idx[ i%len(self.idx) if self.forward else len(self.idx)-1-i%len(self.idx)]
|
105 |
+
def __len__(self) -> int: return len(self.idx)
|
106 |
+
def shuffle(self): np.random.shuffle(self.idx)
|
107 |
+
|
108 |
+
def __init__(self, dataset:LabelList, lengths:Collection[int]=None, bs:int=32, bptt:int=70, backwards:bool=False,
|
109 |
+
shuffle:bool=False, y_offset:int=1,
|
110 |
+
transpose_range=None, transpose_p=0.5,
|
111 |
+
encode_position=True,
|
112 |
+
**kwargs):
|
113 |
+
self.dataset,self.bs,self.bptt,self.shuffle,self.backwards,self.lengths = dataset,bs,bptt,shuffle,backwards,lengths
|
114 |
+
self.vocab = self.dataset.vocab
|
115 |
+
self.bs *= num_distrib() or 1
|
116 |
+
self.totalToks,self.ite_len,self.idx = int(0),None,None
|
117 |
+
self.y_offset = y_offset
|
118 |
+
|
119 |
+
self.transpose_range,self.transpose_p = transpose_range,transpose_p
|
120 |
+
self.encode_position = encode_position
|
121 |
+
self.bptt_len = self.bptt
|
122 |
+
|
123 |
+
self.allocate_buffers() # needed for valid_dl on distributed training - otherwise doesn't get initialized on first epoch
|
124 |
+
|
125 |
+
def __len__(self):
|
126 |
+
if self.ite_len is None:
|
127 |
+
if self.lengths is None: self.lengths = np.array([len(item) for item in self.dataset.x])
|
128 |
+
self.totalToks = self.lengths.sum()
|
129 |
+
self.ite_len = self.bs*int( math.ceil( self.totalToks/(self.bptt*self.bs) )) if self.item is None else 1
|
130 |
+
return self.ite_len
|
131 |
+
|
132 |
+
def __getattr__(self,k:str)->Any: return getattr(self.dataset, k)
|
133 |
+
|
134 |
+
def allocate_buffers(self):
|
135 |
+
"Create the ragged array that will be filled when we ask for items."
|
136 |
+
if self.ite_len is None: len(self)
|
137 |
+
self.idx = MusicPreloader.CircularIndex(len(self.dataset.x), not self.backwards)
|
138 |
+
|
139 |
+
# batch shape = (bs, bptt, 2 - [index, pos]) if encode_position. Else - (bs, bptt)
|
140 |
+
buffer_len = (2,) if self.encode_position else ()
|
141 |
+
self.batch = np.zeros((self.bs, self.bptt+self.y_offset) + buffer_len, dtype=np.int64)
|
142 |
+
self.batch_x, self.batch_y = self.batch[:,0:self.bptt], self.batch[:,self.y_offset:self.bptt+self.y_offset]
|
143 |
+
#ro: index of the text we're at inside our datasets for the various batches
|
144 |
+
self.ro = np.zeros(self.bs, dtype=np.int64)
|
145 |
+
#ri: index of the token we're at inside our current text for the various batches
|
146 |
+
self.ri = np.zeros(self.bs, dtype=np.int)
|
147 |
+
|
148 |
+
# allocate random transpose values. Need to allocate this before hand.
|
149 |
+
self.transpose_values = self.get_random_transpose_values()
|
150 |
+
|
151 |
+
def get_random_transpose_values(self):
|
152 |
+
if self.transpose_range is None: return None
|
153 |
+
n = len(self.dataset)
|
154 |
+
rt_arr = torch.randint(*self.transpose_range, (n,))-self.transpose_range[1]//2
|
155 |
+
mask = torch.rand(rt_arr.shape) > self.transpose_p
|
156 |
+
rt_arr[mask] = 0
|
157 |
+
return rt_arr
|
158 |
+
|
159 |
+
def on_epoch_begin(self, **kwargs):
|
160 |
+
if self.idx is None: self.allocate_buffers()
|
161 |
+
elif self.shuffle:
|
162 |
+
self.ite_len = None
|
163 |
+
self.idx.shuffle()
|
164 |
+
self.transpose_values = self.get_random_transpose_values()
|
165 |
+
self.bptt_len = self.bptt
|
166 |
+
self.idx.forward = not self.backwards
|
167 |
+
|
168 |
+
step = self.totalToks / self.bs
|
169 |
+
ln_rag, countTokens, i_rag = 0, 0, -1
|
170 |
+
for i in range(0,self.bs):
|
171 |
+
#Compute the initial values for ro and ri
|
172 |
+
while ln_rag + countTokens <= int(step * i):
|
173 |
+
countTokens += ln_rag
|
174 |
+
i_rag += 1
|
175 |
+
ln_rag = self.lengths[self.idx[i_rag]]
|
176 |
+
self.ro[i] = i_rag
|
177 |
+
self.ri[i] = ( ln_rag - int(step * i - countTokens) ) if self.backwards else int(step * i - countTokens)
|
178 |
+
|
179 |
+
#Training dl gets on_epoch_begin called, val_dl, on_epoch_end
|
180 |
+
def on_epoch_end(self, **kwargs): self.on_epoch_begin()
|
181 |
+
|
182 |
+
def __getitem__(self, k:int):
|
183 |
+
j = k % self.bs
|
184 |
+
if j==0:
|
185 |
+
if self.item is not None: return self.dataset[0]
|
186 |
+
if self.idx is None: self.on_epoch_begin()
|
187 |
+
|
188 |
+
self.ro[j],self.ri[j] = self.fill_row(not self.backwards, self.dataset.x, self.idx, self.batch[j][:self.bptt_len+self.y_offset],
|
189 |
+
self.ro[j], self.ri[j], overlap=1, lengths=self.lengths)
|
190 |
+
return self.batch_x[j][:self.bptt_len], self.batch_y[j][:self.bptt_len]
|
191 |
+
|
192 |
+
def fill_row(self, forward, items, idx, row, ro, ri, overlap, lengths):
|
193 |
+
"Fill the row with tokens from the ragged array. --OBS-- overlap != 1 has not been implemented"
|
194 |
+
ibuf = n = 0
|
195 |
+
ro -= 1
|
196 |
+
while ibuf < row.shape[0]:
|
197 |
+
ro += 1
|
198 |
+
ix = idx[ro]
|
199 |
+
|
200 |
+
item = items[ix]
|
201 |
+
if self.transpose_values is not None:
|
202 |
+
item = item.transpose(self.transpose_values[ix].item())
|
203 |
+
|
204 |
+
if self.encode_position:
|
205 |
+
# Positions are colomn stacked with indexes. This makes it easier to keep in sync
|
206 |
+
rag = np.stack([item.data, item.position], axis=1)
|
207 |
+
else:
|
208 |
+
rag = item.data
|
209 |
+
|
210 |
+
if forward:
|
211 |
+
ri = 0 if ibuf else ri
|
212 |
+
n = min(lengths[ix] - ri, row.shape[0] - ibuf)
|
213 |
+
row[ibuf:ibuf+n] = rag[ri:ri+n]
|
214 |
+
else:
|
215 |
+
ri = lengths[ix] if ibuf else ri
|
216 |
+
n = min(ri, row.size - ibuf)
|
217 |
+
row[ibuf:ibuf+n] = rag[ri-n:ri][::-1]
|
218 |
+
ibuf += n
|
219 |
+
return ro, ri + ((n-overlap) if forward else -(n-overlap))
|
220 |
+
|
221 |
+
def batch_position_tfm(b):
|
222 |
+
"Batch transform for training with positional encoding"
|
223 |
+
x,y = b
|
224 |
+
x = {
|
225 |
+
'x': x[...,0],
|
226 |
+
'pos': x[...,1]
|
227 |
+
}
|
228 |
+
return x, y[...,0]
|
229 |
+
|
utils/musicautobot/music_transformer/learner.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basics import *
|
2 |
+
from fastai.text.learner import LanguageLearner, get_language_model, _model_meta
|
3 |
+
from .model import *
|
4 |
+
from .transform import MusicItem
|
5 |
+
from ..numpy_encode import SAMPLE_FREQ
|
6 |
+
from ..utils.top_k_top_p import top_k_top_p
|
7 |
+
from ..utils.midifile import is_empty_midi
|
8 |
+
|
9 |
+
_model_meta[MusicTransformerXL] = _model_meta[TransformerXL] # copy over fastai's model metadata
|
10 |
+
|
11 |
+
def music_model_learner(data:DataBunch, arch=MusicTransformerXL, config:dict=None, drop_mult:float=1.,
|
12 |
+
pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
|
13 |
+
"Create a `Learner` with a language model from `data` and `arch`."
|
14 |
+
meta = _model_meta[arch]
|
15 |
+
|
16 |
+
if pretrained_path:
|
17 |
+
state = torch.load(pretrained_path, map_location='cpu')
|
18 |
+
if config is None: config = state['config']
|
19 |
+
|
20 |
+
model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)
|
21 |
+
learn = MusicLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)
|
22 |
+
|
23 |
+
if pretrained_path:
|
24 |
+
get_model(model).load_state_dict(state['model'], strict=False)
|
25 |
+
if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
|
26 |
+
try: learn.opt.load_state_dict(state['opt'])
|
27 |
+
except: pass
|
28 |
+
del state
|
29 |
+
gc.collect()
|
30 |
+
|
31 |
+
return learn
|
32 |
+
|
33 |
+
# Predictions
|
34 |
+
from fastai import basic_train # for predictions
|
35 |
+
class MusicLearner(LanguageLearner):
|
36 |
+
def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
|
37 |
+
"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
|
38 |
+
out_path = super().save(file, return_path=True, with_opt=with_opt)
|
39 |
+
if config and out_path:
|
40 |
+
state = torch.load(out_path)
|
41 |
+
state['config'] = config
|
42 |
+
torch.save(state, out_path)
|
43 |
+
del state
|
44 |
+
gc.collect()
|
45 |
+
return out_path
|
46 |
+
|
47 |
+
def beam_search(self, xb:Tensor, n_words:int, top_k:int=10, beam_sz:int=10, temperature:float=1.,
|
48 |
+
):
|
49 |
+
"Return the `n_words` that come after `text` using beam search."
|
50 |
+
self.model.reset()
|
51 |
+
self.model.eval()
|
52 |
+
xb_length = xb.shape[-1]
|
53 |
+
if xb.shape[0] > 1: xb = xb[0][None]
|
54 |
+
yb = torch.ones_like(xb)
|
55 |
+
|
56 |
+
nodes = None
|
57 |
+
xb = xb.repeat(top_k, 1)
|
58 |
+
nodes = xb.clone()
|
59 |
+
scores = xb.new_zeros(1).float()
|
60 |
+
with torch.no_grad():
|
61 |
+
for k in progress_bar(range(n_words), leave=False):
|
62 |
+
out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)
|
63 |
+
values, indices = out.topk(top_k, dim=-1)
|
64 |
+
scores = (-values + scores[:,None]).view(-1)
|
65 |
+
indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)
|
66 |
+
sort_idx = scores.argsort()[:beam_sz]
|
67 |
+
scores = scores[sort_idx]
|
68 |
+
nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),
|
69 |
+
indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)
|
70 |
+
nodes = nodes.view(-1, nodes.size(2))[sort_idx]
|
71 |
+
self.model[0].select_hidden(indices_idx[sort_idx])
|
72 |
+
xb = nodes[:,-1][:,None]
|
73 |
+
if temperature != 1.: scores.div_(temperature)
|
74 |
+
node_idx = torch.multinomial(torch.exp(-scores), 1).item()
|
75 |
+
return [i.item() for i in nodes[node_idx][xb_length:] ]
|
76 |
+
|
77 |
+
def predict(self, item:MusicItem, n_words:int=128,
|
78 |
+
temperatures:float=(1.0,1.0), min_bars=4,
|
79 |
+
top_k=30, top_p=0.6):
|
80 |
+
"Return the `n_words` that come after `text`."
|
81 |
+
self.model.reset()
|
82 |
+
new_idx = []
|
83 |
+
vocab = self.data.vocab
|
84 |
+
x, pos = item.to_tensor(), item.get_pos_tensor()
|
85 |
+
last_pos = pos[-1] if len(pos) else 0
|
86 |
+
y = torch.tensor([0])
|
87 |
+
|
88 |
+
start_pos = last_pos
|
89 |
+
|
90 |
+
sep_count = 0
|
91 |
+
bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
|
92 |
+
vocab = self.data.vocab
|
93 |
+
|
94 |
+
repeat_count = 0
|
95 |
+
if hasattr(self.model[0], 'encode_position'):
|
96 |
+
encode_position = self.model[0].encode_position
|
97 |
+
else: encode_position = False
|
98 |
+
|
99 |
+
for i in progress_bar(range(n_words), leave=True):
|
100 |
+
with torch.no_grad():
|
101 |
+
if encode_position:
|
102 |
+
batch = { 'x': x[None], 'pos': pos[None] }
|
103 |
+
logits = self.model(batch)[0][-1][-1]
|
104 |
+
else:
|
105 |
+
logits = self.model(x[None])[0][-1][-1]
|
106 |
+
|
107 |
+
prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
|
108 |
+
|
109 |
+
# Temperature
|
110 |
+
# Use first temperatures value if last prediction was duration
|
111 |
+
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
|
112 |
+
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
|
113 |
+
temperature += repeat_penalty
|
114 |
+
if temperature != 1.: logits = logits / temperature
|
115 |
+
|
116 |
+
|
117 |
+
# Filter
|
118 |
+
# bar = 16 beats
|
119 |
+
filter_value = -float('Inf')
|
120 |
+
if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
|
121 |
+
|
122 |
+
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
|
123 |
+
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
|
124 |
+
|
125 |
+
# Sample
|
126 |
+
probs = F.softmax(logits, dim=-1)
|
127 |
+
idx = torch.multinomial(probs, 1).item()
|
128 |
+
|
129 |
+
# Update repeat count
|
130 |
+
num_choices = len(probs.nonzero().view(-1))
|
131 |
+
if num_choices <= 2: repeat_count += 1
|
132 |
+
else: repeat_count = repeat_count // 2
|
133 |
+
|
134 |
+
if prev_idx==vocab.sep_idx:
|
135 |
+
duration = idx - vocab.dur_range[0]
|
136 |
+
last_pos = last_pos + duration
|
137 |
+
|
138 |
+
bars_pred = (last_pos - start_pos) // 16
|
139 |
+
abs_bar = last_pos // 16
|
140 |
+
# if (bars % 8 == 0) and (bars_pred > min_bars): break
|
141 |
+
if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
|
142 |
+
|
143 |
+
|
144 |
+
if idx==vocab.bos_idx:
|
145 |
+
print('Predicted BOS token. Returning prediction...')
|
146 |
+
break
|
147 |
+
|
148 |
+
new_idx.append(idx)
|
149 |
+
x = x.new_tensor([idx])
|
150 |
+
pos = pos.new_tensor([last_pos])
|
151 |
+
|
152 |
+
pred = vocab.to_music_item(np.array(new_idx))
|
153 |
+
full = item.append(pred)
|
154 |
+
return pred, full
|
155 |
+
|
156 |
+
# High level prediction functions from midi file
|
157 |
+
def predict_from_midi(learn, midi=None, n_words=400,
|
158 |
+
temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
|
159 |
+
vocab = learn.data.vocab
|
160 |
+
seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
|
161 |
+
if seed_len is not None: seed = seed.trim_to_beat(seed_len)
|
162 |
+
|
163 |
+
pred, full = learn.predict(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
|
164 |
+
return full
|
165 |
+
|
166 |
+
def filter_invalid_indexes(res, prev_idx, vocab, filter_value=-float('Inf')):
|
167 |
+
if vocab.is_duration_or_pad(prev_idx):
|
168 |
+
res[list(range(*vocab.dur_range))] = filter_value
|
169 |
+
else:
|
170 |
+
res[list(range(*vocab.note_range))] = filter_value
|
171 |
+
return res
|
utils/musicautobot/music_transformer/model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basics import *
|
2 |
+
from fastai.text.models.transformer import TransformerXL
|
3 |
+
from ..utils.attention_mask import rand_window_mask
|
4 |
+
|
5 |
+
class MusicTransformerXL(TransformerXL):
|
6 |
+
"Exactly like fastai's TransformerXL, but with more aggressive attention mask: see `rand_window_mask`"
|
7 |
+
def __init__(self, *args, encode_position=True, mask_steps=1, **kwargs):
|
8 |
+
import inspect
|
9 |
+
sig = inspect.signature(TransformerXL)
|
10 |
+
arg_params = { k:kwargs[k] for k in sig.parameters if k in kwargs }
|
11 |
+
super().__init__(*args, **arg_params)
|
12 |
+
|
13 |
+
self.encode_position = encode_position
|
14 |
+
if self.encode_position: self.beat_enc = BeatPositionEncoder(kwargs['d_model'])
|
15 |
+
|
16 |
+
self.mask_steps=mask_steps
|
17 |
+
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
#The hidden state has to be initiliazed in the forward pass for nn.DataParallel
|
21 |
+
if self.mem_len > 0 and not self.init:
|
22 |
+
self.reset()
|
23 |
+
self.init = True
|
24 |
+
|
25 |
+
benc = 0
|
26 |
+
if self.encode_position:
|
27 |
+
x,pos = x['x'], x['pos']
|
28 |
+
benc = self.beat_enc(pos)
|
29 |
+
|
30 |
+
bs,x_len = x.size()
|
31 |
+
inp = self.drop_emb(self.encoder(x) + benc) #.mul_(self.d_model ** 0.5)
|
32 |
+
m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
|
33 |
+
seq_len = m_len + x_len
|
34 |
+
|
35 |
+
mask = rand_window_mask(x_len, m_len, inp.device, max_size=self.mask_steps, is_eval=not self.training) if self.mask else None
|
36 |
+
if m_len == 0: mask[...,0,0] = 0
|
37 |
+
#[None,:,:None] for einsum implementation of attention
|
38 |
+
hids = []
|
39 |
+
pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
|
40 |
+
pos_enc = self.pos_enc(pos)
|
41 |
+
hids.append(inp)
|
42 |
+
for i, layer in enumerate(self.layers):
|
43 |
+
mem = self.hidden[i] if self.mem_len > 0 else None
|
44 |
+
inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
|
45 |
+
hids.append(inp)
|
46 |
+
core_out = inp[:,-x_len:]
|
47 |
+
if self.mem_len > 0 : self._update_mems(hids)
|
48 |
+
return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]
|
49 |
+
|
50 |
+
|
51 |
+
# Beat encoder
|
52 |
+
class BeatPositionEncoder(nn.Module):
|
53 |
+
"Embedding + positional encoding + dropout"
|
54 |
+
def __init__(self, emb_sz:int, beat_len=32, max_bar_len=1024):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.beat_len, self.max_bar_len = beat_len, max_bar_len
|
58 |
+
self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0)
|
59 |
+
self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0)
|
60 |
+
|
61 |
+
def forward(self, pos):
|
62 |
+
beat_enc = self.beat_enc(pos % self.beat_len)
|
63 |
+
bar_pos = pos // self.beat_len % self.max_bar_len
|
64 |
+
bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1
|
65 |
+
bar_enc = self.bar_enc((bar_pos))
|
66 |
+
return beat_enc + bar_enc
|
utils/musicautobot/music_transformer/transform.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..numpy_encode import *
|
2 |
+
import numpy as np
|
3 |
+
from enum import Enum
|
4 |
+
import torch
|
5 |
+
from ..vocab import *
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
SEQType = Enum('SEQType', 'Mask, Sentence, Melody, Chords, Empty')
|
9 |
+
|
10 |
+
class MusicItem():
|
11 |
+
def __init__(self, data, vocab, stream=None, position=None):
|
12 |
+
self.data = data
|
13 |
+
self.vocab = vocab
|
14 |
+
self._stream = stream
|
15 |
+
self._position = position
|
16 |
+
def __repr__(self): return '\n'.join([
|
17 |
+
f'\n{self.__class__.__name__} - {self.data.shape}',
|
18 |
+
f'{self.vocab.textify(self.data[:10])}...'])
|
19 |
+
def __len__(self): return len(self.data)
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def from_file(cls, midi_file, vocab):
|
23 |
+
return cls.from_stream(file2stream(midi_file), vocab)
|
24 |
+
@classmethod
|
25 |
+
def from_stream(cls, stream, vocab):
|
26 |
+
if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts()
|
27 |
+
chordarr = stream2chordarr(stream) # 2.
|
28 |
+
npenc = chordarr2npenc(chordarr) # 3.
|
29 |
+
return cls.from_npenc(npenc, vocab, stream)
|
30 |
+
@classmethod
|
31 |
+
def from_npenc(cls, npenc, vocab, stream=None): return MusicItem(npenc2idxenc(npenc, vocab), vocab, stream)
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def from_idx(cls, item, vocab):
|
35 |
+
idx,pos = item
|
36 |
+
return MusicItem(idx, vocab=vocab, position=pos)
|
37 |
+
def to_idx(self): return self.data, self.position
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def empty(cls, vocab, seq_type=SEQType.Sentence):
|
41 |
+
return MusicItem(seq_prefix(seq_type, vocab), vocab)
|
42 |
+
|
43 |
+
@property
|
44 |
+
def stream(self):
|
45 |
+
self._stream = self.to_stream() if self._stream is None else self._stream
|
46 |
+
return self._stream
|
47 |
+
|
48 |
+
def to_stream(self, bpm=120):
|
49 |
+
return idxenc2stream(self.data, self.vocab, bpm=bpm)
|
50 |
+
|
51 |
+
def to_tensor(self, device=None):
|
52 |
+
return to_tensor(self.data, device)
|
53 |
+
|
54 |
+
def to_text(self, sep=' '): return self.vocab.textify(self.data, sep)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def position(self):
|
58 |
+
self._position = position_enc(self.data, self.vocab) if self._position is None else self._position
|
59 |
+
return self._position
|
60 |
+
|
61 |
+
def get_pos_tensor(self, device=None): return to_tensor(self.position, device)
|
62 |
+
|
63 |
+
def to_npenc(self):
|
64 |
+
return idxenc2npenc(self.data, self.vocab)
|
65 |
+
|
66 |
+
def show(self, format:str=None):
|
67 |
+
return self.stream.show(format)
|
68 |
+
def play(self): self.stream.show('midi')
|
69 |
+
|
70 |
+
#Added by caslabs
|
71 |
+
def download(self, filename:str=None, ext:str=None):
|
72 |
+
return self.stream.write('midi', fp=filename)
|
73 |
+
|
74 |
+
@property
|
75 |
+
def new(self):
|
76 |
+
return partial(type(self), vocab=self.vocab)
|
77 |
+
|
78 |
+
def trim_to_beat(self, beat, include_last_sep=False):
|
79 |
+
return self.new(trim_to_beat(self.data, self.position, self.vocab, beat, include_last_sep))
|
80 |
+
|
81 |
+
def transpose(self, interval):
|
82 |
+
return self.new(tfm_transpose(self.data, interval, self.vocab), position=self._position)
|
83 |
+
|
84 |
+
def append(self, item):
|
85 |
+
return self.new(np.concatenate((self.data, item.data), axis=0))
|
86 |
+
|
87 |
+
def mask_pitch(self, section=None):
|
88 |
+
return self.new(self.mask(self.vocab.note_range, section), position=self.position)
|
89 |
+
|
90 |
+
def mask_duration(self, section=None, keep_position_enc=True):
|
91 |
+
masked_data = self.mask(self.vocab.dur_range, section)
|
92 |
+
if keep_position_enc: return self.new(masked_data, position=self.position)
|
93 |
+
return self.new(masked_data)
|
94 |
+
|
95 |
+
def mask(self, token_range, section_range=None):
|
96 |
+
return mask_section(self.data, self.position, token_range, self.vocab.mask_idx, section_range=section_range)
|
97 |
+
|
98 |
+
def pad_to(self, bptt):
|
99 |
+
data = pad_seq(self.data, bptt, self.vocab.pad_idx)
|
100 |
+
pos = pad_seq(self.position, bptt, 0)
|
101 |
+
return self.new(data, stream=self._stream, position=pos)
|
102 |
+
|
103 |
+
def split_stream_parts(self):
|
104 |
+
self._stream = separate_melody_chord(self.stream)
|
105 |
+
return self.stream
|
106 |
+
|
107 |
+
def remove_eos(self):
|
108 |
+
if self.data[-1] == self.vocab.stoi[EOS]: return self.new(self.data, stream=self.stream)
|
109 |
+
return self
|
110 |
+
|
111 |
+
def split_parts(self):
|
112 |
+
return self.new(self.data, stream=separate_melody_chord(self.stream), position=self.position)
|
113 |
+
|
114 |
+
def pad_seq(seq, bptt, value):
|
115 |
+
pad_len = max(bptt-seq.shape[0], 0)
|
116 |
+
return np.pad(seq, (0, pad_len), 'constant', constant_values=value)[:bptt]
|
117 |
+
|
118 |
+
def to_tensor(t, device=None):
|
119 |
+
t = t if isinstance(t, torch.Tensor) else torch.tensor(t)
|
120 |
+
if device is None and torch.cuda.is_available(): t = t.cuda()
|
121 |
+
else: t.to(device)
|
122 |
+
return t.long()
|
123 |
+
|
124 |
+
def midi2idxenc(midi_file, vocab):
|
125 |
+
"Converts midi file to index encoding for training"
|
126 |
+
npenc = midi2npenc(midi_file) # 3.
|
127 |
+
return npenc2idxenc(npenc, vocab)
|
128 |
+
|
129 |
+
def idxenc2stream(arr, vocab, bpm=120):
|
130 |
+
"Converts index encoding to music21 stream"
|
131 |
+
npenc = idxenc2npenc(arr, vocab)
|
132 |
+
return npenc2stream(npenc, bpm=bpm)
|
133 |
+
|
134 |
+
# single stream instead of note,dur
|
135 |
+
def npenc2idxenc(t, vocab, seq_type=SEQType.Sentence, add_eos=False):
|
136 |
+
"Transforms numpy array from 2 column (note, duration) matrix to a single column"
|
137 |
+
"[[n1, d1], [n2, d2], ...] -> [n1, d1, n2, d2]"
|
138 |
+
if isinstance(t, (list, tuple)) and len(t) == 2:
|
139 |
+
return [npenc2idxenc(x, vocab, start_seq) for x in t]
|
140 |
+
t = t.copy()
|
141 |
+
|
142 |
+
t[:, 0] = t[:, 0] + vocab.note_range[0]
|
143 |
+
t[:, 1] = t[:, 1] + vocab.dur_range[0]
|
144 |
+
|
145 |
+
prefix = seq_prefix(seq_type, vocab)
|
146 |
+
suffix = np.array([vocab.stoi[EOS]]) if add_eos else np.empty(0, dtype=int)
|
147 |
+
return np.concatenate([prefix, t.reshape(-1), suffix])
|
148 |
+
|
149 |
+
def seq_prefix(seq_type, vocab):
|
150 |
+
if seq_type == SEQType.Empty: return np.empty(0, dtype=int)
|
151 |
+
start_token = vocab.bos_idx
|
152 |
+
if seq_type == SEQType.Chords: start_token = vocab.stoi[CSEQ]
|
153 |
+
if seq_type == SEQType.Melody: start_token = vocab.stoi[MSEQ]
|
154 |
+
return np.array([start_token, vocab.pad_idx])
|
155 |
+
|
156 |
+
def idxenc2npenc(t, vocab, validate=True):
|
157 |
+
if validate: t = to_valid_idxenc(t, vocab.npenc_range)
|
158 |
+
t = t.copy().reshape(-1, 2)
|
159 |
+
if t.shape[0] == 0: return t
|
160 |
+
|
161 |
+
t[:, 0] = t[:, 0] - vocab.note_range[0]
|
162 |
+
t[:, 1] = t[:, 1] - vocab.dur_range[0]
|
163 |
+
|
164 |
+
if validate: return to_valid_npenc(t)
|
165 |
+
return t
|
166 |
+
|
167 |
+
def to_valid_idxenc(t, valid_range):
|
168 |
+
r = valid_range
|
169 |
+
t = t[np.where((t >= r[0]) & (t < r[1]))]
|
170 |
+
if t.shape[-1] % 2 == 1: t = t[..., :-1]
|
171 |
+
return t
|
172 |
+
|
173 |
+
def to_valid_npenc(t):
|
174 |
+
is_note = (t[:, 0] < VALTSEP) | (t[:, 0] >= NOTE_SIZE)
|
175 |
+
invalid_note_idx = is_note.argmax()
|
176 |
+
invalid_dur_idx = (t[:, 1] < 0).argmax()
|
177 |
+
|
178 |
+
invalid_idx = max(invalid_dur_idx, invalid_note_idx)
|
179 |
+
if invalid_idx > 0:
|
180 |
+
if invalid_note_idx > 0 and invalid_dur_idx > 0: invalid_idx = min(invalid_dur_idx, invalid_note_idx)
|
181 |
+
print('Non midi note detected. Only returning valid portion. Index, seed', invalid_idx, t.shape)
|
182 |
+
return t[:invalid_idx]
|
183 |
+
return t
|
184 |
+
|
185 |
+
def position_enc(idxenc, vocab):
|
186 |
+
"Calculates positional beat encoding."
|
187 |
+
sep_idxs = (idxenc == vocab.sep_idx).nonzero()[0]
|
188 |
+
sep_idxs = sep_idxs[sep_idxs+2 < idxenc.shape[0]] # remove any indexes right before out of bounds (sep_idx+2)
|
189 |
+
dur_vals = idxenc[sep_idxs+1]
|
190 |
+
dur_vals[dur_vals == vocab.mask_idx] = vocab.dur_range[0] # make sure masked durations are 0
|
191 |
+
dur_vals -= vocab.dur_range[0]
|
192 |
+
|
193 |
+
posenc = np.zeros_like(idxenc)
|
194 |
+
posenc[sep_idxs+2] = dur_vals
|
195 |
+
return posenc.cumsum()
|
196 |
+
|
197 |
+
def beat2index(idxenc, pos, vocab, beat, include_last_sep=False):
|
198 |
+
cutoff = find_beat(pos, beat)
|
199 |
+
if cutoff < 2: return 2 # always leave starter tokens
|
200 |
+
if len(idxenc) < 2 or include_last_sep: return cutoff
|
201 |
+
if idxenc[cutoff - 2] == vocab.sep_idx: return cutoff - 2
|
202 |
+
return cutoff
|
203 |
+
|
204 |
+
def find_beat(pos, beat, sample_freq=SAMPLE_FREQ, side='left'):
|
205 |
+
return np.searchsorted(pos, beat * sample_freq, side=side)
|
206 |
+
|
207 |
+
# TRANSFORMS
|
208 |
+
|
209 |
+
def tfm_transpose(x, value, vocab):
|
210 |
+
x = x.copy()
|
211 |
+
x[(x >= vocab.note_range[0]) & (x < vocab.note_range[1])] += value
|
212 |
+
return x
|
213 |
+
|
214 |
+
def trim_to_beat(idxenc, pos, vocab, to_beat=None, include_last_sep=True):
|
215 |
+
if to_beat is None: return idxenc
|
216 |
+
cutoff = beat2index(idxenc, pos, vocab, to_beat, include_last_sep=include_last_sep)
|
217 |
+
return idxenc[:cutoff]
|
218 |
+
|
219 |
+
def mask_input(xb, mask_range, replacement_idx):
|
220 |
+
xb = xb.copy()
|
221 |
+
xb[(xb >= mask_range[0]) & (xb < mask_range[1])] = replacement_idx
|
222 |
+
return xb
|
223 |
+
|
224 |
+
def mask_section(xb, pos, token_range, replacement_idx, section_range=None):
|
225 |
+
xb = xb.copy()
|
226 |
+
token_mask = (xb >= token_range[0]) & (xb < token_range[1])
|
227 |
+
|
228 |
+
if section_range is None: section_range = (None, None)
|
229 |
+
section_mask = np.zeros_like(xb, dtype=bool)
|
230 |
+
start_idx = find_beat(pos, section_range[0]) if section_range[0] is not None else 0
|
231 |
+
end_idx = find_beat(pos, section_range[1]) if section_range[1] is not None else xb.shape[0]
|
232 |
+
section_mask[start_idx:end_idx] = True
|
233 |
+
|
234 |
+
xb[token_mask & section_mask] = replacement_idx
|
235 |
+
return xb
|
utils/musicautobot/numpy_encode.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Encoding music21 streams -> numpy array -> text"
|
2 |
+
|
3 |
+
# import re
|
4 |
+
import music21
|
5 |
+
import numpy as np
|
6 |
+
# from pathlib import Path
|
7 |
+
|
8 |
+
BPB = 4 # beats per bar
|
9 |
+
TIMESIG = f'{BPB}/4' # default time signature
|
10 |
+
PIANO_RANGE = (21, 108)
|
11 |
+
VALTSEP = -1 # separator value for numpy encoding
|
12 |
+
VALTCONT = -2 # numpy value for TCONT - needed for compressing chord array
|
13 |
+
|
14 |
+
SAMPLE_FREQ = 4
|
15 |
+
NOTE_SIZE = 128
|
16 |
+
DUR_SIZE = (10*BPB*SAMPLE_FREQ)+1 # Max length - 8 bars. Or 16 beats/quarternotes
|
17 |
+
MAX_NOTE_DUR = (8*BPB*SAMPLE_FREQ)
|
18 |
+
|
19 |
+
# Encoding process
|
20 |
+
# 1. midi -> music21.Stream
|
21 |
+
# 2. Stream -> numpy chord array (timestep X instrument X noterange)
|
22 |
+
# 3. numpy array -> List[Timestep][NoteEnc]
|
23 |
+
def midi2npenc(midi_file, skip_last_rest=True):
|
24 |
+
"Converts midi file to numpy encoding for language model"
|
25 |
+
stream = file2stream(midi_file) # 1.
|
26 |
+
chordarr = stream2chordarr(stream) # 2.
|
27 |
+
return chordarr2npenc(chordarr, skip_last_rest=skip_last_rest) # 3.
|
28 |
+
|
29 |
+
# Decoding process
|
30 |
+
# 1. NoteEnc -> numpy chord array
|
31 |
+
# 2. numpy array -> music21.Stream
|
32 |
+
def npenc2stream(arr, bpm=120):
|
33 |
+
"Converts numpy encoding to music21 stream"
|
34 |
+
chordarr = npenc2chordarr(np.array(arr)) # 1.
|
35 |
+
return chordarr2stream(chordarr, bpm=bpm) # 2.
|
36 |
+
|
37 |
+
##### ENCODING ######
|
38 |
+
|
39 |
+
# 1. File To STream
|
40 |
+
|
41 |
+
def file2stream(fp):
|
42 |
+
if isinstance(fp, music21.midi.MidiFile): return music21.midi.translate.midiFileToStream(fp)
|
43 |
+
return music21.converter.parse(fp)
|
44 |
+
|
45 |
+
# 2.
|
46 |
+
def stream2chordarr(s, note_size=NOTE_SIZE, sample_freq=SAMPLE_FREQ, max_note_dur=MAX_NOTE_DUR):
|
47 |
+
"Converts music21.Stream to 1-hot numpy array"
|
48 |
+
# assuming 4/4 time
|
49 |
+
# note x instrument x pitch
|
50 |
+
# FYI: midi middle C value=60
|
51 |
+
|
52 |
+
# (AS) TODO: need to order by instruments most played and filter out percussion or include the channel
|
53 |
+
highest_time = max(s.flat.getElementsByClass('Note').highestTime, s.flat.getElementsByClass('Chord').highestTime)
|
54 |
+
maxTimeStep = round(highest_time * sample_freq)+1
|
55 |
+
score_arr = np.zeros((maxTimeStep, len(s.parts), NOTE_SIZE))
|
56 |
+
|
57 |
+
def note_data(pitch, note):
|
58 |
+
return (pitch.midi, int(round(note.offset*sample_freq)), int(round(note.duration.quarterLength*sample_freq)))
|
59 |
+
|
60 |
+
for idx,part in enumerate(s.parts):
|
61 |
+
notes=[]
|
62 |
+
for elem in part.flat:
|
63 |
+
if isinstance(elem, music21.note.Note):
|
64 |
+
notes.append(note_data(elem.pitch, elem))
|
65 |
+
if isinstance(elem, music21.chord.Chord):
|
66 |
+
for p in elem.pitches:
|
67 |
+
notes.append(note_data(p, elem))
|
68 |
+
|
69 |
+
# sort notes by offset (1), duration (2) so that hits are not overwritten and longer notes have priority
|
70 |
+
notes_sorted = sorted(notes, key=lambda x: (x[1], x[2]))
|
71 |
+
for n in notes_sorted:
|
72 |
+
if n is None: continue
|
73 |
+
pitch,offset,duration = n
|
74 |
+
if max_note_dur is not None and duration > max_note_dur: duration = max_note_dur
|
75 |
+
score_arr[offset, idx, pitch] = duration
|
76 |
+
score_arr[offset+1:offset+duration, idx, pitch] = VALTCONT # Continue holding note
|
77 |
+
return score_arr
|
78 |
+
|
79 |
+
def chordarr2npenc(chordarr, skip_last_rest=True):
|
80 |
+
# combine instruments
|
81 |
+
result = []
|
82 |
+
wait_count = 0
|
83 |
+
for idx,timestep in enumerate(chordarr):
|
84 |
+
flat_time = timestep2npenc(timestep)
|
85 |
+
if len(flat_time) == 0:
|
86 |
+
wait_count += 1
|
87 |
+
else:
|
88 |
+
# pitch, octave, duration, instrument
|
89 |
+
if wait_count > 0: result.append([VALTSEP, wait_count])
|
90 |
+
result.extend(flat_time)
|
91 |
+
wait_count = 1
|
92 |
+
if wait_count > 0 and not skip_last_rest: result.append([VALTSEP, wait_count])
|
93 |
+
return np.array(result, dtype=int).reshape(-1, 2) # reshaping. Just in case result is empty
|
94 |
+
|
95 |
+
# Note: not worrying about overlaps - as notes will still play. just look tied
|
96 |
+
# http://web.mit.edu/music21/doc/moduleReference/moduleStream.html#music21.stream.Stream.getOverlaps
|
97 |
+
def timestep2npenc(timestep, note_range=PIANO_RANGE, enc_type=None):
|
98 |
+
# inst x pitch
|
99 |
+
notes = []
|
100 |
+
for i,n in zip(*timestep.nonzero()):
|
101 |
+
d = timestep[i,n]
|
102 |
+
if d < 0: continue # only supporting short duration encoding for now
|
103 |
+
if n < note_range[0] or n >= note_range[1]: continue # must be within midi range
|
104 |
+
notes.append([n,d,i])
|
105 |
+
|
106 |
+
notes = sorted(notes, key=lambda x: x[0], reverse=True) # sort by note (highest to lowest)
|
107 |
+
|
108 |
+
if enc_type is None:
|
109 |
+
# note, duration
|
110 |
+
return [n[:2] for n in notes]
|
111 |
+
if enc_type == 'parts':
|
112 |
+
# note, duration, part
|
113 |
+
return [n for n in notes]
|
114 |
+
if enc_type == 'full':
|
115 |
+
# note_class, duration, octave, instrument
|
116 |
+
return [[n%12, d, n//12, i] for n,d,i in notes]
|
117 |
+
|
118 |
+
##### DECODING #####
|
119 |
+
|
120 |
+
# 1.
|
121 |
+
def npenc2chordarr(npenc, note_size=NOTE_SIZE):
|
122 |
+
num_instruments = 1 if len(npenc.shape) <= 2 else npenc.max(axis=0)[-1]
|
123 |
+
|
124 |
+
max_len = npenc_len(npenc)
|
125 |
+
# score_arr = (steps, inst, note)
|
126 |
+
score_arr = np.zeros((max_len, num_instruments, note_size))
|
127 |
+
|
128 |
+
idx = 0
|
129 |
+
for step in npenc:
|
130 |
+
n,d,i = (step.tolist()+[0])[:3] # or n,d,i
|
131 |
+
if n < VALTSEP: continue # special token
|
132 |
+
if n == VALTSEP:
|
133 |
+
idx += d
|
134 |
+
continue
|
135 |
+
score_arr[idx,i,n] = d
|
136 |
+
return score_arr
|
137 |
+
|
138 |
+
def npenc_len(npenc):
|
139 |
+
duration = 0
|
140 |
+
for t in npenc:
|
141 |
+
if t[0] == VALTSEP: duration += t[1]
|
142 |
+
return duration + 1
|
143 |
+
|
144 |
+
|
145 |
+
# 2.
|
146 |
+
def chordarr2stream(arr, sample_freq=SAMPLE_FREQ, bpm=120):
|
147 |
+
duration = music21.duration.Duration(1. / sample_freq)
|
148 |
+
stream = music21.stream.Score()
|
149 |
+
stream.append(music21.meter.TimeSignature(TIMESIG))
|
150 |
+
stream.append(music21.tempo.MetronomeMark(number=bpm))
|
151 |
+
stream.append(music21.key.KeySignature(0))
|
152 |
+
for inst in range(arr.shape[1]):
|
153 |
+
p = partarr2stream(arr[:,inst,:], duration)
|
154 |
+
stream.append(p)
|
155 |
+
stream = stream.transpose(0)
|
156 |
+
return stream
|
157 |
+
|
158 |
+
# 2b.
|
159 |
+
def partarr2stream(partarr, duration):
|
160 |
+
"convert instrument part to music21 chords"
|
161 |
+
part = music21.stream.Part()
|
162 |
+
part.append(music21.instrument.Piano())
|
163 |
+
part_append_duration_notes(partarr, duration, part) # notes already have duration calculated
|
164 |
+
|
165 |
+
return part
|
166 |
+
|
167 |
+
def part_append_duration_notes(partarr, duration, stream):
|
168 |
+
"convert instrument part to music21 chords"
|
169 |
+
for tidx,t in enumerate(partarr):
|
170 |
+
note_idxs = np.where(t > 0)[0] # filter out any negative values (continuous mode)
|
171 |
+
if len(note_idxs) == 0: continue
|
172 |
+
notes = []
|
173 |
+
for nidx in note_idxs:
|
174 |
+
note = music21.note.Note(nidx)
|
175 |
+
note.duration = music21.duration.Duration(partarr[tidx,nidx]*duration.quarterLength)
|
176 |
+
notes.append(note)
|
177 |
+
for g in group_notes_by_duration(notes):
|
178 |
+
if len(g) == 1:
|
179 |
+
stream.insert(tidx*duration.quarterLength, g[0])
|
180 |
+
else:
|
181 |
+
chord = music21.chord.Chord(g)
|
182 |
+
stream.insert(tidx*duration.quarterLength, chord)
|
183 |
+
return stream
|
184 |
+
|
185 |
+
from itertools import groupby
|
186 |
+
# combining notes with different durations into a single chord may overwrite conflicting durations. Example: aylictal/still-waters-run-deep
|
187 |
+
def group_notes_by_duration(notes):
|
188 |
+
"separate notes into chord groups"
|
189 |
+
keyfunc = lambda n: n.duration.quarterLength
|
190 |
+
notes = sorted(notes, key=keyfunc)
|
191 |
+
return [list(g) for k,g in groupby(notes, keyfunc)]
|
192 |
+
|
193 |
+
|
194 |
+
# Midi -> npenc Conversion helpers
|
195 |
+
def is_valid_npenc(npenc, note_range=PIANO_RANGE, max_dur=DUR_SIZE,
|
196 |
+
min_notes=32, input_path=None, verbose=True):
|
197 |
+
if len(npenc) < min_notes:
|
198 |
+
if verbose: print('Sequence too short:', len(npenc), input_path)
|
199 |
+
return False
|
200 |
+
if (npenc[:,1] >= max_dur).any():
|
201 |
+
if verbose: print(f'npenc exceeds max {max_dur} duration:', npenc[:,1].max(), input_path)
|
202 |
+
return False
|
203 |
+
# https://en.wikipedia.org/wiki/Scientific_pitch_notation - 88 key range - 21 = A0, 108 = C8
|
204 |
+
if ((npenc[...,0] > VALTSEP) & ((npenc[...,0] < note_range[0]) | (npenc[...,0] >= note_range[1]))).any():
|
205 |
+
print(f'npenc out of piano note range {note_range}:', input_path)
|
206 |
+
return False
|
207 |
+
return True
|
208 |
+
|
209 |
+
# seperates overlapping notes to different tracks
|
210 |
+
def remove_overlaps(stream, separate_chords=True):
|
211 |
+
if not separate_chords:
|
212 |
+
return stream.flat.makeVoices().voicesToParts()
|
213 |
+
return separate_melody_chord(stream)
|
214 |
+
|
215 |
+
# seperates notes and chords to different tracks
|
216 |
+
def separate_melody_chord(stream):
|
217 |
+
new_stream = music21.stream.Score()
|
218 |
+
if stream.timeSignature: new_stream.append(stream.timeSignature)
|
219 |
+
new_stream.append(stream.metronomeMarkBoundaries()[0][-1])
|
220 |
+
if stream.keySignature: new_stream.append(stream.keySignature)
|
221 |
+
|
222 |
+
melody_part = music21.stream.Part(stream.flat.getElementsByClass('Note'))
|
223 |
+
melody_part.insert(0, stream.getInstrument())
|
224 |
+
chord_part = music21.stream.Part(stream.flat.getElementsByClass('Chord'))
|
225 |
+
chord_part.insert(0, stream.getInstrument())
|
226 |
+
new_stream.append(melody_part)
|
227 |
+
new_stream.append(chord_part)
|
228 |
+
return new_stream
|
229 |
+
|
230 |
+
# processing functions for sanitizing data
|
231 |
+
|
232 |
+
def compress_chordarr(chordarr):
|
233 |
+
return shorten_chordarr_rests(trim_chordarr_rests(chordarr))
|
234 |
+
|
235 |
+
def trim_chordarr_rests(arr, max_rests=4, sample_freq=SAMPLE_FREQ):
|
236 |
+
# max rests is in quarter notes
|
237 |
+
# max 1 bar between song start and end
|
238 |
+
start_idx = 0
|
239 |
+
max_sample = max_rests*sample_freq
|
240 |
+
for idx,t in enumerate(arr):
|
241 |
+
if (t != 0).any(): break
|
242 |
+
start_idx = idx+1
|
243 |
+
|
244 |
+
end_idx = 0
|
245 |
+
for idx,t in enumerate(reversed(arr)):
|
246 |
+
if (t != 0).any(): break
|
247 |
+
end_idx = idx+1
|
248 |
+
start_idx = start_idx - start_idx % max_sample
|
249 |
+
end_idx = end_idx - end_idx % max_sample
|
250 |
+
# if start_idx > 0 or end_idx > 0: print('Trimming rests. Start, end:', start_idx, len(arr)-end_idx, end_idx)
|
251 |
+
return arr[start_idx:(len(arr)-end_idx)]
|
252 |
+
|
253 |
+
def shorten_chordarr_rests(arr, max_rests=8, sample_freq=SAMPLE_FREQ):
|
254 |
+
# max rests is in quarter notes
|
255 |
+
# max 2 bar pause
|
256 |
+
rest_count = 0
|
257 |
+
result = []
|
258 |
+
max_sample = max_rests*sample_freq
|
259 |
+
for timestep in arr:
|
260 |
+
if (timestep==0).all():
|
261 |
+
rest_count += 1
|
262 |
+
else:
|
263 |
+
if rest_count > max_sample:
|
264 |
+
# old_count = rest_count
|
265 |
+
rest_count = (rest_count % sample_freq) + max_sample
|
266 |
+
# print(f'Compressing rests: {old_count} -> {rest_count}')
|
267 |
+
for i in range(rest_count): result.append(np.zeros(timestep.shape))
|
268 |
+
rest_count = 0
|
269 |
+
result.append(timestep)
|
270 |
+
for i in range(rest_count): result.append(np.zeros(timestep.shape))
|
271 |
+
return np.array(result)
|
272 |
+
|
273 |
+
# sequence 2 sequence convenience functions
|
274 |
+
|
275 |
+
def stream2npenc_parts(stream, sort_pitch=True):
|
276 |
+
chordarr = stream2chordarr(stream)
|
277 |
+
_,num_parts,_ = chordarr.shape
|
278 |
+
parts = [part_enc(chordarr, i) for i in range(num_parts)]
|
279 |
+
return sorted(parts, key=avg_pitch, reverse=True) if sort_pitch else parts
|
280 |
+
|
281 |
+
def chordarr_combine_parts(parts):
|
282 |
+
max_ts = max([p.shape[0] for p in parts])
|
283 |
+
parts_padded = [pad_part_to(p, max_ts) for p in parts]
|
284 |
+
chordarr_comb = np.concatenate(parts_padded, axis=1)
|
285 |
+
return chordarr_comb
|
286 |
+
|
287 |
+
def pad_part_to(p, target_size):
|
288 |
+
pad_width = ((0,target_size-p.shape[0]),(0,0),(0,0))
|
289 |
+
return np.pad(p, pad_width, 'constant')
|
290 |
+
|
291 |
+
def part_enc(chordarr, part):
|
292 |
+
partarr = chordarr[:,part:part+1,:]
|
293 |
+
npenc = chordarr2npenc(partarr)
|
294 |
+
return npenc
|
295 |
+
|
296 |
+
def avg_tempo(t, sep_idx=VALTSEP):
|
297 |
+
avg = t[t[:, 0] == sep_idx][:, 1].sum()/t.shape[0]
|
298 |
+
avg = int(round(avg/SAMPLE_FREQ))
|
299 |
+
return 'mt'+str(min(avg, MTEMPO_SIZE-1))
|
300 |
+
|
301 |
+
def avg_pitch(t, sep_idx=VALTSEP):
|
302 |
+
return t[t[:, 0] > sep_idx][:, 0].mean()
|
utils/musicautobot/utils/__init__.py
ADDED
File without changes
|
utils/musicautobot/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (176 Bytes). View file
|
|
utils/musicautobot/utils/__pycache__/attention_mask.cpython-310.pyc
ADDED
Binary file (1.3 kB). View file
|
|
utils/musicautobot/utils/__pycache__/file_processing.cpython-310.pyc
ADDED
Binary file (2.62 kB). View file
|
|
utils/musicautobot/utils/__pycache__/midifile.cpython-310.pyc
ADDED
Binary file (4.5 kB). View file
|
|
utils/musicautobot/utils/__pycache__/setup_musescore.cpython-310.pyc
ADDED
Binary file (1.79 kB). View file
|
|
utils/musicautobot/utils/__pycache__/top_k_top_p.cpython-310.pyc
ADDED
Binary file (1.24 kB). View file
|
|
utils/musicautobot/utils/attention_mask.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def window_mask(x_len, device, m_len=0, size=(1,1)):
|
5 |
+
win_size,k = size
|
6 |
+
mem_mask = torch.zeros((x_len,m_len), device=device)
|
7 |
+
tri_mask = torch.triu(torch.ones((x_len//win_size+1,x_len//win_size+1), device=device),diagonal=k)
|
8 |
+
window_mask = tri_mask.repeat_interleave(win_size,dim=0).repeat_interleave(win_size,dim=1)[:x_len,:x_len]
|
9 |
+
if x_len: window_mask[...,0] = 0 # Always allowing first index to see. Otherwise you'll get NaN loss
|
10 |
+
mask = torch.cat((mem_mask, window_mask), dim=1)[None,None]
|
11 |
+
return mask.bool() if hasattr(mask, 'bool') else mask.byte()
|
12 |
+
|
13 |
+
def rand_window_mask(x_len,m_len,device,max_size:int=None,p:float=0.2,is_eval:bool=False):
|
14 |
+
if is_eval or np.random.rand() >= p or max_size is None:
|
15 |
+
win_size,k = (1,1)
|
16 |
+
else: win_size,k = (np.random.randint(0,max_size)+1,0)
|
17 |
+
return window_mask(x_len, device, m_len, size=(win_size,k))
|
18 |
+
|
19 |
+
def lm_mask(x_len, device):
|
20 |
+
mask = torch.triu(torch.ones((x_len, x_len), device=device), diagonal=1)[None,None]
|
21 |
+
return mask.bool() if hasattr(mask, 'bool') else mask.byte()
|
utils/musicautobot/utils/file_processing.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Parallel processing for midi files"
|
2 |
+
import csv
|
3 |
+
from fastprogress.fastprogress import master_bar, progress_bar
|
4 |
+
from pathlib import Path
|
5 |
+
from pebble import ProcessPool
|
6 |
+
from concurrent.futures import TimeoutError
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# https://stackoverflow.com/questions/20991968/asynchronous-multiprocessing-with-a-worker-pool-in-python-how-to-keep-going-aft
|
10 |
+
def process_all(func, arr, timeout_func=None, total=None, max_workers=None, timeout=None):
|
11 |
+
with ProcessPool() as pool:
|
12 |
+
future = pool.map(func, arr, timeout=timeout)
|
13 |
+
|
14 |
+
iterator = future.result()
|
15 |
+
results = []
|
16 |
+
for i in progress_bar(range(len(arr)), total=len(arr)):
|
17 |
+
try:
|
18 |
+
result = next(iterator)
|
19 |
+
if result: results.append(result)
|
20 |
+
except StopIteration:
|
21 |
+
break
|
22 |
+
except TimeoutError as error:
|
23 |
+
if timeout_func: timeout_func(arr[i], error.args[1])
|
24 |
+
return results
|
25 |
+
|
26 |
+
def process_file(file_path, tfm_func=None, src_path=None, dest_path=None):
|
27 |
+
"Utility function that transforms midi file to numpy array."
|
28 |
+
output_file = Path(str(file_path).replace(str(src_path), str(dest_path))).with_suffix('.npy')
|
29 |
+
if output_file.exists(): return output_file
|
30 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
31 |
+
|
32 |
+
# Call tfm_func and save file
|
33 |
+
npenc = tfm_func(file_path)
|
34 |
+
if npenc is not None:
|
35 |
+
np.save(output_file, npenc)
|
36 |
+
return output_file
|
37 |
+
|
38 |
+
def arr2csv(arr, out_file):
|
39 |
+
"Convert metadata array to csv"
|
40 |
+
all_keys = {k for d in arr for k in d.keys()}
|
41 |
+
arr = [format_values(x) for x in arr]
|
42 |
+
with open(out_file, 'w') as f:
|
43 |
+
dict_writer = csv.DictWriter(f, list(all_keys))
|
44 |
+
dict_writer.writeheader()
|
45 |
+
dict_writer.writerows(arr)
|
46 |
+
|
47 |
+
def format_values(d):
|
48 |
+
"Format array values for csv encoding"
|
49 |
+
def format_value(v):
|
50 |
+
if isinstance(v, list): return ','.join(v)
|
51 |
+
return v
|
52 |
+
return {k:format_value(v) for k,v in d.items()}
|
utils/musicautobot/utils/lamb.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SOURCE: https://github.com/cybertronai/pytorch-lamb/
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.optim import Optimizer
|
8 |
+
|
9 |
+
|
10 |
+
class Lamb(Optimizer):
|
11 |
+
r"""Implements Lamb algorithm.
|
12 |
+
|
13 |
+
It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_.
|
14 |
+
|
15 |
+
Arguments:
|
16 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
17 |
+
parameter groups
|
18 |
+
lr (float, optional): learning rate (default: 1e-3)
|
19 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
20 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
21 |
+
eps (float, optional): term added to the denominator to improve
|
22 |
+
numerical stability (default: 1e-8)
|
23 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
24 |
+
adam (bool, optional): always use trust ratio = 1, which turns this into
|
25 |
+
Adam. Useful for comparison purposes.
|
26 |
+
|
27 |
+
.. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes:
|
28 |
+
https://arxiv.org/abs/1904.00962
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-4,
|
32 |
+
weight_decay=0, adam=False):
|
33 |
+
if not 0.0 <= lr:
|
34 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
35 |
+
if not 0.0 <= eps:
|
36 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
37 |
+
if not 0.0 <= betas[0] < 1.0:
|
38 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
39 |
+
if not 0.0 <= betas[1] < 1.0:
|
40 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
41 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
42 |
+
weight_decay=weight_decay)
|
43 |
+
self.adam = adam
|
44 |
+
super(Lamb, self).__init__(params, defaults)
|
45 |
+
|
46 |
+
def step(self, closure=None):
|
47 |
+
"""Performs a single optimization step.
|
48 |
+
|
49 |
+
Arguments:
|
50 |
+
closure (callable, optional): A closure that reevaluates the model
|
51 |
+
and returns the loss.
|
52 |
+
"""
|
53 |
+
loss = None
|
54 |
+
if closure is not None:
|
55 |
+
loss = closure()
|
56 |
+
|
57 |
+
for group in self.param_groups:
|
58 |
+
for p in group['params']:
|
59 |
+
if p.grad is None:
|
60 |
+
continue
|
61 |
+
grad = p.grad.data
|
62 |
+
if grad.is_sparse:
|
63 |
+
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
64 |
+
|
65 |
+
state = self.state[p]
|
66 |
+
|
67 |
+
# State initialization
|
68 |
+
if len(state) == 0:
|
69 |
+
state['step'] = 0
|
70 |
+
# Exponential moving average of gradient values
|
71 |
+
state['exp_avg'] = torch.zeros_like(p.data)
|
72 |
+
# Exponential moving average of squared gradient values
|
73 |
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
74 |
+
|
75 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
76 |
+
beta1, beta2 = group['betas']
|
77 |
+
|
78 |
+
state['step'] += 1
|
79 |
+
|
80 |
+
if group['weight_decay'] != 0:
|
81 |
+
grad.add_(group['weight_decay'], p.data)
|
82 |
+
|
83 |
+
# Decay the first and second moment running average coefficient
|
84 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
85 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
86 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
87 |
+
|
88 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
89 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
90 |
+
# Apply bias to lr to avoid broadcast.
|
91 |
+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
92 |
+
|
93 |
+
adam_step = exp_avg / denom
|
94 |
+
# L2 norm uses sum, but here since we're dividing, use mean to avoid overflow.
|
95 |
+
r1 = p.data.pow(2).mean().sqrt()
|
96 |
+
r2 = adam_step.pow(2).mean().sqrt()
|
97 |
+
r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10)
|
98 |
+
state['r1'] = r1
|
99 |
+
state['r2'] = r2
|
100 |
+
state['r'] = r
|
101 |
+
if self.adam:
|
102 |
+
r = 1
|
103 |
+
|
104 |
+
p.data.add_(-step_size * r, adam_step)
|
105 |
+
|
106 |
+
return loss
|
utils/musicautobot/utils/midifile.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Transform functions for raw midi files"
|
2 |
+
from enum import Enum
|
3 |
+
import music21
|
4 |
+
|
5 |
+
PIANO_TYPES = list(range(24)) + list(range(80, 96)) # Piano, Synths
|
6 |
+
PLUCK_TYPES = list(range(24, 40)) + list(range(104, 112)) # Guitar, Bass, Ethnic
|
7 |
+
BRIGHT_TYPES = list(range(40, 56)) + list(range(56, 80))
|
8 |
+
|
9 |
+
PIANO_RANGE = (21, 109) # https://en.wikipedia.org/wiki/Scientific_pitch_notation
|
10 |
+
|
11 |
+
class Track(Enum):
|
12 |
+
PIANO = 0 # discrete instruments - keyboard, woodwinds
|
13 |
+
PLUCK = 1 # continuous instruments with pitch bend: violin, trombone, synths
|
14 |
+
BRIGHT = 2
|
15 |
+
PERC = 3
|
16 |
+
UNDEF = 4
|
17 |
+
|
18 |
+
type2inst = {
|
19 |
+
# use print_music21_instruments() to see supported types
|
20 |
+
Track.PIANO: 0, # Piano
|
21 |
+
Track.PLUCK: 24, # Guitar
|
22 |
+
Track.BRIGHT: 40, # Violin
|
23 |
+
Track.PERC: 114, # Steel Drum
|
24 |
+
}
|
25 |
+
|
26 |
+
# INFO_TYPES = set(['TIME_SIGNATURE', 'KEY_SIGNATURE'])
|
27 |
+
INFO_TYPES = set(['TIME_SIGNATURE', 'KEY_SIGNATURE', 'SET_TEMPO'])
|
28 |
+
|
29 |
+
def file2mf(fp):
|
30 |
+
mf = music21.midi.MidiFile()
|
31 |
+
if isinstance(fp, bytes):
|
32 |
+
mf.readstr(fp)
|
33 |
+
else:
|
34 |
+
mf.open(fp)
|
35 |
+
mf.read()
|
36 |
+
mf.close()
|
37 |
+
return mf
|
38 |
+
|
39 |
+
def mf2stream(mf): return music21.midi.translate.midiFileToStream(mf)
|
40 |
+
|
41 |
+
def is_empty_midi(fp):
|
42 |
+
if fp is None: return False
|
43 |
+
mf = file2mf(fp)
|
44 |
+
return not any([t.hasNotes() for t in mf.tracks])
|
45 |
+
|
46 |
+
def num_piano_tracks(fp):
|
47 |
+
music_file = file2mf(fp)
|
48 |
+
note_tracks = [t for t in music_file.tracks if t.hasNotes() and get_track_type(t) == Track.PIANO]
|
49 |
+
return len(note_tracks)
|
50 |
+
|
51 |
+
def is_channel(t, c_val):
|
52 |
+
return any([c == c_val for c in t.getChannels()])
|
53 |
+
|
54 |
+
def track_sort(t): # sort by 1. variation of pitch, 2. number of notes
|
55 |
+
return len(unique_track_notes(t)), len(t.events)
|
56 |
+
|
57 |
+
def is_piano_note(pitch):
|
58 |
+
return (pitch >= PIANO_RANGE[0]) and (pitch < PIANO_RANGE[1])
|
59 |
+
|
60 |
+
def unique_track_notes(t):
|
61 |
+
return { e.pitch for e in t.events if e.pitch is not None }
|
62 |
+
|
63 |
+
def compress_midi_file(fp, cutoff=6, min_variation=3, supported_types=set([Track.PIANO, Track.PLUCK, Track.BRIGHT])):
|
64 |
+
music_file = file2mf(fp)
|
65 |
+
|
66 |
+
info_tracks = [t for t in music_file.tracks if not t.hasNotes()]
|
67 |
+
note_tracks = [t for t in music_file.tracks if t.hasNotes()]
|
68 |
+
|
69 |
+
if len(note_tracks) > cutoff:
|
70 |
+
note_tracks = sorted(note_tracks, key=track_sort, reverse=True)
|
71 |
+
|
72 |
+
supported_tracks = []
|
73 |
+
for idx,t in enumerate(note_tracks):
|
74 |
+
if len(supported_tracks) >= cutoff: break
|
75 |
+
track_type = get_track_type(t)
|
76 |
+
if track_type not in supported_types: continue
|
77 |
+
pitch_set = unique_track_notes(t)
|
78 |
+
if (len(pitch_set) < min_variation): continue # must have more than x unique notes
|
79 |
+
if not all(map(is_piano_note, pitch_set)): continue # must not contain midi notes outside of piano range
|
80 |
+
# if track_type == Track.UNDEF: print('Could not designate track:', fp, t)
|
81 |
+
change_track_instrument(t, type2inst[track_type])
|
82 |
+
supported_tracks.append(t)
|
83 |
+
if not supported_tracks: return None
|
84 |
+
music_file.tracks = info_tracks + supported_tracks
|
85 |
+
return music_file
|
86 |
+
|
87 |
+
def get_track_type(t):
|
88 |
+
if is_channel(t, 10): return Track.PERC
|
89 |
+
i = get_track_instrument(t)
|
90 |
+
if i in PIANO_TYPES: return Track.PIANO
|
91 |
+
if i in PLUCK_TYPES: return Track.PLUCK
|
92 |
+
if i in BRIGHT_TYPES: return Track.BRIGHT
|
93 |
+
return Track.UNDEF
|
94 |
+
|
95 |
+
def get_track_instrument(t):
|
96 |
+
for idx,e in enumerate(t.events):
|
97 |
+
if e.type == 'PROGRAM_CHANGE': return e.data
|
98 |
+
return None
|
99 |
+
|
100 |
+
def change_track_instrument(t, value):
|
101 |
+
for idx,e in enumerate(t.events):
|
102 |
+
if e.type == 'PROGRAM_CHANGE': e.data = value
|
103 |
+
|
104 |
+
def print_music21_instruments():
|
105 |
+
for i in range(200):
|
106 |
+
try: print(i, music21.instrument.instrumentFromMidiProgram(i))
|
107 |
+
except: pass
|
utils/musicautobot/utils/setup_musescore.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def setup_musescore(musescore_path=None):
|
2 |
+
if not is_ipython(): return
|
3 |
+
|
4 |
+
import platform
|
5 |
+
from music21 import environment
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
system = platform.system()
|
9 |
+
if system == 'Linux':
|
10 |
+
import os
|
11 |
+
os.environ['QT_QPA_PLATFORM']='offscreen' # https://musescore.org/en/node/29041
|
12 |
+
|
13 |
+
existing_path = environment.get('musicxmlPath')
|
14 |
+
if existing_path: return
|
15 |
+
if musescore_path is None:
|
16 |
+
if system == 'Darwin':
|
17 |
+
app_paths = list(Path('/Applications').glob('MuseScore *.app'))
|
18 |
+
if len(app_paths): musescore_path = app_paths[-1]/'Contents/MacOS/mscore'
|
19 |
+
elif system == 'Linux':
|
20 |
+
musescore_path = '/usr/bin/musescore'
|
21 |
+
|
22 |
+
if musescore_path is None or not Path(musescore_path).exists():
|
23 |
+
print('Warning: Could not find musescore installation. Please install musescore (see README) and/or update music21 environment paths')
|
24 |
+
else :
|
25 |
+
environment.set('musicxmlPath', musescore_path)
|
26 |
+
environment.set('musescoreDirectPNGPath', musescore_path)
|
27 |
+
|
28 |
+
def is_ipython():
|
29 |
+
try: get_ipython
|
30 |
+
except: return False
|
31 |
+
return True
|
32 |
+
|
33 |
+
def is_colab():
|
34 |
+
try: import google.colab
|
35 |
+
except: return False
|
36 |
+
return True
|
37 |
+
|
38 |
+
def setup_fluidsynth():
|
39 |
+
from midi2audio import FluidSynth
|
40 |
+
from IPython.display import Audio
|
41 |
+
|
42 |
+
def play_wav(stream):
|
43 |
+
out_midi = stream.write('midi')
|
44 |
+
out_wav = str(Path(out_midi).with_suffix('.wav'))
|
45 |
+
FluidSynth("font.sf2").midi_to_audio(out_midi, out_wav)
|
46 |
+
return Audio(out_wav)
|
utils/musicautobot/utils/stacked_dataloader.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Dataloader wrapper that can combine and handle multiple dataloaders for multitask training"
|
2 |
+
from fastai.callback import Callback
|
3 |
+
from typing import Callable
|
4 |
+
|
5 |
+
__all__ = ['StackedDataBunch']
|
6 |
+
|
7 |
+
# DataLoading
|
8 |
+
class StackedDataBunch():
|
9 |
+
def __init__(self, dbs, num_it=100):
|
10 |
+
self.dbs = dbs
|
11 |
+
self.train_dl = StackedDataloader([db.train_dl for db in self.dbs], num_it)
|
12 |
+
self.valid_dl = StackedDataloader([db.valid_dl for db in self.dbs], num_it)
|
13 |
+
self.train_ds = None
|
14 |
+
self.path = dbs[0].path
|
15 |
+
self.device = dbs[0].device
|
16 |
+
self.vocab = dbs[0].vocab
|
17 |
+
self.empty_val = False
|
18 |
+
|
19 |
+
def add_tfm(self,tfm:Callable)->None:
|
20 |
+
for dl in self.dbs: dl.add_tfm(tfm)
|
21 |
+
|
22 |
+
def remove_tfm(self,tfm:Callable)->None:
|
23 |
+
for dl in self.dbs: dl.remove_tfm(tfm)
|
24 |
+
|
25 |
+
# Helper functions
|
26 |
+
class StackedDataset(Callback):
|
27 |
+
def __init__(self, dss):
|
28 |
+
self.dss = dss
|
29 |
+
def __getattribute__(self, attr):
|
30 |
+
if attr == 'dss': return super().__getattribute__(attr)
|
31 |
+
def redirected(*args, **kwargs):
|
32 |
+
for ds in self.dss:
|
33 |
+
if hasattr(ds, attr): getattr(ds, attr)(*args, **kwargs)
|
34 |
+
return redirected
|
35 |
+
def __len__(self)->int: return sum([len(ds) for ds in self.dss])
|
36 |
+
def __repr__(self): return '\n'.join([self.__class__.__name__] + [repr(ds) for ds in self.dss])
|
37 |
+
|
38 |
+
class StackedDataloader():
|
39 |
+
def __init__(self, dls, num_it=100):
|
40 |
+
self.dls = dls
|
41 |
+
self.dataset = StackedDataset([dl.dataset for dl in dls if hasattr(dl, 'dataset')])
|
42 |
+
self.num_it = num_it
|
43 |
+
self.dl_idx = -1
|
44 |
+
|
45 |
+
def __len__(self)->int: return sum([len(dl) for dl in self.dls])
|
46 |
+
def __getattr__(self, attr):
|
47 |
+
def redirected(*args, **kwargs):
|
48 |
+
for dl in self.dls:
|
49 |
+
if hasattr(dl, attr):
|
50 |
+
getattr(dl, attr)(*args, **kwargs)
|
51 |
+
return redirected
|
52 |
+
|
53 |
+
def __iter__(self):
|
54 |
+
"Process and returns items from `DataLoader`."
|
55 |
+
iters = [iter(dl) for dl in self.dls]
|
56 |
+
self.dl_idx = -1
|
57 |
+
while len(iters):
|
58 |
+
self.dl_idx = (self.dl_idx+1) % len(iters)
|
59 |
+
for b in range(self.num_it):
|
60 |
+
try:
|
61 |
+
yield next(iters[self.dl_idx])
|
62 |
+
except StopIteration as e:
|
63 |
+
iters.remove(iters[self.dl_idx])
|
64 |
+
break
|
65 |
+
# raise StopIteration
|
66 |
+
|
67 |
+
def new(self, **kwargs):
|
68 |
+
"Create a new copy of `self` with `kwargs` replacing current values."
|
69 |
+
new_dls = [dl.new(**kwargs) for dl in self.dls]
|
70 |
+
return StackedDataloader(new_dls, self.num_it)
|
utils/musicautobot/utils/top_k_top_p.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
__all__ = ['top_k_top_p']
|
5 |
+
|
6 |
+
# top_k + nucleus filter - https://twitter.com/thom_wolf/status/1124263861727760384?lang=en
|
7 |
+
# https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
8 |
+
def top_k_top_p(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
9 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
10 |
+
Args:
|
11 |
+
logits: logits distribution shape (vocabulary size)
|
12 |
+
top_k >0: keep only top k tokens with highest probability (top-k filtering).
|
13 |
+
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
14 |
+
"""
|
15 |
+
logits = logits.clone()
|
16 |
+
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
17 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
18 |
+
if top_k > 0:
|
19 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
20 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
21 |
+
logits[indices_to_remove] = filter_value
|
22 |
+
|
23 |
+
if top_p > 0.0:
|
24 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
25 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
26 |
+
|
27 |
+
# Remove tokens with cumulative probability above the threshold
|
28 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
29 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
30 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
31 |
+
sorted_indices_to_remove[..., 0] = 0
|
32 |
+
|
33 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
34 |
+
logits[indices_to_remove] = filter_value
|
35 |
+
return logits
|
utils/musicautobot/vocab.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastai.basics import *
|
2 |
+
from .numpy_encode import *
|
3 |
+
from .music_transformer import transform
|
4 |
+
|
5 |
+
BOS = 'xxbos'
|
6 |
+
PAD = 'xxpad'
|
7 |
+
EOS = 'xxeos'
|
8 |
+
MASK = 'xxmask' # Used for BERT masked language modeling.
|
9 |
+
CSEQ = 'xxcseq' # Used for Seq2Seq translation - denotes start of chord sequence
|
10 |
+
MSEQ = 'xxmseq' # Used for Seq2Seq translation - denotes start of melody sequence
|
11 |
+
|
12 |
+
# Deprecated tokens. Kept for compatibility
|
13 |
+
S2SCLS = 'xxs2scls' # deprecated
|
14 |
+
NSCLS = 'xxnscls' # deprecated
|
15 |
+
|
16 |
+
SEP = 'xxsep' # Used to denote end of timestep (required for polyphony). separator idx = -1 (part of notes)
|
17 |
+
|
18 |
+
SPECIAL_TOKS = [BOS, PAD, EOS, S2SCLS, MASK, CSEQ, MSEQ, NSCLS, SEP] # Important: SEP token must be last
|
19 |
+
|
20 |
+
NOTE_TOKS = [f'n{i}' for i in range(NOTE_SIZE)]
|
21 |
+
DUR_TOKS = [f'd{i}' for i in range(DUR_SIZE)]
|
22 |
+
NOTE_START, NOTE_END = NOTE_TOKS[0], NOTE_TOKS[-1]
|
23 |
+
DUR_START, DUR_END = DUR_TOKS[0], DUR_TOKS[-1]
|
24 |
+
|
25 |
+
MTEMPO_SIZE = 10
|
26 |
+
MTEMPO_OFF = 'mt0'
|
27 |
+
MTEMPO_TOKS = [f'mt{i}' for i in range(MTEMPO_SIZE)]
|
28 |
+
|
29 |
+
# Vocab - token to index mapping
|
30 |
+
class MusicVocab():
|
31 |
+
"Contain the correspondence between numbers and tokens and numericalize."
|
32 |
+
def __init__(self, itos:Collection[str]):
|
33 |
+
self.itos = itos
|
34 |
+
self.stoi = {v:k for k,v in enumerate(self.itos)}
|
35 |
+
|
36 |
+
def numericalize(self, t:Collection[str]) -> List[int]:
|
37 |
+
"Convert a list of tokens `t` to their ids."
|
38 |
+
return [self.stoi[w] for w in t]
|
39 |
+
|
40 |
+
def textify(self, nums:Collection[int], sep=' ') -> List[str]:
|
41 |
+
"Convert a list of `nums` to their tokens."
|
42 |
+
items = [self.itos[i] for i in nums]
|
43 |
+
return sep.join(items) if sep is not None else items
|
44 |
+
|
45 |
+
def to_music_item(self, idxenc):
|
46 |
+
return transform.MusicItem(idxenc, self)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def mask_idx(self): return self.stoi[MASK]
|
50 |
+
@property
|
51 |
+
def pad_idx(self): return self.stoi[PAD]
|
52 |
+
@property
|
53 |
+
def bos_idx(self): return self.stoi[BOS]
|
54 |
+
@property
|
55 |
+
def sep_idx(self): return self.stoi[SEP]
|
56 |
+
@property
|
57 |
+
def npenc_range(self): return (self.stoi[SEP], self.stoi[DUR_END]+1)
|
58 |
+
@property
|
59 |
+
def note_range(self): return self.stoi[NOTE_START], self.stoi[NOTE_END]+1
|
60 |
+
@property
|
61 |
+
def dur_range(self): return self.stoi[DUR_START], self.stoi[DUR_END]+1
|
62 |
+
|
63 |
+
def is_duration(self, idx):
|
64 |
+
return idx >= self.dur_range[0] and idx < self.dur_range[1]
|
65 |
+
def is_duration_or_pad(self, idx):
|
66 |
+
return idx == self.pad_idx or self.is_duration(idx)
|
67 |
+
|
68 |
+
def __getstate__(self):
|
69 |
+
return {'itos':self.itos}
|
70 |
+
|
71 |
+
def __setstate__(self, state:dict):
|
72 |
+
self.itos = state['itos']
|
73 |
+
self.stoi = {v:k for k,v in enumerate(self.itos)}
|
74 |
+
|
75 |
+
def __len__(self): return len(self.itos)
|
76 |
+
|
77 |
+
def save(self, path):
|
78 |
+
"Save `self.itos` in `path`"
|
79 |
+
pickle.dump(self.itos, open(path, 'wb'))
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def create(cls) -> 'Vocab':
|
83 |
+
"Create a vocabulary from a set of `tokens`."
|
84 |
+
itos = SPECIAL_TOKS + NOTE_TOKS + DUR_TOKS + MTEMPO_TOKS
|
85 |
+
if len(itos)%8 != 0:
|
86 |
+
itos = itos + [f'dummy{i}' for i in range(len(itos)%8)]
|
87 |
+
return cls(itos)
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def load(cls, path):
|
91 |
+
"Load the `Vocab` contained in `path`"
|
92 |
+
itos = pickle.load(open(path, 'rb'))
|
93 |
+
return cls(itos)
|