amosyou commited on
Commit
7116323
·
1 Parent(s): 6274dcb

feat: add lofi-bytes-api and gradio app

Browse files
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from model.music_transformer import MusicTransformer
8
+ from processor import decode_midi, encode_midi
9
+ from utilities.constants import TOKEN_END, TOKEN_PAD, TORCH_LABEL_TYPE
10
+ from utilities.device import get_device, use_cuda
11
+
12
+ REPO_ID = "Launchpad/lofi-bytes"
13
+ FILENAME = "weights_maestro_finetuned.pickle"
14
+
15
+ SEQUENCE_START = 0
16
+ OUTPUT_PATH = "./output_midi"
17
+ RPR = True
18
+ # TARGET_SEQ_LENGTH = 1023
19
+ TARGET_SEQ_LENGTH = 512
20
+ NUM_PRIME = 65
21
+ MAX_SEQUENCE = 2048
22
+ N_LAYERS = 6
23
+ NUM_HEADS = 8
24
+ D_MODEL = 512
25
+ DIM_FEEDFORWARD = 1024
26
+ BEAM = 0
27
+ FORCE_CPU = False
28
+ ALLOWED_EXTENSIONS = {'mid'}
29
+ UPLOAD_FOLDER = './uploaded_midis'
30
+
31
+ generated_midi = None
32
+
33
+ use_cuda(True)
34
+
35
+ model = MusicTransformer(
36
+ n_layers=N_LAYERS,
37
+ num_heads=NUM_HEADS,
38
+ d_model=D_MODEL,
39
+ dim_feedforward=DIM_FEEDFORWARD,
40
+ max_sequence=MAX_SEQUENCE,
41
+ rpr=RPR
42
+ ).to(get_device())
43
+
44
+ state_dict = torch.load(
45
+ hf_hub_download(repo_id=REPO_ID, filename=FILENAME),
46
+ map_location=get_device()
47
+ )
48
+
49
+ model.load_state_dict(state_dict)
50
+
51
+ def generate(input_midi):
52
+
53
+ raw_mid = encode_midi(input_midi)
54
+ if(len(raw_mid) == 0):
55
+ return
56
+
57
+ primer, _ = process_midi(raw_mid, NUM_PRIME, random_seq=False)
58
+ primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
59
+
60
+ # saves a pretty_midi at file_path
61
+ # decode_midi(primer[:NUM_PRIME].cpu().numpy(), file_path=f_path)
62
+ decode_midi(primer[:NUM_PRIME].cpu().numpy())
63
+
64
+ # GENERATION
65
+ model.eval()
66
+ with torch.set_grad_enabled(False):
67
+
68
+ # NOTE: model.generate() returns a MIDI stored as an ARRAY given a primer
69
+ beam_seq = model.generate(primer[:NUM_PRIME], TARGET_SEQ_LENGTH, beam=BEAM)
70
+
71
+ file_path = "output.mid"
72
+
73
+ # NOTE: function decode_midi() returns an actual MIDI of class pretty_midi.PrettyMIDI
74
+ decoded_midi = decode_midi(beam_seq[0].cpu().numpy(), file_path=file_path)
75
+
76
+ # THIS SHOULD BE EITHER decoded_midi OR beam_seq
77
+ # TODO: decoded_midi is actual pretty_midi MIDI file, beam_seq is just an array representing a MIDI
78
+ # decoded_midi stores more information about instruments and stuff
79
+ return file_path
80
+
81
+ def process_midi(raw_mid, max_seq, random_seq):
82
+ """
83
+ ----------
84
+ Author: Damon Gwinn
85
+ ----------
86
+ Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or
87
+ go from the start based on random_seq.
88
+ ----------
89
+ """
90
+
91
+ x = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
92
+ tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
93
+
94
+ raw_len = len(raw_mid)
95
+ full_seq = max_seq + 1 # Performing seq2seq
96
+
97
+ if(raw_len == 0):
98
+ return x, tgt
99
+
100
+ if(raw_len < full_seq):
101
+ x[:raw_len] = raw_mid
102
+ tgt[:raw_len-1] = raw_mid[1:]
103
+ tgt[raw_len] = TOKEN_END
104
+ else:
105
+ # Randomly selecting a range
106
+ if(random_seq):
107
+ end_range = raw_len - full_seq
108
+ start = random.randint(SEQUENCE_START, end_range)
109
+
110
+ # Always taking from the start to as far as we can
111
+ else:
112
+ start = SEQUENCE_START
113
+
114
+ end = start + full_seq
115
+
116
+ data = raw_mid[start:end]
117
+
118
+ x = data[:max_seq]
119
+ tgt = data[1:full_seq]
120
+
121
+ return x, tgt
122
+
123
+
124
+ with gr.Blocks() as demo:
125
+ with gr.Row():
126
+ with gr.Column(scale=1):
127
+ gr.Image(
128
+ "https://www.ocf.berkeley.edu/~launchpad/media/uploads/project_logos/410912267_278779401866686_2517511436172822307_n_0iVwDxI.png",
129
+ elem_id="logo-img",
130
+ show_label=False,
131
+ show_share_button=False,
132
+ show_download_button=False,
133
+ show_fullscreen_button=False,
134
+ )
135
+
136
+ with gr.Column(scale=3):
137
+ gr.Markdown("""lofi-bytes is a [Launchpad](https://launchpad.studentorg.berkeley.edu/) project (Spring 2023) that generates lofi tracks from input MIDI stamples using a MusicTransformer model.
138
+ <br/><br/>
139
+ **Model**: [lofi-bytes](https://huggingface.co/Launchpad/lofi-bytes)
140
+ <br/>
141
+ **Project Leader**: Alicia Wang
142
+ <br/>
143
+ **Members**: Alena Chao, Eric Liu, Zane Mogannam, Chloe Wong, Iris Zhou
144
+ <br/>
145
+ **Advisors**: Vincent Lim, Winston Liu
146
+ <br/>
147
+ """
148
+ )
149
+ gr.Interface(
150
+ fn=generate,
151
+ inputs=gr.File(),
152
+ outputs=gr.File(),
153
+ examples=["uploaded_midis/ghibli_castle_in_the_sky.mid", "uploaded_midis/am_i_blue_jazz.mid"]
154
+ )
155
+
156
+ if __name__ == '__main__':
157
+ demo.launch(share=True)
model/loss.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.modules.loss import _Loss
5
+
6
+ # Borrowed from https://github.com/jason9693/MusicTransformer-pytorch/blob/5f183374833ff6b7e17f3a24e3594dedd93a5fe5/custom/criterion.py#L28
7
+ class SmoothCrossEntropyLoss(_Loss):
8
+ """
9
+ https://arxiv.org/abs/1512.00567
10
+ """
11
+ __constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction']
12
+
13
+ def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean', is_logits=True):
14
+ assert 0.0 <= label_smoothing <= 1.0
15
+ super().__init__(reduction=reduction)
16
+
17
+ self.label_smoothing = label_smoothing
18
+ self.vocab_size = vocab_size
19
+ self.ignore_index = ignore_index
20
+ self.input_is_logits = is_logits
21
+
22
+ def forward(self, input, target):
23
+ """
24
+ Args:
25
+ input: [B * T, V]
26
+ target: [B * T]
27
+ Returns:
28
+ cross entropy: [1]
29
+ """
30
+ mask = (target == self.ignore_index).unsqueeze(-1)
31
+ q = F.one_hot(target.long(), self.vocab_size).type(torch.float32)
32
+ u = 1.0 / self.vocab_size
33
+ q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u
34
+ q_prime = q_prime.masked_fill(mask, 0)
35
+
36
+ ce = self.cross_entropy_with_logits(q_prime, input)
37
+ if self.reduction == 'mean':
38
+ lengths = torch.sum(target != self.ignore_index)
39
+ return ce.sum() / lengths
40
+ elif self.reduction == 'sum':
41
+ return ce.sum()
42
+ else:
43
+ raise NotImplementedError
44
+
45
+ def cross_entropy_with_logits(self, p, q):
46
+ return -torch.sum(p * (q - q.logsumexp(dim=-1, keepdim=True)), dim=-1)
model/music_transformer.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.modules.normalization import LayerNorm
4
+ import random
5
+
6
+ from utilities.constants import *
7
+ from utilities.device import get_device
8
+
9
+ from .positional_encoding import PositionalEncoding
10
+ from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR
11
+
12
+
13
+ # MusicTransformer
14
+ class MusicTransformer(nn.Module):
15
+ """
16
+ ----------
17
+ Author: Damon Gwinn
18
+ ----------
19
+ Music Transformer reproduction from https://arxiv.org/abs/1809.04281. Arguments allow for
20
+ tweaking the transformer architecture (https://arxiv.org/abs/1706.03762) and the rpr argument
21
+ toggles Relative Position Representations (RPR - https://arxiv.org/abs/1803.02155).
22
+
23
+ Supports training and generation using Pytorch's nn.Transformer class with dummy decoder to
24
+ make a decoder-only transformer architecture
25
+
26
+ For RPR support, there is modified Pytorch 1.2.0 code in rpr.py. Modified source will be
27
+ kept up to date with Pytorch revisions only as necessary.
28
+ ----------
29
+ """
30
+
31
+ def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024,
32
+ dropout=0.1, max_sequence=2048, rpr=False):
33
+ super(MusicTransformer, self).__init__()
34
+
35
+ self.dummy = DummyDecoder()
36
+
37
+ self.nlayers = n_layers
38
+ self.nhead = num_heads
39
+ self.d_model = d_model
40
+ self.d_ff = dim_feedforward
41
+ self.dropout = dropout
42
+ self.max_seq = max_sequence
43
+ self.rpr = rpr
44
+
45
+ # Input embedding
46
+ self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model)
47
+
48
+ # Positional encoding
49
+ self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq)
50
+
51
+ # Base transformer
52
+ if(not self.rpr):
53
+ # To make a decoder-only transformer we need to use masked encoder layers
54
+ # Dummy decoder to essentially just return the encoder output
55
+ self.transformer = nn.Transformer(
56
+ d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
57
+ num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
58
+ dim_feedforward=self.d_ff, custom_decoder=self.dummy
59
+ )
60
+ # RPR Transformer
61
+ else:
62
+ encoder_norm = LayerNorm(self.d_model)
63
+ encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq)
64
+ encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm)
65
+ self.transformer = nn.Transformer(
66
+ d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
67
+ num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
68
+ dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder
69
+ )
70
+
71
+ # Final output is a softmaxed linear layer
72
+ self.Wout = nn.Linear(self.d_model, VOCAB_SIZE)
73
+ self.softmax = nn.Softmax(dim=-1)
74
+
75
+ # forward
76
+ def forward(self, x, mask=True):
77
+ """
78
+ ----------
79
+ Author: Damon Gwinn
80
+ ----------
81
+ Takes an input sequence and outputs predictions using a sequence to sequence method.
82
+
83
+ A prediction at one index is the "next" prediction given all information seen previously.
84
+ ----------
85
+ """
86
+
87
+ if(mask is True):
88
+ mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(get_device())
89
+ else:
90
+ mask = None
91
+
92
+ x = self.embedding(x)
93
+
94
+ # Input shape is (max_seq, batch_size, d_model)
95
+ x = x.permute(1,0,2)
96
+
97
+ x = self.positional_encoding(x)
98
+
99
+ # Since there are no true decoder layers, the tgt is unused
100
+ # Pytorch wants src and tgt to have some equal dims however
101
+ x_out = self.transformer(src=x, tgt=x, src_mask=mask)
102
+
103
+ # Back to (batch_size, max_seq, d_model)
104
+ x_out = x_out.permute(1,0,2)
105
+
106
+ y = self.Wout(x_out)
107
+ # y = self.softmax(y)
108
+
109
+ del mask
110
+
111
+ # They are trained to predict the next note in sequence (we don't need the last one)
112
+ return y
113
+
114
+ # generate
115
+ def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0):
116
+ """
117
+ ----------
118
+ Author: Damon Gwinn
119
+ ----------
120
+ Generates midi given a primer sample. Music can be generated using a probability distribution over
121
+ the softmax probabilities (recommended) or by using a beam search.
122
+ ----------
123
+ """
124
+
125
+ assert (not self.training), "Cannot generate while in training mode"
126
+
127
+ print("Generating sequence of max length:", target_seq_length)
128
+
129
+ gen_seq = torch.full((1,target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
130
+
131
+ num_primer = len(primer)
132
+ gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device())
133
+
134
+
135
+ # print("primer:",primer)
136
+ # print(gen_seq)
137
+ cur_i = num_primer
138
+ while(cur_i < target_seq_length):
139
+ # gen_seq_batch = gen_seq.clone()
140
+ y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END]
141
+ token_probs = y[:, cur_i-1, :]
142
+
143
+ if(beam == 0):
144
+ beam_ran = 2.0
145
+ else:
146
+ beam_ran = random.uniform(0,1)
147
+
148
+ if(beam_ran <= beam_chance):
149
+ token_probs = token_probs.flatten()
150
+ top_res, top_i = torch.topk(token_probs, beam)
151
+
152
+ beam_rows = top_i // VOCAB_SIZE
153
+ beam_cols = top_i % VOCAB_SIZE
154
+
155
+ gen_seq = gen_seq[beam_rows, :]
156
+ gen_seq[..., cur_i] = beam_cols
157
+
158
+ else:
159
+ distrib = torch.distributions.categorical.Categorical(probs=token_probs)
160
+ next_token = distrib.sample()
161
+ # print("next token:",next_token)
162
+ gen_seq[:, cur_i] = next_token
163
+
164
+
165
+ # Let the transformer decide to end if it wants to
166
+ if(next_token == TOKEN_END):
167
+ print("Model called end of sequence at:", cur_i, "/", target_seq_length)
168
+ break
169
+
170
+ cur_i += 1
171
+ if(cur_i % 50 == 0):
172
+ print(cur_i, "/", target_seq_length)
173
+
174
+ return gen_seq[:, :cur_i]
175
+
176
+ # Used as a dummy to nn.Transformer
177
+ # DummyDecoder
178
+ class DummyDecoder(nn.Module):
179
+ """
180
+ ----------
181
+ Author: Damon Gwinn
182
+ ----------
183
+ A dummy decoder that returns its input. Used to make the Pytorch transformer into a decoder-only
184
+ architecture (stacked encoders with dummy decoder fits the bill)
185
+ ----------
186
+ """
187
+
188
+ def __init__(self):
189
+ super(DummyDecoder, self).__init__()
190
+
191
+ def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask, **kwargs):
192
+ """
193
+ ----------
194
+ Author: Damon Gwinn
195
+ ----------
196
+ Returns the input (memory)
197
+ ----------
198
+ """
199
+
200
+ return memory
model/positional_encoding.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ # PositionalEncoding
6
+ # Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
7
+ class PositionalEncoding(nn.Module):
8
+
9
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
10
+ super(PositionalEncoding, self).__init__()
11
+ self.dropout = nn.Dropout(p=dropout)
12
+
13
+ pe = torch.zeros(max_len, d_model)
14
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0).transpose(0, 1)
19
+ self.register_buffer('pe', pe)
20
+
21
+ def forward(self, x):
22
+ x = x + self.pe[:x.size(0), :]
23
+ return self.dropout(x)
model/rpr.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.nn import functional as F
5
+ from torch.nn.parameter import Parameter
6
+ from torch.nn import Module
7
+ from torch.nn.modules.transformer import _get_clones
8
+ from torch.nn.modules.linear import Linear
9
+ from torch.nn.modules.dropout import Dropout
10
+ from torch.nn.modules.normalization import LayerNorm
11
+ from torch.nn.init import *
12
+
13
+ from torch.nn.functional import linear, softmax, dropout
14
+
15
+ # TransformerEncoderRPR
16
+ class TransformerEncoderRPR(Module):
17
+ """
18
+ ----------
19
+ Author: Pytorch
20
+ ----------
21
+ For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
22
+ https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoder
23
+
24
+ No modification. Copied here to ensure continued compatibility with other edits.
25
+ ----------
26
+ """
27
+
28
+ def __init__(self, encoder_layer, num_layers, norm=None):
29
+ super(TransformerEncoderRPR, self).__init__()
30
+ self.layers = _get_clones(encoder_layer, num_layers)
31
+ self.num_layers = num_layers
32
+ self.norm = norm
33
+
34
+ def forward(self, src, mask=None, src_key_padding_mask=None, **kwargs):
35
+
36
+ output = src
37
+
38
+ for i in range(self.num_layers):
39
+ output = self.layers[i](output, src_mask=mask,
40
+ src_key_padding_mask=src_key_padding_mask)
41
+
42
+ if self.norm:
43
+ output = self.norm(output)
44
+
45
+ return output
46
+
47
+ # TransformerEncoderLayerRPR
48
+ class TransformerEncoderLayerRPR(Module):
49
+ """
50
+ ----------
51
+ Author: Pytorch
52
+ Modified: Damon Gwinn
53
+ ----------
54
+ For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
55
+ https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
56
+
57
+ Modification to create and call custom MultiheadAttentionRPR
58
+ ----------
59
+ """
60
+
61
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None):
62
+ super(TransformerEncoderLayerRPR, self).__init__()
63
+ self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len)
64
+ # Implementation of Feedforward model
65
+ self.linear1 = Linear(d_model, dim_feedforward)
66
+ self.dropout = Dropout(dropout)
67
+ self.linear2 = Linear(dim_feedforward, d_model)
68
+
69
+ self.norm1 = LayerNorm(d_model)
70
+ self.norm2 = LayerNorm(d_model)
71
+ self.dropout1 = Dropout(dropout)
72
+ self.dropout2 = Dropout(dropout)
73
+
74
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
75
+ src2 = self.self_attn(src, src, src, attn_mask=src_mask,
76
+ key_padding_mask=src_key_padding_mask)[0]
77
+ src = src + self.dropout1(src2)
78
+ src = self.norm1(src)
79
+ src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
80
+ src = src + self.dropout2(src2)
81
+ src = self.norm2(src)
82
+ return src
83
+
84
+ # MultiheadAttentionRPR
85
+ class MultiheadAttentionRPR(Module):
86
+ """
87
+ ----------
88
+ Author: Pytorch
89
+ Modified: Damon Gwinn
90
+ ----------
91
+ For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
92
+ https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/activation.html#MultiheadAttention
93
+
94
+ Modification to add RPR embedding Er and call custom multi_head_attention_forward_rpr
95
+ ----------
96
+ """
97
+
98
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None):
99
+ super(MultiheadAttentionRPR, self).__init__()
100
+ self.embed_dim = embed_dim
101
+ self.kdim = kdim if kdim is not None else embed_dim
102
+ self.vdim = vdim if vdim is not None else embed_dim
103
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
104
+
105
+ self.num_heads = num_heads
106
+ self.dropout = dropout
107
+ self.head_dim = embed_dim // num_heads
108
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
109
+
110
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
111
+
112
+ if self._qkv_same_embed_dim is False:
113
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
114
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
115
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
116
+
117
+ if bias:
118
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
119
+ else:
120
+ self.register_parameter('in_proj_bias', None)
121
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
122
+
123
+ if add_bias_kv:
124
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
125
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
126
+ else:
127
+ self.bias_k = self.bias_v = None
128
+
129
+ self.add_zero_attn = add_zero_attn
130
+
131
+ # Adding RPR embedding matrix
132
+ if(er_len is not None):
133
+ self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32))
134
+ else:
135
+ self.Er = None
136
+
137
+ self._reset_parameters()
138
+
139
+ def _reset_parameters(self):
140
+ if self._qkv_same_embed_dim:
141
+ xavier_uniform_(self.in_proj_weight)
142
+ else:
143
+ xavier_uniform_(self.q_proj_weight)
144
+ xavier_uniform_(self.k_proj_weight)
145
+ xavier_uniform_(self.v_proj_weight)
146
+
147
+ if self.in_proj_bias is not None:
148
+ constant_(self.in_proj_bias, 0.)
149
+ constant_(self.out_proj.bias, 0.)
150
+ if self.bias_k is not None:
151
+ xavier_normal_(self.bias_k)
152
+ if self.bias_v is not None:
153
+ xavier_normal_(self.bias_v)
154
+
155
+ def forward(self, query, key, value, key_padding_mask=None,
156
+ need_weights=True, attn_mask=None):
157
+
158
+ if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
159
+ # return F.multi_head_attention_forward(
160
+ # query, key, value, self.embed_dim, self.num_heads,
161
+ # self.in_proj_weight, self.in_proj_bias,
162
+ # self.bias_k, self.bias_v, self.add_zero_attn,
163
+ # self.dropout, self.out_proj.weight, self.out_proj.bias,
164
+ # training=self.training,
165
+ # key_padding_mask=key_padding_mask, need_weights=need_weights,
166
+ # attn_mask=attn_mask, use_separate_proj_weight=True,
167
+ # q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
168
+ # v_proj_weight=self.v_proj_weight)
169
+
170
+ return multi_head_attention_forward_rpr(
171
+ query, key, value, self.embed_dim, self.num_heads,
172
+ self.in_proj_weight, self.in_proj_bias,
173
+ self.bias_k, self.bias_v, self.add_zero_attn,
174
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
175
+ training=self.training,
176
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
177
+ attn_mask=attn_mask, use_separate_proj_weight=True,
178
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
179
+ v_proj_weight=self.v_proj_weight, rpr_mat=self.Er)
180
+ else:
181
+ if not hasattr(self, '_qkv_same_embed_dim'):
182
+ warnings.warn('A new version of MultiheadAttention module has been implemented. \
183
+ Please re-train your model with the new module',
184
+ UserWarning)
185
+
186
+ # return F.multi_head_attention_forward(
187
+ # query, key, value, self.embed_dim, self.num_heads,
188
+ # self.in_proj_weight, self.in_proj_bias,
189
+ # self.bias_k, self.bias_v, self.add_zero_attn,
190
+ # self.dropout, self.out_proj.weight, self.out_proj.bias,
191
+ # training=self.training,
192
+ # key_padding_mask=key_padding_mask, need_weights=need_weights,
193
+ # attn_mask=attn_mask)
194
+
195
+ return multi_head_attention_forward_rpr(
196
+ query, key, value, self.embed_dim, self.num_heads,
197
+ self.in_proj_weight, self.in_proj_bias,
198
+ self.bias_k, self.bias_v, self.add_zero_attn,
199
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
200
+ training=self.training,
201
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
202
+ attn_mask=attn_mask, rpr_mat=self.Er)
203
+
204
+ # multi_head_attention_forward_rpr
205
+ def multi_head_attention_forward_rpr(query, # type: Tensor
206
+ key, # type: Tensor
207
+ value, # type: Tensor
208
+ embed_dim_to_check, # type: int
209
+ num_heads, # type: int
210
+ in_proj_weight, # type: Tensor
211
+ in_proj_bias, # type: Tensor
212
+ bias_k, # type: Optional[Tensor]
213
+ bias_v, # type: Optional[Tensor]
214
+ add_zero_attn, # type: bool
215
+ dropout_p, # type: float
216
+ out_proj_weight, # type: Tensor
217
+ out_proj_bias, # type: Tensor
218
+ training=True, # type: bool
219
+ key_padding_mask=None, # type: Optional[Tensor]
220
+ need_weights=True, # type: bool
221
+ attn_mask=None, # type: Optional[Tensor]
222
+ use_separate_proj_weight=False, # type: bool
223
+ q_proj_weight=None, # type: Optional[Tensor]
224
+ k_proj_weight=None, # type: Optional[Tensor]
225
+ v_proj_weight=None, # type: Optional[Tensor]
226
+ static_k=None, # type: Optional[Tensor]
227
+ static_v=None, # type: Optional[Tensor]
228
+ rpr_mat=None
229
+ ):
230
+ """
231
+ ----------
232
+ Author: Pytorch
233
+ Modified: Damon Gwinn
234
+ ----------
235
+ For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
236
+ https://pytorch.org/docs/1.2.0/_modules/torch/nn/functional.html
237
+
238
+ Modification to take RPR embedding matrix and perform skew optimized RPR (https://arxiv.org/abs/1809.04281)
239
+ ----------
240
+ """
241
+
242
+ # type: (...) -> Tuple[Tensor, Optional[Tensor]]
243
+
244
+ qkv_same = torch.equal(query, key) and torch.equal(key, value)
245
+ kv_same = torch.equal(key, value)
246
+
247
+ tgt_len, bsz, embed_dim = query.size()
248
+ assert embed_dim == embed_dim_to_check
249
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
250
+ assert key.size() == value.size()
251
+
252
+ head_dim = embed_dim // num_heads
253
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
254
+ scaling = float(head_dim) ** -0.5
255
+
256
+ if use_separate_proj_weight is not True:
257
+ if qkv_same:
258
+ # self-attention
259
+ q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
260
+
261
+ elif kv_same:
262
+ # encoder-decoder attention
263
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
264
+ _b = in_proj_bias
265
+ _start = 0
266
+ _end = embed_dim
267
+ _w = in_proj_weight[_start:_end, :]
268
+ if _b is not None:
269
+ _b = _b[_start:_end]
270
+ q = linear(query, _w, _b)
271
+
272
+ if key is None:
273
+ assert value is None
274
+ k = None
275
+ v = None
276
+ else:
277
+
278
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
279
+ _b = in_proj_bias
280
+ _start = embed_dim
281
+ _end = None
282
+ _w = in_proj_weight[_start:, :]
283
+ if _b is not None:
284
+ _b = _b[_start:]
285
+ k, v = linear(key, _w, _b).chunk(2, dim=-1)
286
+
287
+ else:
288
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
289
+ _b = in_proj_bias
290
+ _start = 0
291
+ _end = embed_dim
292
+ _w = in_proj_weight[_start:_end, :]
293
+ if _b is not None:
294
+ _b = _b[_start:_end]
295
+ q = linear(query, _w, _b)
296
+
297
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
298
+ _b = in_proj_bias
299
+ _start = embed_dim
300
+ _end = embed_dim * 2
301
+ _w = in_proj_weight[_start:_end, :]
302
+ if _b is not None:
303
+ _b = _b[_start:_end]
304
+ k = linear(key, _w, _b)
305
+
306
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
307
+ _b = in_proj_bias
308
+ _start = embed_dim * 2
309
+ _end = None
310
+ _w = in_proj_weight[_start:, :]
311
+ if _b is not None:
312
+ _b = _b[_start:]
313
+ v = linear(value, _w, _b)
314
+ else:
315
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
316
+ len1, len2 = q_proj_weight_non_opt.size()
317
+ assert len1 == embed_dim and len2 == query.size(-1)
318
+
319
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
320
+ len1, len2 = k_proj_weight_non_opt.size()
321
+ assert len1 == embed_dim and len2 == key.size(-1)
322
+
323
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
324
+ len1, len2 = v_proj_weight_non_opt.size()
325
+ assert len1 == embed_dim and len2 == value.size(-1)
326
+
327
+ if in_proj_bias is not None:
328
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
329
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
330
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
331
+ else:
332
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
333
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
334
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
335
+ q = q * scaling
336
+
337
+ if bias_k is not None and bias_v is not None:
338
+ if static_k is None and static_v is None:
339
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
340
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
341
+ if attn_mask is not None:
342
+ attn_mask = torch.cat([attn_mask,
343
+ torch.zeros((attn_mask.size(0), 1),
344
+ dtype=attn_mask.dtype,
345
+ device=attn_mask.device)], dim=1)
346
+ if key_padding_mask is not None:
347
+ key_padding_mask = torch.cat(
348
+ [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
349
+ dtype=key_padding_mask.dtype,
350
+ device=key_padding_mask.device)], dim=1)
351
+ else:
352
+ assert static_k is None, "bias cannot be added to static key."
353
+ assert static_v is None, "bias cannot be added to static value."
354
+ else:
355
+ assert bias_k is None
356
+ assert bias_v is None
357
+
358
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
359
+ if k is not None:
360
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
361
+ if v is not None:
362
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
363
+
364
+ if static_k is not None:
365
+ assert static_k.size(0) == bsz * num_heads
366
+ assert static_k.size(2) == head_dim
367
+ k = static_k
368
+
369
+ if static_v is not None:
370
+ assert static_v.size(0) == bsz * num_heads
371
+ assert static_v.size(2) == head_dim
372
+ v = static_v
373
+
374
+ src_len = k.size(1)
375
+
376
+ if key_padding_mask is not None:
377
+ assert key_padding_mask.size(0) == bsz
378
+ assert key_padding_mask.size(1) == src_len
379
+
380
+ if add_zero_attn:
381
+ src_len += 1
382
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
383
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
384
+ if attn_mask is not None:
385
+ attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
386
+ dtype=attn_mask.dtype,
387
+ device=attn_mask.device)], dim=1)
388
+ if key_padding_mask is not None:
389
+ key_padding_mask = torch.cat(
390
+ [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
391
+ dtype=key_padding_mask.dtype,
392
+ device=key_padding_mask.device)], dim=1)
393
+
394
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
395
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
396
+
397
+ ######### ADDITION OF RPR ###########
398
+ if(rpr_mat is not None):
399
+ rpr_mat = _get_valid_embedding(rpr_mat, q.shape[1], k.shape[1])
400
+ qe = torch.einsum("hld,md->hlm", q, rpr_mat)
401
+ srel = _skew(qe)
402
+
403
+ attn_output_weights += srel
404
+
405
+ if attn_mask is not None:
406
+ attn_mask = attn_mask.unsqueeze(0)
407
+ attn_output_weights += attn_mask
408
+
409
+ if key_padding_mask is not None:
410
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
411
+ attn_output_weights = attn_output_weights.masked_fill(
412
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
413
+ float('-inf'),
414
+ )
415
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
416
+
417
+ attn_output_weights = softmax(
418
+ attn_output_weights, dim=-1)
419
+
420
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
421
+
422
+ attn_output = torch.bmm(attn_output_weights, v)
423
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
424
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
425
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
426
+
427
+ if need_weights:
428
+ # average attention weights over heads
429
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
430
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
431
+ else:
432
+ return attn_output, None
433
+
434
+ def _get_valid_embedding(Er, len_q, len_k):
435
+ """
436
+ ----------
437
+ Author: Damon Gwinn
438
+ ----------
439
+ Gets valid embeddings based on max length of RPR attention
440
+ ----------
441
+ """
442
+
443
+ len_e = Er.shape[0]
444
+ start = max(0, len_e - len_q)
445
+ return Er[start:, :]
446
+
447
+ def _skew(qe):
448
+ """
449
+ ----------
450
+ Author: Damon Gwinn
451
+ ----------
452
+ Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281)
453
+ ----------
454
+ """
455
+
456
+ sz = qe.shape[1]
457
+ mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0)
458
+
459
+ qe = mask * qe
460
+ qe = F.pad(qe, (1,0, 0,0, 0,0))
461
+ qe = torch.reshape(qe, (qe.shape[0], qe.shape[2], qe.shape[1]))
462
+
463
+ srel = qe[:, 1:, :]
464
+ return srel
processor.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pretty_midi
2
+
3
+
4
+ RANGE_NOTE_ON = 128
5
+ RANGE_NOTE_OFF = 128
6
+ RANGE_VEL = 32
7
+ RANGE_TIME_SHIFT = 100
8
+
9
+ START_IDX = {
10
+ 'note_on': 0,
11
+ 'note_off': RANGE_NOTE_ON,
12
+ 'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF,
13
+ 'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT
14
+ }
15
+
16
+
17
+ class SustainAdapter:
18
+ def __init__(self, time, type):
19
+ self.start = time
20
+ self.type = type
21
+
22
+
23
+ class SustainDownManager:
24
+ def __init__(self, start, end):
25
+ self.start = start
26
+ self.end = end
27
+ self.managed_notes = []
28
+ self._note_dict = {} # key: pitch, value: note.start
29
+
30
+ def add_managed_note(self, note: pretty_midi.Note):
31
+ self.managed_notes.append(note)
32
+
33
+ def transposition_notes(self):
34
+ for note in reversed(self.managed_notes):
35
+ try:
36
+ note.end = self._note_dict[note.pitch]
37
+ except KeyError:
38
+ note.end = max(self.end, note.end)
39
+ self._note_dict[note.pitch] = note.start
40
+
41
+
42
+ # Divided note by note_on, note_off
43
+ class SplitNote:
44
+ def __init__(self, type, time, value, velocity):
45
+ ## type: note_on, note_off
46
+ self.type = type
47
+ self.time = time
48
+ self.velocity = velocity
49
+ self.value = value
50
+
51
+ def __repr__(self):
52
+ return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\
53
+ .format(self.time, self.type, self.value, self.velocity)
54
+
55
+
56
+ class Event:
57
+ def __init__(self, event_type, value):
58
+ self.type = event_type
59
+ self.value = value
60
+
61
+ def __repr__(self):
62
+ return '<Event type: {}, value: {}>'.format(self.type, self.value)
63
+
64
+ def to_int(self):
65
+ return START_IDX[self.type] + self.value
66
+
67
+ @staticmethod
68
+ def from_int(int_value):
69
+ info = Event._type_check(int_value)
70
+ return Event(info['type'], info['value'])
71
+
72
+ @staticmethod
73
+ def _type_check(int_value):
74
+ range_note_on = range(0, RANGE_NOTE_ON)
75
+ range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF)
76
+ range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT)
77
+
78
+ valid_value = int_value
79
+
80
+ if int_value in range_note_on:
81
+ return {'type': 'note_on', 'value': valid_value}
82
+ elif int_value in range_note_off:
83
+ valid_value -= RANGE_NOTE_ON
84
+ return {'type': 'note_off', 'value': valid_value}
85
+ elif int_value in range_time_shift:
86
+ valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF)
87
+ return {'type': 'time_shift', 'value': valid_value}
88
+ else:
89
+ valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT)
90
+ return {'type': 'velocity', 'value': valid_value}
91
+
92
+
93
+ def _divide_note(notes):
94
+ result_array = []
95
+ notes.sort(key=lambda x: x.start)
96
+
97
+ for note in notes:
98
+ on = SplitNote('note_on', note.start, note.pitch, note.velocity)
99
+ off = SplitNote('note_off', note.end, note.pitch, None)
100
+ result_array += [on, off]
101
+ return result_array
102
+
103
+
104
+ def _merge_note(snote_sequence):
105
+ note_on_dict = {}
106
+ result_array = []
107
+
108
+ for snote in snote_sequence:
109
+ # print(note_on_dict)
110
+ if snote.type == 'note_on':
111
+ note_on_dict[snote.value] = snote
112
+ elif snote.type == 'note_off':
113
+ try:
114
+ on = note_on_dict[snote.value]
115
+ off = snote
116
+ if off.time - on.time == 0:
117
+ continue
118
+ result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time)
119
+ result_array.append(result)
120
+ except:
121
+ print('info removed pitch: {}'.format(snote.value))
122
+ return result_array
123
+
124
+
125
+ def _snote2events(snote: SplitNote, prev_vel: int):
126
+ result = []
127
+ if snote.velocity is not None:
128
+ modified_velocity = snote.velocity // 4
129
+ if prev_vel != modified_velocity:
130
+ result.append(Event(event_type='velocity', value=modified_velocity))
131
+ result.append(Event(event_type=snote.type, value=snote.value))
132
+ return result
133
+
134
+
135
+ def _event_seq2snote_seq(event_sequence):
136
+ timeline = 0
137
+ velocity = 0
138
+ snote_seq = []
139
+
140
+ for event in event_sequence:
141
+ if event.type == 'time_shift':
142
+ timeline += ((event.value+1) / 100)
143
+ if event.type == 'velocity':
144
+ velocity = event.value * 4
145
+ else:
146
+ snote = SplitNote(event.type, timeline, event.value, velocity)
147
+ snote_seq.append(snote)
148
+ return snote_seq
149
+
150
+
151
+ def _make_time_sift_events(prev_time, post_time):
152
+ time_interval = int(round((post_time - prev_time) * 100))
153
+ results = []
154
+ while time_interval >= RANGE_TIME_SHIFT:
155
+ results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1))
156
+ time_interval -= RANGE_TIME_SHIFT
157
+ if time_interval == 0:
158
+ return results
159
+ else:
160
+ return results + [Event(event_type='time_shift', value=time_interval-1)]
161
+
162
+
163
+ def _control_preprocess(ctrl_changes):
164
+ sustains = []
165
+
166
+ manager = None
167
+ for ctrl in ctrl_changes:
168
+ if ctrl.value >= 64 and manager is None:
169
+ # sustain down
170
+ manager = SustainDownManager(start=ctrl.time, end=None)
171
+ elif ctrl.value < 64 and manager is not None:
172
+ # sustain up
173
+ manager.end = ctrl.time
174
+ sustains.append(manager)
175
+ manager = None
176
+ elif ctrl.value < 64 and len(sustains) > 0:
177
+ sustains[-1].end = ctrl.time
178
+ return sustains
179
+
180
+
181
+ def _note_preprocess(susteins, notes):
182
+ note_stream = []
183
+
184
+ if susteins: # if the midi file has sustain controls
185
+ for sustain in susteins:
186
+ for note_idx, note in enumerate(notes):
187
+ if note.start < sustain.start:
188
+ note_stream.append(note)
189
+ elif note.start > sustain.end:
190
+ notes = notes[note_idx:]
191
+ sustain.transposition_notes()
192
+ break
193
+ else:
194
+ sustain.add_managed_note(note)
195
+
196
+ for sustain in susteins:
197
+ note_stream += sustain.managed_notes
198
+
199
+ else: # else, just push everything into note stream
200
+ for note_idx, note in enumerate(notes):
201
+ note_stream.append(note)
202
+
203
+ note_stream.sort(key= lambda x: x.start)
204
+ return note_stream
205
+
206
+
207
+ def encode_midi(file_path):
208
+ events = []
209
+ notes = []
210
+ mid = pretty_midi.PrettyMIDI(midi_file=file_path)
211
+
212
+ for inst in mid.instruments:
213
+ inst_notes = inst.notes
214
+ # ctrl.number is the number of sustain control. If you want to know abour the number type of control,
215
+ # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2
216
+ ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64])
217
+ notes += _note_preprocess(ctrls, inst_notes)
218
+
219
+ dnotes = _divide_note(notes)
220
+
221
+ # print(dnotes)
222
+ dnotes.sort(key=lambda x: x.time)
223
+ # print('sorted:')
224
+ # print(dnotes)
225
+ cur_time = 0
226
+ cur_vel = 0
227
+ for snote in dnotes:
228
+ events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
229
+ events += _snote2events(snote=snote, prev_vel=cur_vel)
230
+ # events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
231
+
232
+ cur_time = snote.time
233
+ cur_vel = snote.velocity
234
+
235
+ return [e.to_int() for e in events]
236
+
237
+
238
+ def decode_midi(idx_array, file_path=None):
239
+ event_sequence = [Event.from_int(idx) for idx in idx_array]
240
+ # print(event_sequence)
241
+ snote_seq = _event_seq2snote_seq(event_sequence)
242
+ note_seq = _merge_note(snote_seq)
243
+ note_seq.sort(key=lambda x:x.start)
244
+
245
+ mid = pretty_midi.PrettyMIDI()
246
+ # if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
247
+ instument = pretty_midi.Instrument(0, False, "Composed by Super Piano Music Transformer AI")
248
+ instument.notes = note_seq
249
+
250
+ mid.instruments.append(instument)
251
+ if file_path is not None:
252
+ mid.write(file_path)
253
+ return mid
254
+
255
+
256
+ if __name__ == '__main__':
257
+ encoded = encode_midi('bin/ADIG04.mid')
258
+ print(encoded)
259
+ decided = decode_midi(encoded,file_path='bin/test.mid')
260
+
261
+ ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid')
262
+ print(ins)
263
+ print(ins.instruments[0])
264
+ for i in ins.instruments:
265
+ print(i.control_changes)
266
+ print(i.notes)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ huggingface_hub
3
+ pretty_midi
4
+ setuptools
5
+ spaces
6
+ torch
uploaded_midis/am_i_blue_jazz.mid ADDED
Binary file (21.5 kB). View file
 
uploaded_midis/ghibli_castle_in_the_sky.mid ADDED
Binary file (2.81 kB). View file
 
utilities/argument_funcs.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from .constants import SEPERATOR
4
+
5
+ # parse_train_args
6
+ def parse_train_args():
7
+ """
8
+ ----------
9
+ Author: Damon Gwinn
10
+ ----------
11
+ Argparse arguments for training a model
12
+ ----------
13
+ """
14
+
15
+ parser = argparse.ArgumentParser()
16
+
17
+ parser.add_argument("-input_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
18
+ parser.add_argument("-output_dir", type=str, default="./saved_models", help="Folder to save model weights. Saves one every epoch")
19
+ parser.add_argument("-weight_modulus", type=int, default=1, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)")
20
+ parser.add_argument("-print_modulus", type=int, default=1, help="How often to print train results for a batch (batch loss, learn rate, etc.)")
21
+
22
+ parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
23
+ parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
24
+ parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting")
25
+
26
+ parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on")
27
+ parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at")
28
+
29
+ parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.")
30
+ parser.add_argument("-ce_smoothing", type=float, default=None, help="Smoothing parameter for smoothed cross entropy loss (defaults to no smoothing)")
31
+ parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
32
+ parser.add_argument("-epochs", type=int, default=100, help="Number of epochs to use")
33
+
34
+ parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
35
+ parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
36
+ parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
37
+ parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
38
+ parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
39
+
40
+ parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
41
+
42
+ parser.add_argument("-dropout", type=float, default=0.1, help="Dropout rate")
43
+
44
+ return parser.parse_args()
45
+
46
+ # print_train_args
47
+ def print_train_args(args):
48
+ """
49
+ ----------
50
+ Author: Damon Gwinn
51
+ ----------
52
+ Prints training arguments
53
+ ----------
54
+ """
55
+
56
+ print(SEPERATOR)
57
+ print("input_dir:", args.input_dir)
58
+ print("output_dir:", args.output_dir)
59
+ print("weight_modulus:", args.weight_modulus)
60
+ print("print_modulus:", args.print_modulus)
61
+ print("")
62
+ print("n_workers:", args.n_workers)
63
+ print("force_cpu:", args.force_cpu)
64
+ print("tensorboard:", not args.no_tensorboard)
65
+ print("")
66
+ print("continue_weights:", args.continue_weights)
67
+ print("continue_epoch:", args.continue_epoch)
68
+ print("")
69
+ print("lr:", args.lr)
70
+ print("ce_smoothing:", args.ce_smoothing)
71
+ print("batch_size:", args.batch_size)
72
+ print("epochs:", args.epochs)
73
+ print("")
74
+ print("rpr:", args.rpr)
75
+ print("max_sequence:", args.max_sequence)
76
+ print("n_layers:", args.n_layers)
77
+ print("num_heads:", args.num_heads)
78
+ print("d_model:", args.d_model)
79
+ print("")
80
+ print("dim_feedforward:", args.dim_feedforward)
81
+ print("dropout:", args.dropout)
82
+ print(SEPERATOR)
83
+ print("")
84
+
85
+ # parse_eval_args
86
+ def parse_eval_args():
87
+ """
88
+ ----------
89
+ Author: Damon Gwinn
90
+ ----------
91
+ Argparse arguments for evaluating a model
92
+ ----------
93
+ """
94
+
95
+ parser = argparse.ArgumentParser()
96
+
97
+ parser.add_argument("-dataset_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
98
+ parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
99
+ parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
100
+ parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
101
+
102
+ parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
103
+
104
+ parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
105
+ parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider in the model")
106
+ parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
107
+ parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
108
+ parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
109
+
110
+ parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
111
+
112
+ return parser.parse_args()
113
+
114
+ # print_eval_args
115
+ def print_eval_args(args):
116
+ """
117
+ ----------
118
+ Author: Damon Gwinn
119
+ ----------
120
+ Prints evaluation arguments
121
+ ----------
122
+ """
123
+
124
+ print(SEPERATOR)
125
+ print("dataset_dir:", args.dataset_dir)
126
+ print("model_weights:", args.model_weights)
127
+ print("n_workers:", args.n_workers)
128
+ print("force_cpu:", args.force_cpu)
129
+ print("")
130
+ print("batch_size:", args.batch_size)
131
+ print("")
132
+ print("rpr:", args.rpr)
133
+ print("max_sequence:", args.max_sequence)
134
+ print("n_layers:", args.n_layers)
135
+ print("num_heads:", args.num_heads)
136
+ print("d_model:", args.d_model)
137
+ print("")
138
+ print("dim_feedforward:", args.dim_feedforward)
139
+ print(SEPERATOR)
140
+ print("")
141
+
142
+ # parse_generate_args
143
+ def parse_generate_args():
144
+ """
145
+ ----------
146
+ Author: Damon Gwinn
147
+ ----------
148
+ Argparse arguments for generation
149
+ ----------
150
+ """
151
+
152
+ parser = argparse.ArgumentParser()
153
+
154
+ parser.add_argument("-midi_root", type=str, default="./dataset/e_piano/", help="Midi file to prime the generator with")
155
+ parser.add_argument("-output_dir", type=str, default="./gen", help="Folder to write generated midi to")
156
+ parser.add_argument("-primer_file", type=str, default=None, help="File path or integer index to the evaluation dataset. Default is to select a random index.")
157
+ parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
158
+
159
+ parser.add_argument("-target_seq_length", type=int, default=1024, help="Target length you'd like the midi to be")
160
+ parser.add_argument("-num_prime", type=int, default=256, help="Amount of messages to prime the generator with")
161
+ parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
162
+ parser.add_argument("-beam", type=int, default=0, help="Beam search k. 0 for random probability sample and 1 for greedy")
163
+
164
+ parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
165
+ parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
166
+ parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
167
+ parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
168
+ parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
169
+
170
+ parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
171
+
172
+ return parser.parse_args()
173
+
174
+ # print_generate_args
175
+ def print_generate_args(args):
176
+ """
177
+ ----------
178
+ Author: Damon Gwinn
179
+ ----------
180
+ Prints generation arguments
181
+ ----------
182
+ """
183
+
184
+ print(SEPERATOR)
185
+ print("midi_root:", args.midi_root)
186
+ print("output_dir:", args.output_dir)
187
+ print("primer_file:", args.primer_file)
188
+ print("force_cpu:", args.force_cpu)
189
+ print("")
190
+ print("target_seq_length:", args.target_seq_length)
191
+ print("num_prime:", args.num_prime)
192
+ print("model_weights:", args.model_weights)
193
+ print("beam:", args.beam)
194
+ print("")
195
+ print("rpr:", args.rpr)
196
+ print("max_sequence:", args.max_sequence)
197
+ print("n_layers:", args.n_layers)
198
+ print("num_heads:", args.num_heads)
199
+ print("d_model:", args.d_model)
200
+ print("")
201
+ print("dim_feedforward:", args.dim_feedforward)
202
+ print(SEPERATOR)
203
+ print("")
204
+
205
+ # write_model_params
206
+ def write_model_params(args, output_file):
207
+ """
208
+ ----------
209
+ Author: Damon Gwinn
210
+ ----------
211
+ Writes given training parameters to text file
212
+ ----------
213
+ """
214
+
215
+ o_stream = open(output_file, "w")
216
+
217
+ o_stream.write("rpr: " + str(args.rpr) + "\n")
218
+ o_stream.write("lr: " + str(args.lr) + "\n")
219
+ o_stream.write("ce_smoothing: " + str(args.ce_smoothing) + "\n")
220
+ o_stream.write("batch_size: " + str(args.batch_size) + "\n")
221
+ o_stream.write("max_sequence: " + str(args.max_sequence) + "\n")
222
+ o_stream.write("n_layers: " + str(args.n_layers) + "\n")
223
+ o_stream.write("num_heads: " + str(args.num_heads) + "\n")
224
+ o_stream.write("d_model: " + str(args.d_model) + "\n")
225
+ o_stream.write("dim_feedforward: " + str(args.dim_feedforward) + "\n")
226
+ o_stream.write("dropout: " + str(args.dropout) + "\n")
227
+
228
+ o_stream.close()
utilities/constants.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from processor import RANGE_NOTE_ON, RANGE_NOTE_OFF, RANGE_VEL, RANGE_TIME_SHIFT
4
+
5
+ SEPERATOR = "========================="
6
+
7
+ # Taken from the paper
8
+ ADAM_BETA_1 = 0.9
9
+ ADAM_BETA_2 = 0.98
10
+ ADAM_EPSILON = 10e-9
11
+
12
+ LR_DEFAULT_START = 1.0
13
+ SCHEDULER_WARMUP_STEPS = 4000
14
+ # LABEL_SMOOTHING_E = 0.1
15
+
16
+ # DROPOUT_P = 0.1
17
+
18
+ TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
19
+ TOKEN_PAD = TOKEN_END + 1
20
+
21
+ VOCAB_SIZE = TOKEN_PAD + 1
22
+
23
+ TORCH_FLOAT = torch.float32
24
+ TORCH_INT = torch.int32
25
+
26
+ TORCH_LABEL_TYPE = torch.long
27
+
28
+ PREPEND_ZEROS_WIDTH = 4
utilities/device.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # For all things related to devices
2
+ #### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS ####
3
+
4
+ import torch
5
+ import os
6
+
7
+ # change cuda devices to ones that are available after running nvidia-smi.
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = '3,4,5'
9
+
10
+ TORCH_CPU_DEVICE = torch.device("cpu")
11
+
12
+ # If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
13
+ if(torch.cuda.device_count() > 0):
14
+ TORCH_CUDA_DEVICE = torch.device("cuda")
15
+
16
+ else:
17
+ print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
18
+ print("")
19
+ TORCH_CUDA_DEVICE = None
20
+
21
+ USE_CUDA = True
22
+
23
+ # use_cuda
24
+ def use_cuda(cuda_bool):
25
+ """
26
+ ----------
27
+ Author: Damon Gwinn
28
+ ----------
29
+ Sets whether to use CUDA (if available), or use the CPU (not recommended)
30
+ ----------
31
+ """
32
+
33
+ global USE_CUDA
34
+ USE_CUDA = cuda_bool
35
+
36
+ # get_device
37
+ def get_device():
38
+ """
39
+ ----------
40
+ Author: Damon Gwinn
41
+ ----------
42
+ Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise.
43
+ ----------
44
+ """
45
+
46
+ if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)):
47
+ return TORCH_CPU_DEVICE
48
+ else:
49
+ return TORCH_CUDA_DEVICE
50
+
51
+ # cuda_device
52
+ def cuda_device():
53
+ """
54
+ ----------
55
+ Author: Damon Gwinn
56
+ ----------
57
+ Grabs the cuda device (may be None if CUDA is not available)
58
+ ----------
59
+ """
60
+
61
+ return TORCH_CUDA_DEVICE
62
+
63
+ # cpu_device
64
+ def cpu_device():
65
+ """
66
+ ----------
67
+ Author: Damon Gwinn
68
+ ----------
69
+ Grabs the cpu device
70
+ ----------
71
+ """
72
+
73
+ return TORCH_CPU_DEVICE
utilities/lr_scheduling.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Library Imports
2
+ import math
3
+
4
+ #Using Adam optimizer with
5
+ #Beta_1=0.9, Beta_2=0.98, and Epsilon=10^-9
6
+
7
+ #Learning rate varies over course of training
8
+ #lrate = sqrt(d_model)*min((1/sqrt(step_num)), step_num*(1/warmup_steps*sqrt(warmup_steps)))
9
+
10
+ # LrStepTracker
11
+ class LrStepTracker:
12
+ """
13
+ ----------
14
+ Author: Ryan Marshall
15
+ Modified: Damon Gwinn
16
+ ----------
17
+ Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR).
18
+
19
+ Learn rate for each step (batch) given the warmup steps is:
20
+ lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ]
21
+
22
+ This is from Attention is All you Need (https://arxiv.org/abs/1706.03762)
23
+ ----------
24
+ """
25
+
26
+ def __init__(self, model_dim=512, warmup_steps=4000, init_steps=0):
27
+ # Store Values
28
+ self.warmup_steps = warmup_steps
29
+ self.model_dim = model_dim
30
+ self.init_steps = init_steps
31
+
32
+ # Begin Calculations
33
+ self.invsqrt_dim = (1 / math.sqrt(model_dim))
34
+ self.invsqrt_warmup = (1 / (warmup_steps * math.sqrt(warmup_steps)))
35
+
36
+ # step
37
+ def step(self, step):
38
+ """
39
+ ----------
40
+ Author: Ryan Marshall
41
+ Modified: Damon Gwinn
42
+ ----------
43
+ Method to pass to LambdaLR. Increments the step and computes the new learn rate.
44
+ ----------
45
+ """
46
+
47
+ step += self.init_steps
48
+ if(step <= self.warmup_steps):
49
+ return self.invsqrt_dim * self.invsqrt_warmup * step
50
+ else:
51
+ invsqrt_step = (1 / math.sqrt(step))
52
+ return self.invsqrt_dim * invsqrt_step
53
+
54
+ # get_lr
55
+ def get_lr(optimizer):
56
+ """
57
+ ----------
58
+ Author: Damon Gwinn
59
+ ----------
60
+ Hack to get the current learn rate of the model
61
+ ----------
62
+ """
63
+
64
+ for param_group in optimizer.param_groups:
65
+ return param_group['lr']
utilities/run_model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+
4
+ from .constants import *
5
+ from utilities.device import get_device
6
+ from .lr_scheduling import get_lr
7
+
8
+ from dataset.e_piano import compute_epiano_accuracy
9
+
10
+
11
+ # train_epoch
12
+ def train_epoch(cur_epoch, model, dataloader, loss, opt, lr_scheduler=None, print_modulus=1):
13
+ """
14
+ ----------
15
+ Author: Damon Gwinn
16
+ ----------
17
+ Trains a single model epoch
18
+ ----------
19
+ """
20
+
21
+ out = -1
22
+ model.train()
23
+ for batch_num, batch in enumerate(dataloader):
24
+ time_before = time.time()
25
+
26
+ opt.zero_grad()
27
+
28
+ x = batch[0].to(get_device())
29
+ tgt = batch[1].to(get_device())
30
+
31
+ y = model(x)
32
+
33
+ y = y.reshape(y.shape[0] * y.shape[1], -1)
34
+ tgt = tgt.flatten()
35
+
36
+ out = loss.forward(y, tgt)
37
+
38
+ out.backward()
39
+ opt.step()
40
+
41
+ if(lr_scheduler is not None):
42
+ lr_scheduler.step()
43
+
44
+ time_after = time.time()
45
+ time_took = time_after - time_before
46
+
47
+ if((batch_num+1) % print_modulus == 0):
48
+ print(SEPERATOR)
49
+ print("Epoch", cur_epoch, " Batch", batch_num+1, "/", len(dataloader))
50
+ print("LR:", get_lr(opt))
51
+ print("Train loss:", float(out))
52
+ print("")
53
+ print("Time (s):", time_took)
54
+ print(SEPERATOR)
55
+ print("")
56
+
57
+ return
58
+
59
+ # eval_model
60
+ def eval_model(model, dataloader, loss):
61
+ """
62
+ ----------
63
+ Author: Damon Gwinn
64
+ ----------
65
+ Evaluates the model and prints the average loss and accuracy
66
+ ----------
67
+ """
68
+
69
+ model.eval()
70
+
71
+ avg_acc = -1
72
+ avg_loss = -1
73
+ with torch.set_grad_enabled(False):
74
+ n_test = len(dataloader)
75
+ sum_loss = 0.0
76
+ sum_acc = 0.0
77
+ for batch in dataloader:
78
+ x = batch[0].to(get_device())
79
+ tgt = batch[1].to(get_device())
80
+
81
+ y = model(x)
82
+
83
+ sum_acc += float(compute_epiano_accuracy(y, tgt))
84
+
85
+ y = y.reshape(y.shape[0] * y.shape[1], -1)
86
+ tgt = tgt.flatten()
87
+
88
+ out = loss.forward(y, tgt)
89
+
90
+ sum_loss += float(out)
91
+
92
+ avg_loss = sum_loss / n_test
93
+ avg_acc = sum_acc / n_test
94
+
95
+ return avg_loss, avg_acc