Spaces:
Runtime error
Runtime error
feat: add lofi-bytes-api and gradio app
Browse files- app.py +157 -0
- model/loss.py +46 -0
- model/music_transformer.py +200 -0
- model/positional_encoding.py +23 -0
- model/rpr.py +464 -0
- processor.py +266 -0
- requirements.txt +6 -0
- uploaded_midis/am_i_blue_jazz.mid +0 -0
- uploaded_midis/ghibli_castle_in_the_sky.mid +0 -0
- utilities/argument_funcs.py +228 -0
- utilities/constants.py +28 -0
- utilities/device.py +73 -0
- utilities/lr_scheduling.py +65 -0
- utilities/run_model.py +95 -0
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
|
7 |
+
from model.music_transformer import MusicTransformer
|
8 |
+
from processor import decode_midi, encode_midi
|
9 |
+
from utilities.constants import TOKEN_END, TOKEN_PAD, TORCH_LABEL_TYPE
|
10 |
+
from utilities.device import get_device, use_cuda
|
11 |
+
|
12 |
+
REPO_ID = "Launchpad/lofi-bytes"
|
13 |
+
FILENAME = "weights_maestro_finetuned.pickle"
|
14 |
+
|
15 |
+
SEQUENCE_START = 0
|
16 |
+
OUTPUT_PATH = "./output_midi"
|
17 |
+
RPR = True
|
18 |
+
# TARGET_SEQ_LENGTH = 1023
|
19 |
+
TARGET_SEQ_LENGTH = 512
|
20 |
+
NUM_PRIME = 65
|
21 |
+
MAX_SEQUENCE = 2048
|
22 |
+
N_LAYERS = 6
|
23 |
+
NUM_HEADS = 8
|
24 |
+
D_MODEL = 512
|
25 |
+
DIM_FEEDFORWARD = 1024
|
26 |
+
BEAM = 0
|
27 |
+
FORCE_CPU = False
|
28 |
+
ALLOWED_EXTENSIONS = {'mid'}
|
29 |
+
UPLOAD_FOLDER = './uploaded_midis'
|
30 |
+
|
31 |
+
generated_midi = None
|
32 |
+
|
33 |
+
use_cuda(True)
|
34 |
+
|
35 |
+
model = MusicTransformer(
|
36 |
+
n_layers=N_LAYERS,
|
37 |
+
num_heads=NUM_HEADS,
|
38 |
+
d_model=D_MODEL,
|
39 |
+
dim_feedforward=DIM_FEEDFORWARD,
|
40 |
+
max_sequence=MAX_SEQUENCE,
|
41 |
+
rpr=RPR
|
42 |
+
).to(get_device())
|
43 |
+
|
44 |
+
state_dict = torch.load(
|
45 |
+
hf_hub_download(repo_id=REPO_ID, filename=FILENAME),
|
46 |
+
map_location=get_device()
|
47 |
+
)
|
48 |
+
|
49 |
+
model.load_state_dict(state_dict)
|
50 |
+
|
51 |
+
def generate(input_midi):
|
52 |
+
|
53 |
+
raw_mid = encode_midi(input_midi)
|
54 |
+
if(len(raw_mid) == 0):
|
55 |
+
return
|
56 |
+
|
57 |
+
primer, _ = process_midi(raw_mid, NUM_PRIME, random_seq=False)
|
58 |
+
primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
|
59 |
+
|
60 |
+
# saves a pretty_midi at file_path
|
61 |
+
# decode_midi(primer[:NUM_PRIME].cpu().numpy(), file_path=f_path)
|
62 |
+
decode_midi(primer[:NUM_PRIME].cpu().numpy())
|
63 |
+
|
64 |
+
# GENERATION
|
65 |
+
model.eval()
|
66 |
+
with torch.set_grad_enabled(False):
|
67 |
+
|
68 |
+
# NOTE: model.generate() returns a MIDI stored as an ARRAY given a primer
|
69 |
+
beam_seq = model.generate(primer[:NUM_PRIME], TARGET_SEQ_LENGTH, beam=BEAM)
|
70 |
+
|
71 |
+
file_path = "output.mid"
|
72 |
+
|
73 |
+
# NOTE: function decode_midi() returns an actual MIDI of class pretty_midi.PrettyMIDI
|
74 |
+
decoded_midi = decode_midi(beam_seq[0].cpu().numpy(), file_path=file_path)
|
75 |
+
|
76 |
+
# THIS SHOULD BE EITHER decoded_midi OR beam_seq
|
77 |
+
# TODO: decoded_midi is actual pretty_midi MIDI file, beam_seq is just an array representing a MIDI
|
78 |
+
# decoded_midi stores more information about instruments and stuff
|
79 |
+
return file_path
|
80 |
+
|
81 |
+
def process_midi(raw_mid, max_seq, random_seq):
|
82 |
+
"""
|
83 |
+
----------
|
84 |
+
Author: Damon Gwinn
|
85 |
+
----------
|
86 |
+
Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or
|
87 |
+
go from the start based on random_seq.
|
88 |
+
----------
|
89 |
+
"""
|
90 |
+
|
91 |
+
x = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
|
92 |
+
tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
|
93 |
+
|
94 |
+
raw_len = len(raw_mid)
|
95 |
+
full_seq = max_seq + 1 # Performing seq2seq
|
96 |
+
|
97 |
+
if(raw_len == 0):
|
98 |
+
return x, tgt
|
99 |
+
|
100 |
+
if(raw_len < full_seq):
|
101 |
+
x[:raw_len] = raw_mid
|
102 |
+
tgt[:raw_len-1] = raw_mid[1:]
|
103 |
+
tgt[raw_len] = TOKEN_END
|
104 |
+
else:
|
105 |
+
# Randomly selecting a range
|
106 |
+
if(random_seq):
|
107 |
+
end_range = raw_len - full_seq
|
108 |
+
start = random.randint(SEQUENCE_START, end_range)
|
109 |
+
|
110 |
+
# Always taking from the start to as far as we can
|
111 |
+
else:
|
112 |
+
start = SEQUENCE_START
|
113 |
+
|
114 |
+
end = start + full_seq
|
115 |
+
|
116 |
+
data = raw_mid[start:end]
|
117 |
+
|
118 |
+
x = data[:max_seq]
|
119 |
+
tgt = data[1:full_seq]
|
120 |
+
|
121 |
+
return x, tgt
|
122 |
+
|
123 |
+
|
124 |
+
with gr.Blocks() as demo:
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Column(scale=1):
|
127 |
+
gr.Image(
|
128 |
+
"https://www.ocf.berkeley.edu/~launchpad/media/uploads/project_logos/410912267_278779401866686_2517511436172822307_n_0iVwDxI.png",
|
129 |
+
elem_id="logo-img",
|
130 |
+
show_label=False,
|
131 |
+
show_share_button=False,
|
132 |
+
show_download_button=False,
|
133 |
+
show_fullscreen_button=False,
|
134 |
+
)
|
135 |
+
|
136 |
+
with gr.Column(scale=3):
|
137 |
+
gr.Markdown("""lofi-bytes is a [Launchpad](https://launchpad.studentorg.berkeley.edu/) project (Spring 2023) that generates lofi tracks from input MIDI stamples using a MusicTransformer model.
|
138 |
+
<br/><br/>
|
139 |
+
**Model**: [lofi-bytes](https://huggingface.co/Launchpad/lofi-bytes)
|
140 |
+
<br/>
|
141 |
+
**Project Leader**: Alicia Wang
|
142 |
+
<br/>
|
143 |
+
**Members**: Alena Chao, Eric Liu, Zane Mogannam, Chloe Wong, Iris Zhou
|
144 |
+
<br/>
|
145 |
+
**Advisors**: Vincent Lim, Winston Liu
|
146 |
+
<br/>
|
147 |
+
"""
|
148 |
+
)
|
149 |
+
gr.Interface(
|
150 |
+
fn=generate,
|
151 |
+
inputs=gr.File(),
|
152 |
+
outputs=gr.File(),
|
153 |
+
examples=["uploaded_midis/ghibli_castle_in_the_sky.mid", "uploaded_midis/am_i_blue_jazz.mid"]
|
154 |
+
)
|
155 |
+
|
156 |
+
if __name__ == '__main__':
|
157 |
+
demo.launch(share=True)
|
model/loss.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.modules.loss import _Loss
|
5 |
+
|
6 |
+
# Borrowed from https://github.com/jason9693/MusicTransformer-pytorch/blob/5f183374833ff6b7e17f3a24e3594dedd93a5fe5/custom/criterion.py#L28
|
7 |
+
class SmoothCrossEntropyLoss(_Loss):
|
8 |
+
"""
|
9 |
+
https://arxiv.org/abs/1512.00567
|
10 |
+
"""
|
11 |
+
__constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction']
|
12 |
+
|
13 |
+
def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean', is_logits=True):
|
14 |
+
assert 0.0 <= label_smoothing <= 1.0
|
15 |
+
super().__init__(reduction=reduction)
|
16 |
+
|
17 |
+
self.label_smoothing = label_smoothing
|
18 |
+
self.vocab_size = vocab_size
|
19 |
+
self.ignore_index = ignore_index
|
20 |
+
self.input_is_logits = is_logits
|
21 |
+
|
22 |
+
def forward(self, input, target):
|
23 |
+
"""
|
24 |
+
Args:
|
25 |
+
input: [B * T, V]
|
26 |
+
target: [B * T]
|
27 |
+
Returns:
|
28 |
+
cross entropy: [1]
|
29 |
+
"""
|
30 |
+
mask = (target == self.ignore_index).unsqueeze(-1)
|
31 |
+
q = F.one_hot(target.long(), self.vocab_size).type(torch.float32)
|
32 |
+
u = 1.0 / self.vocab_size
|
33 |
+
q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u
|
34 |
+
q_prime = q_prime.masked_fill(mask, 0)
|
35 |
+
|
36 |
+
ce = self.cross_entropy_with_logits(q_prime, input)
|
37 |
+
if self.reduction == 'mean':
|
38 |
+
lengths = torch.sum(target != self.ignore_index)
|
39 |
+
return ce.sum() / lengths
|
40 |
+
elif self.reduction == 'sum':
|
41 |
+
return ce.sum()
|
42 |
+
else:
|
43 |
+
raise NotImplementedError
|
44 |
+
|
45 |
+
def cross_entropy_with_logits(self, p, q):
|
46 |
+
return -torch.sum(p * (q - q.logsumexp(dim=-1, keepdim=True)), dim=-1)
|
model/music_transformer.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.modules.normalization import LayerNorm
|
4 |
+
import random
|
5 |
+
|
6 |
+
from utilities.constants import *
|
7 |
+
from utilities.device import get_device
|
8 |
+
|
9 |
+
from .positional_encoding import PositionalEncoding
|
10 |
+
from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR
|
11 |
+
|
12 |
+
|
13 |
+
# MusicTransformer
|
14 |
+
class MusicTransformer(nn.Module):
|
15 |
+
"""
|
16 |
+
----------
|
17 |
+
Author: Damon Gwinn
|
18 |
+
----------
|
19 |
+
Music Transformer reproduction from https://arxiv.org/abs/1809.04281. Arguments allow for
|
20 |
+
tweaking the transformer architecture (https://arxiv.org/abs/1706.03762) and the rpr argument
|
21 |
+
toggles Relative Position Representations (RPR - https://arxiv.org/abs/1803.02155).
|
22 |
+
|
23 |
+
Supports training and generation using Pytorch's nn.Transformer class with dummy decoder to
|
24 |
+
make a decoder-only transformer architecture
|
25 |
+
|
26 |
+
For RPR support, there is modified Pytorch 1.2.0 code in rpr.py. Modified source will be
|
27 |
+
kept up to date with Pytorch revisions only as necessary.
|
28 |
+
----------
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024,
|
32 |
+
dropout=0.1, max_sequence=2048, rpr=False):
|
33 |
+
super(MusicTransformer, self).__init__()
|
34 |
+
|
35 |
+
self.dummy = DummyDecoder()
|
36 |
+
|
37 |
+
self.nlayers = n_layers
|
38 |
+
self.nhead = num_heads
|
39 |
+
self.d_model = d_model
|
40 |
+
self.d_ff = dim_feedforward
|
41 |
+
self.dropout = dropout
|
42 |
+
self.max_seq = max_sequence
|
43 |
+
self.rpr = rpr
|
44 |
+
|
45 |
+
# Input embedding
|
46 |
+
self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model)
|
47 |
+
|
48 |
+
# Positional encoding
|
49 |
+
self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq)
|
50 |
+
|
51 |
+
# Base transformer
|
52 |
+
if(not self.rpr):
|
53 |
+
# To make a decoder-only transformer we need to use masked encoder layers
|
54 |
+
# Dummy decoder to essentially just return the encoder output
|
55 |
+
self.transformer = nn.Transformer(
|
56 |
+
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
|
57 |
+
num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
|
58 |
+
dim_feedforward=self.d_ff, custom_decoder=self.dummy
|
59 |
+
)
|
60 |
+
# RPR Transformer
|
61 |
+
else:
|
62 |
+
encoder_norm = LayerNorm(self.d_model)
|
63 |
+
encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq)
|
64 |
+
encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm)
|
65 |
+
self.transformer = nn.Transformer(
|
66 |
+
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
|
67 |
+
num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
|
68 |
+
dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder
|
69 |
+
)
|
70 |
+
|
71 |
+
# Final output is a softmaxed linear layer
|
72 |
+
self.Wout = nn.Linear(self.d_model, VOCAB_SIZE)
|
73 |
+
self.softmax = nn.Softmax(dim=-1)
|
74 |
+
|
75 |
+
# forward
|
76 |
+
def forward(self, x, mask=True):
|
77 |
+
"""
|
78 |
+
----------
|
79 |
+
Author: Damon Gwinn
|
80 |
+
----------
|
81 |
+
Takes an input sequence and outputs predictions using a sequence to sequence method.
|
82 |
+
|
83 |
+
A prediction at one index is the "next" prediction given all information seen previously.
|
84 |
+
----------
|
85 |
+
"""
|
86 |
+
|
87 |
+
if(mask is True):
|
88 |
+
mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(get_device())
|
89 |
+
else:
|
90 |
+
mask = None
|
91 |
+
|
92 |
+
x = self.embedding(x)
|
93 |
+
|
94 |
+
# Input shape is (max_seq, batch_size, d_model)
|
95 |
+
x = x.permute(1,0,2)
|
96 |
+
|
97 |
+
x = self.positional_encoding(x)
|
98 |
+
|
99 |
+
# Since there are no true decoder layers, the tgt is unused
|
100 |
+
# Pytorch wants src and tgt to have some equal dims however
|
101 |
+
x_out = self.transformer(src=x, tgt=x, src_mask=mask)
|
102 |
+
|
103 |
+
# Back to (batch_size, max_seq, d_model)
|
104 |
+
x_out = x_out.permute(1,0,2)
|
105 |
+
|
106 |
+
y = self.Wout(x_out)
|
107 |
+
# y = self.softmax(y)
|
108 |
+
|
109 |
+
del mask
|
110 |
+
|
111 |
+
# They are trained to predict the next note in sequence (we don't need the last one)
|
112 |
+
return y
|
113 |
+
|
114 |
+
# generate
|
115 |
+
def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0):
|
116 |
+
"""
|
117 |
+
----------
|
118 |
+
Author: Damon Gwinn
|
119 |
+
----------
|
120 |
+
Generates midi given a primer sample. Music can be generated using a probability distribution over
|
121 |
+
the softmax probabilities (recommended) or by using a beam search.
|
122 |
+
----------
|
123 |
+
"""
|
124 |
+
|
125 |
+
assert (not self.training), "Cannot generate while in training mode"
|
126 |
+
|
127 |
+
print("Generating sequence of max length:", target_seq_length)
|
128 |
+
|
129 |
+
gen_seq = torch.full((1,target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
|
130 |
+
|
131 |
+
num_primer = len(primer)
|
132 |
+
gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device())
|
133 |
+
|
134 |
+
|
135 |
+
# print("primer:",primer)
|
136 |
+
# print(gen_seq)
|
137 |
+
cur_i = num_primer
|
138 |
+
while(cur_i < target_seq_length):
|
139 |
+
# gen_seq_batch = gen_seq.clone()
|
140 |
+
y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END]
|
141 |
+
token_probs = y[:, cur_i-1, :]
|
142 |
+
|
143 |
+
if(beam == 0):
|
144 |
+
beam_ran = 2.0
|
145 |
+
else:
|
146 |
+
beam_ran = random.uniform(0,1)
|
147 |
+
|
148 |
+
if(beam_ran <= beam_chance):
|
149 |
+
token_probs = token_probs.flatten()
|
150 |
+
top_res, top_i = torch.topk(token_probs, beam)
|
151 |
+
|
152 |
+
beam_rows = top_i // VOCAB_SIZE
|
153 |
+
beam_cols = top_i % VOCAB_SIZE
|
154 |
+
|
155 |
+
gen_seq = gen_seq[beam_rows, :]
|
156 |
+
gen_seq[..., cur_i] = beam_cols
|
157 |
+
|
158 |
+
else:
|
159 |
+
distrib = torch.distributions.categorical.Categorical(probs=token_probs)
|
160 |
+
next_token = distrib.sample()
|
161 |
+
# print("next token:",next_token)
|
162 |
+
gen_seq[:, cur_i] = next_token
|
163 |
+
|
164 |
+
|
165 |
+
# Let the transformer decide to end if it wants to
|
166 |
+
if(next_token == TOKEN_END):
|
167 |
+
print("Model called end of sequence at:", cur_i, "/", target_seq_length)
|
168 |
+
break
|
169 |
+
|
170 |
+
cur_i += 1
|
171 |
+
if(cur_i % 50 == 0):
|
172 |
+
print(cur_i, "/", target_seq_length)
|
173 |
+
|
174 |
+
return gen_seq[:, :cur_i]
|
175 |
+
|
176 |
+
# Used as a dummy to nn.Transformer
|
177 |
+
# DummyDecoder
|
178 |
+
class DummyDecoder(nn.Module):
|
179 |
+
"""
|
180 |
+
----------
|
181 |
+
Author: Damon Gwinn
|
182 |
+
----------
|
183 |
+
A dummy decoder that returns its input. Used to make the Pytorch transformer into a decoder-only
|
184 |
+
architecture (stacked encoders with dummy decoder fits the bill)
|
185 |
+
----------
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self):
|
189 |
+
super(DummyDecoder, self).__init__()
|
190 |
+
|
191 |
+
def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask, **kwargs):
|
192 |
+
"""
|
193 |
+
----------
|
194 |
+
Author: Damon Gwinn
|
195 |
+
----------
|
196 |
+
Returns the input (memory)
|
197 |
+
----------
|
198 |
+
"""
|
199 |
+
|
200 |
+
return memory
|
model/positional_encoding.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
# PositionalEncoding
|
6 |
+
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
|
7 |
+
class PositionalEncoding(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
10 |
+
super(PositionalEncoding, self).__init__()
|
11 |
+
self.dropout = nn.Dropout(p=dropout)
|
12 |
+
|
13 |
+
pe = torch.zeros(max_len, d_model)
|
14 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
18 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
19 |
+
self.register_buffer('pe', pe)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x + self.pe[:x.size(0), :]
|
23 |
+
return self.dropout(x)
|
model/rpr.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn.parameter import Parameter
|
6 |
+
from torch.nn import Module
|
7 |
+
from torch.nn.modules.transformer import _get_clones
|
8 |
+
from torch.nn.modules.linear import Linear
|
9 |
+
from torch.nn.modules.dropout import Dropout
|
10 |
+
from torch.nn.modules.normalization import LayerNorm
|
11 |
+
from torch.nn.init import *
|
12 |
+
|
13 |
+
from torch.nn.functional import linear, softmax, dropout
|
14 |
+
|
15 |
+
# TransformerEncoderRPR
|
16 |
+
class TransformerEncoderRPR(Module):
|
17 |
+
"""
|
18 |
+
----------
|
19 |
+
Author: Pytorch
|
20 |
+
----------
|
21 |
+
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
|
22 |
+
https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoder
|
23 |
+
|
24 |
+
No modification. Copied here to ensure continued compatibility with other edits.
|
25 |
+
----------
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
29 |
+
super(TransformerEncoderRPR, self).__init__()
|
30 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
31 |
+
self.num_layers = num_layers
|
32 |
+
self.norm = norm
|
33 |
+
|
34 |
+
def forward(self, src, mask=None, src_key_padding_mask=None, **kwargs):
|
35 |
+
|
36 |
+
output = src
|
37 |
+
|
38 |
+
for i in range(self.num_layers):
|
39 |
+
output = self.layers[i](output, src_mask=mask,
|
40 |
+
src_key_padding_mask=src_key_padding_mask)
|
41 |
+
|
42 |
+
if self.norm:
|
43 |
+
output = self.norm(output)
|
44 |
+
|
45 |
+
return output
|
46 |
+
|
47 |
+
# TransformerEncoderLayerRPR
|
48 |
+
class TransformerEncoderLayerRPR(Module):
|
49 |
+
"""
|
50 |
+
----------
|
51 |
+
Author: Pytorch
|
52 |
+
Modified: Damon Gwinn
|
53 |
+
----------
|
54 |
+
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
|
55 |
+
https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
|
56 |
+
|
57 |
+
Modification to create and call custom MultiheadAttentionRPR
|
58 |
+
----------
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None):
|
62 |
+
super(TransformerEncoderLayerRPR, self).__init__()
|
63 |
+
self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len)
|
64 |
+
# Implementation of Feedforward model
|
65 |
+
self.linear1 = Linear(d_model, dim_feedforward)
|
66 |
+
self.dropout = Dropout(dropout)
|
67 |
+
self.linear2 = Linear(dim_feedforward, d_model)
|
68 |
+
|
69 |
+
self.norm1 = LayerNorm(d_model)
|
70 |
+
self.norm2 = LayerNorm(d_model)
|
71 |
+
self.dropout1 = Dropout(dropout)
|
72 |
+
self.dropout2 = Dropout(dropout)
|
73 |
+
|
74 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
75 |
+
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
|
76 |
+
key_padding_mask=src_key_padding_mask)[0]
|
77 |
+
src = src + self.dropout1(src2)
|
78 |
+
src = self.norm1(src)
|
79 |
+
src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
|
80 |
+
src = src + self.dropout2(src2)
|
81 |
+
src = self.norm2(src)
|
82 |
+
return src
|
83 |
+
|
84 |
+
# MultiheadAttentionRPR
|
85 |
+
class MultiheadAttentionRPR(Module):
|
86 |
+
"""
|
87 |
+
----------
|
88 |
+
Author: Pytorch
|
89 |
+
Modified: Damon Gwinn
|
90 |
+
----------
|
91 |
+
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
|
92 |
+
https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/activation.html#MultiheadAttention
|
93 |
+
|
94 |
+
Modification to add RPR embedding Er and call custom multi_head_attention_forward_rpr
|
95 |
+
----------
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None):
|
99 |
+
super(MultiheadAttentionRPR, self).__init__()
|
100 |
+
self.embed_dim = embed_dim
|
101 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
102 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
103 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
104 |
+
|
105 |
+
self.num_heads = num_heads
|
106 |
+
self.dropout = dropout
|
107 |
+
self.head_dim = embed_dim // num_heads
|
108 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
109 |
+
|
110 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
111 |
+
|
112 |
+
if self._qkv_same_embed_dim is False:
|
113 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
114 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
115 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
116 |
+
|
117 |
+
if bias:
|
118 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
119 |
+
else:
|
120 |
+
self.register_parameter('in_proj_bias', None)
|
121 |
+
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
|
122 |
+
|
123 |
+
if add_bias_kv:
|
124 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
125 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
126 |
+
else:
|
127 |
+
self.bias_k = self.bias_v = None
|
128 |
+
|
129 |
+
self.add_zero_attn = add_zero_attn
|
130 |
+
|
131 |
+
# Adding RPR embedding matrix
|
132 |
+
if(er_len is not None):
|
133 |
+
self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32))
|
134 |
+
else:
|
135 |
+
self.Er = None
|
136 |
+
|
137 |
+
self._reset_parameters()
|
138 |
+
|
139 |
+
def _reset_parameters(self):
|
140 |
+
if self._qkv_same_embed_dim:
|
141 |
+
xavier_uniform_(self.in_proj_weight)
|
142 |
+
else:
|
143 |
+
xavier_uniform_(self.q_proj_weight)
|
144 |
+
xavier_uniform_(self.k_proj_weight)
|
145 |
+
xavier_uniform_(self.v_proj_weight)
|
146 |
+
|
147 |
+
if self.in_proj_bias is not None:
|
148 |
+
constant_(self.in_proj_bias, 0.)
|
149 |
+
constant_(self.out_proj.bias, 0.)
|
150 |
+
if self.bias_k is not None:
|
151 |
+
xavier_normal_(self.bias_k)
|
152 |
+
if self.bias_v is not None:
|
153 |
+
xavier_normal_(self.bias_v)
|
154 |
+
|
155 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
156 |
+
need_weights=True, attn_mask=None):
|
157 |
+
|
158 |
+
if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
|
159 |
+
# return F.multi_head_attention_forward(
|
160 |
+
# query, key, value, self.embed_dim, self.num_heads,
|
161 |
+
# self.in_proj_weight, self.in_proj_bias,
|
162 |
+
# self.bias_k, self.bias_v, self.add_zero_attn,
|
163 |
+
# self.dropout, self.out_proj.weight, self.out_proj.bias,
|
164 |
+
# training=self.training,
|
165 |
+
# key_padding_mask=key_padding_mask, need_weights=need_weights,
|
166 |
+
# attn_mask=attn_mask, use_separate_proj_weight=True,
|
167 |
+
# q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
168 |
+
# v_proj_weight=self.v_proj_weight)
|
169 |
+
|
170 |
+
return multi_head_attention_forward_rpr(
|
171 |
+
query, key, value, self.embed_dim, self.num_heads,
|
172 |
+
self.in_proj_weight, self.in_proj_bias,
|
173 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
174 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
175 |
+
training=self.training,
|
176 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
177 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
178 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
179 |
+
v_proj_weight=self.v_proj_weight, rpr_mat=self.Er)
|
180 |
+
else:
|
181 |
+
if not hasattr(self, '_qkv_same_embed_dim'):
|
182 |
+
warnings.warn('A new version of MultiheadAttention module has been implemented. \
|
183 |
+
Please re-train your model with the new module',
|
184 |
+
UserWarning)
|
185 |
+
|
186 |
+
# return F.multi_head_attention_forward(
|
187 |
+
# query, key, value, self.embed_dim, self.num_heads,
|
188 |
+
# self.in_proj_weight, self.in_proj_bias,
|
189 |
+
# self.bias_k, self.bias_v, self.add_zero_attn,
|
190 |
+
# self.dropout, self.out_proj.weight, self.out_proj.bias,
|
191 |
+
# training=self.training,
|
192 |
+
# key_padding_mask=key_padding_mask, need_weights=need_weights,
|
193 |
+
# attn_mask=attn_mask)
|
194 |
+
|
195 |
+
return multi_head_attention_forward_rpr(
|
196 |
+
query, key, value, self.embed_dim, self.num_heads,
|
197 |
+
self.in_proj_weight, self.in_proj_bias,
|
198 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
199 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
200 |
+
training=self.training,
|
201 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
202 |
+
attn_mask=attn_mask, rpr_mat=self.Er)
|
203 |
+
|
204 |
+
# multi_head_attention_forward_rpr
|
205 |
+
def multi_head_attention_forward_rpr(query, # type: Tensor
|
206 |
+
key, # type: Tensor
|
207 |
+
value, # type: Tensor
|
208 |
+
embed_dim_to_check, # type: int
|
209 |
+
num_heads, # type: int
|
210 |
+
in_proj_weight, # type: Tensor
|
211 |
+
in_proj_bias, # type: Tensor
|
212 |
+
bias_k, # type: Optional[Tensor]
|
213 |
+
bias_v, # type: Optional[Tensor]
|
214 |
+
add_zero_attn, # type: bool
|
215 |
+
dropout_p, # type: float
|
216 |
+
out_proj_weight, # type: Tensor
|
217 |
+
out_proj_bias, # type: Tensor
|
218 |
+
training=True, # type: bool
|
219 |
+
key_padding_mask=None, # type: Optional[Tensor]
|
220 |
+
need_weights=True, # type: bool
|
221 |
+
attn_mask=None, # type: Optional[Tensor]
|
222 |
+
use_separate_proj_weight=False, # type: bool
|
223 |
+
q_proj_weight=None, # type: Optional[Tensor]
|
224 |
+
k_proj_weight=None, # type: Optional[Tensor]
|
225 |
+
v_proj_weight=None, # type: Optional[Tensor]
|
226 |
+
static_k=None, # type: Optional[Tensor]
|
227 |
+
static_v=None, # type: Optional[Tensor]
|
228 |
+
rpr_mat=None
|
229 |
+
):
|
230 |
+
"""
|
231 |
+
----------
|
232 |
+
Author: Pytorch
|
233 |
+
Modified: Damon Gwinn
|
234 |
+
----------
|
235 |
+
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
|
236 |
+
https://pytorch.org/docs/1.2.0/_modules/torch/nn/functional.html
|
237 |
+
|
238 |
+
Modification to take RPR embedding matrix and perform skew optimized RPR (https://arxiv.org/abs/1809.04281)
|
239 |
+
----------
|
240 |
+
"""
|
241 |
+
|
242 |
+
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
|
243 |
+
|
244 |
+
qkv_same = torch.equal(query, key) and torch.equal(key, value)
|
245 |
+
kv_same = torch.equal(key, value)
|
246 |
+
|
247 |
+
tgt_len, bsz, embed_dim = query.size()
|
248 |
+
assert embed_dim == embed_dim_to_check
|
249 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
250 |
+
assert key.size() == value.size()
|
251 |
+
|
252 |
+
head_dim = embed_dim // num_heads
|
253 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
254 |
+
scaling = float(head_dim) ** -0.5
|
255 |
+
|
256 |
+
if use_separate_proj_weight is not True:
|
257 |
+
if qkv_same:
|
258 |
+
# self-attention
|
259 |
+
q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
260 |
+
|
261 |
+
elif kv_same:
|
262 |
+
# encoder-decoder attention
|
263 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
264 |
+
_b = in_proj_bias
|
265 |
+
_start = 0
|
266 |
+
_end = embed_dim
|
267 |
+
_w = in_proj_weight[_start:_end, :]
|
268 |
+
if _b is not None:
|
269 |
+
_b = _b[_start:_end]
|
270 |
+
q = linear(query, _w, _b)
|
271 |
+
|
272 |
+
if key is None:
|
273 |
+
assert value is None
|
274 |
+
k = None
|
275 |
+
v = None
|
276 |
+
else:
|
277 |
+
|
278 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
279 |
+
_b = in_proj_bias
|
280 |
+
_start = embed_dim
|
281 |
+
_end = None
|
282 |
+
_w = in_proj_weight[_start:, :]
|
283 |
+
if _b is not None:
|
284 |
+
_b = _b[_start:]
|
285 |
+
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
286 |
+
|
287 |
+
else:
|
288 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
289 |
+
_b = in_proj_bias
|
290 |
+
_start = 0
|
291 |
+
_end = embed_dim
|
292 |
+
_w = in_proj_weight[_start:_end, :]
|
293 |
+
if _b is not None:
|
294 |
+
_b = _b[_start:_end]
|
295 |
+
q = linear(query, _w, _b)
|
296 |
+
|
297 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
298 |
+
_b = in_proj_bias
|
299 |
+
_start = embed_dim
|
300 |
+
_end = embed_dim * 2
|
301 |
+
_w = in_proj_weight[_start:_end, :]
|
302 |
+
if _b is not None:
|
303 |
+
_b = _b[_start:_end]
|
304 |
+
k = linear(key, _w, _b)
|
305 |
+
|
306 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
307 |
+
_b = in_proj_bias
|
308 |
+
_start = embed_dim * 2
|
309 |
+
_end = None
|
310 |
+
_w = in_proj_weight[_start:, :]
|
311 |
+
if _b is not None:
|
312 |
+
_b = _b[_start:]
|
313 |
+
v = linear(value, _w, _b)
|
314 |
+
else:
|
315 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
316 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
317 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
318 |
+
|
319 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
320 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
321 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
322 |
+
|
323 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
324 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
325 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
326 |
+
|
327 |
+
if in_proj_bias is not None:
|
328 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
329 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
330 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
331 |
+
else:
|
332 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
333 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
334 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
335 |
+
q = q * scaling
|
336 |
+
|
337 |
+
if bias_k is not None and bias_v is not None:
|
338 |
+
if static_k is None and static_v is None:
|
339 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
340 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
341 |
+
if attn_mask is not None:
|
342 |
+
attn_mask = torch.cat([attn_mask,
|
343 |
+
torch.zeros((attn_mask.size(0), 1),
|
344 |
+
dtype=attn_mask.dtype,
|
345 |
+
device=attn_mask.device)], dim=1)
|
346 |
+
if key_padding_mask is not None:
|
347 |
+
key_padding_mask = torch.cat(
|
348 |
+
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
349 |
+
dtype=key_padding_mask.dtype,
|
350 |
+
device=key_padding_mask.device)], dim=1)
|
351 |
+
else:
|
352 |
+
assert static_k is None, "bias cannot be added to static key."
|
353 |
+
assert static_v is None, "bias cannot be added to static value."
|
354 |
+
else:
|
355 |
+
assert bias_k is None
|
356 |
+
assert bias_v is None
|
357 |
+
|
358 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
359 |
+
if k is not None:
|
360 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
361 |
+
if v is not None:
|
362 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
363 |
+
|
364 |
+
if static_k is not None:
|
365 |
+
assert static_k.size(0) == bsz * num_heads
|
366 |
+
assert static_k.size(2) == head_dim
|
367 |
+
k = static_k
|
368 |
+
|
369 |
+
if static_v is not None:
|
370 |
+
assert static_v.size(0) == bsz * num_heads
|
371 |
+
assert static_v.size(2) == head_dim
|
372 |
+
v = static_v
|
373 |
+
|
374 |
+
src_len = k.size(1)
|
375 |
+
|
376 |
+
if key_padding_mask is not None:
|
377 |
+
assert key_padding_mask.size(0) == bsz
|
378 |
+
assert key_padding_mask.size(1) == src_len
|
379 |
+
|
380 |
+
if add_zero_attn:
|
381 |
+
src_len += 1
|
382 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
383 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
384 |
+
if attn_mask is not None:
|
385 |
+
attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
|
386 |
+
dtype=attn_mask.dtype,
|
387 |
+
device=attn_mask.device)], dim=1)
|
388 |
+
if key_padding_mask is not None:
|
389 |
+
key_padding_mask = torch.cat(
|
390 |
+
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
391 |
+
dtype=key_padding_mask.dtype,
|
392 |
+
device=key_padding_mask.device)], dim=1)
|
393 |
+
|
394 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
395 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
396 |
+
|
397 |
+
######### ADDITION OF RPR ###########
|
398 |
+
if(rpr_mat is not None):
|
399 |
+
rpr_mat = _get_valid_embedding(rpr_mat, q.shape[1], k.shape[1])
|
400 |
+
qe = torch.einsum("hld,md->hlm", q, rpr_mat)
|
401 |
+
srel = _skew(qe)
|
402 |
+
|
403 |
+
attn_output_weights += srel
|
404 |
+
|
405 |
+
if attn_mask is not None:
|
406 |
+
attn_mask = attn_mask.unsqueeze(0)
|
407 |
+
attn_output_weights += attn_mask
|
408 |
+
|
409 |
+
if key_padding_mask is not None:
|
410 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
411 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
412 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
413 |
+
float('-inf'),
|
414 |
+
)
|
415 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
416 |
+
|
417 |
+
attn_output_weights = softmax(
|
418 |
+
attn_output_weights, dim=-1)
|
419 |
+
|
420 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
421 |
+
|
422 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
423 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
424 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
425 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
426 |
+
|
427 |
+
if need_weights:
|
428 |
+
# average attention weights over heads
|
429 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
430 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
431 |
+
else:
|
432 |
+
return attn_output, None
|
433 |
+
|
434 |
+
def _get_valid_embedding(Er, len_q, len_k):
|
435 |
+
"""
|
436 |
+
----------
|
437 |
+
Author: Damon Gwinn
|
438 |
+
----------
|
439 |
+
Gets valid embeddings based on max length of RPR attention
|
440 |
+
----------
|
441 |
+
"""
|
442 |
+
|
443 |
+
len_e = Er.shape[0]
|
444 |
+
start = max(0, len_e - len_q)
|
445 |
+
return Er[start:, :]
|
446 |
+
|
447 |
+
def _skew(qe):
|
448 |
+
"""
|
449 |
+
----------
|
450 |
+
Author: Damon Gwinn
|
451 |
+
----------
|
452 |
+
Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281)
|
453 |
+
----------
|
454 |
+
"""
|
455 |
+
|
456 |
+
sz = qe.shape[1]
|
457 |
+
mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0)
|
458 |
+
|
459 |
+
qe = mask * qe
|
460 |
+
qe = F.pad(qe, (1,0, 0,0, 0,0))
|
461 |
+
qe = torch.reshape(qe, (qe.shape[0], qe.shape[2], qe.shape[1]))
|
462 |
+
|
463 |
+
srel = qe[:, 1:, :]
|
464 |
+
return srel
|
processor.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pretty_midi
|
2 |
+
|
3 |
+
|
4 |
+
RANGE_NOTE_ON = 128
|
5 |
+
RANGE_NOTE_OFF = 128
|
6 |
+
RANGE_VEL = 32
|
7 |
+
RANGE_TIME_SHIFT = 100
|
8 |
+
|
9 |
+
START_IDX = {
|
10 |
+
'note_on': 0,
|
11 |
+
'note_off': RANGE_NOTE_ON,
|
12 |
+
'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF,
|
13 |
+
'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
class SustainAdapter:
|
18 |
+
def __init__(self, time, type):
|
19 |
+
self.start = time
|
20 |
+
self.type = type
|
21 |
+
|
22 |
+
|
23 |
+
class SustainDownManager:
|
24 |
+
def __init__(self, start, end):
|
25 |
+
self.start = start
|
26 |
+
self.end = end
|
27 |
+
self.managed_notes = []
|
28 |
+
self._note_dict = {} # key: pitch, value: note.start
|
29 |
+
|
30 |
+
def add_managed_note(self, note: pretty_midi.Note):
|
31 |
+
self.managed_notes.append(note)
|
32 |
+
|
33 |
+
def transposition_notes(self):
|
34 |
+
for note in reversed(self.managed_notes):
|
35 |
+
try:
|
36 |
+
note.end = self._note_dict[note.pitch]
|
37 |
+
except KeyError:
|
38 |
+
note.end = max(self.end, note.end)
|
39 |
+
self._note_dict[note.pitch] = note.start
|
40 |
+
|
41 |
+
|
42 |
+
# Divided note by note_on, note_off
|
43 |
+
class SplitNote:
|
44 |
+
def __init__(self, type, time, value, velocity):
|
45 |
+
## type: note_on, note_off
|
46 |
+
self.type = type
|
47 |
+
self.time = time
|
48 |
+
self.velocity = velocity
|
49 |
+
self.value = value
|
50 |
+
|
51 |
+
def __repr__(self):
|
52 |
+
return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\
|
53 |
+
.format(self.time, self.type, self.value, self.velocity)
|
54 |
+
|
55 |
+
|
56 |
+
class Event:
|
57 |
+
def __init__(self, event_type, value):
|
58 |
+
self.type = event_type
|
59 |
+
self.value = value
|
60 |
+
|
61 |
+
def __repr__(self):
|
62 |
+
return '<Event type: {}, value: {}>'.format(self.type, self.value)
|
63 |
+
|
64 |
+
def to_int(self):
|
65 |
+
return START_IDX[self.type] + self.value
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def from_int(int_value):
|
69 |
+
info = Event._type_check(int_value)
|
70 |
+
return Event(info['type'], info['value'])
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def _type_check(int_value):
|
74 |
+
range_note_on = range(0, RANGE_NOTE_ON)
|
75 |
+
range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF)
|
76 |
+
range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT)
|
77 |
+
|
78 |
+
valid_value = int_value
|
79 |
+
|
80 |
+
if int_value in range_note_on:
|
81 |
+
return {'type': 'note_on', 'value': valid_value}
|
82 |
+
elif int_value in range_note_off:
|
83 |
+
valid_value -= RANGE_NOTE_ON
|
84 |
+
return {'type': 'note_off', 'value': valid_value}
|
85 |
+
elif int_value in range_time_shift:
|
86 |
+
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF)
|
87 |
+
return {'type': 'time_shift', 'value': valid_value}
|
88 |
+
else:
|
89 |
+
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT)
|
90 |
+
return {'type': 'velocity', 'value': valid_value}
|
91 |
+
|
92 |
+
|
93 |
+
def _divide_note(notes):
|
94 |
+
result_array = []
|
95 |
+
notes.sort(key=lambda x: x.start)
|
96 |
+
|
97 |
+
for note in notes:
|
98 |
+
on = SplitNote('note_on', note.start, note.pitch, note.velocity)
|
99 |
+
off = SplitNote('note_off', note.end, note.pitch, None)
|
100 |
+
result_array += [on, off]
|
101 |
+
return result_array
|
102 |
+
|
103 |
+
|
104 |
+
def _merge_note(snote_sequence):
|
105 |
+
note_on_dict = {}
|
106 |
+
result_array = []
|
107 |
+
|
108 |
+
for snote in snote_sequence:
|
109 |
+
# print(note_on_dict)
|
110 |
+
if snote.type == 'note_on':
|
111 |
+
note_on_dict[snote.value] = snote
|
112 |
+
elif snote.type == 'note_off':
|
113 |
+
try:
|
114 |
+
on = note_on_dict[snote.value]
|
115 |
+
off = snote
|
116 |
+
if off.time - on.time == 0:
|
117 |
+
continue
|
118 |
+
result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time)
|
119 |
+
result_array.append(result)
|
120 |
+
except:
|
121 |
+
print('info removed pitch: {}'.format(snote.value))
|
122 |
+
return result_array
|
123 |
+
|
124 |
+
|
125 |
+
def _snote2events(snote: SplitNote, prev_vel: int):
|
126 |
+
result = []
|
127 |
+
if snote.velocity is not None:
|
128 |
+
modified_velocity = snote.velocity // 4
|
129 |
+
if prev_vel != modified_velocity:
|
130 |
+
result.append(Event(event_type='velocity', value=modified_velocity))
|
131 |
+
result.append(Event(event_type=snote.type, value=snote.value))
|
132 |
+
return result
|
133 |
+
|
134 |
+
|
135 |
+
def _event_seq2snote_seq(event_sequence):
|
136 |
+
timeline = 0
|
137 |
+
velocity = 0
|
138 |
+
snote_seq = []
|
139 |
+
|
140 |
+
for event in event_sequence:
|
141 |
+
if event.type == 'time_shift':
|
142 |
+
timeline += ((event.value+1) / 100)
|
143 |
+
if event.type == 'velocity':
|
144 |
+
velocity = event.value * 4
|
145 |
+
else:
|
146 |
+
snote = SplitNote(event.type, timeline, event.value, velocity)
|
147 |
+
snote_seq.append(snote)
|
148 |
+
return snote_seq
|
149 |
+
|
150 |
+
|
151 |
+
def _make_time_sift_events(prev_time, post_time):
|
152 |
+
time_interval = int(round((post_time - prev_time) * 100))
|
153 |
+
results = []
|
154 |
+
while time_interval >= RANGE_TIME_SHIFT:
|
155 |
+
results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1))
|
156 |
+
time_interval -= RANGE_TIME_SHIFT
|
157 |
+
if time_interval == 0:
|
158 |
+
return results
|
159 |
+
else:
|
160 |
+
return results + [Event(event_type='time_shift', value=time_interval-1)]
|
161 |
+
|
162 |
+
|
163 |
+
def _control_preprocess(ctrl_changes):
|
164 |
+
sustains = []
|
165 |
+
|
166 |
+
manager = None
|
167 |
+
for ctrl in ctrl_changes:
|
168 |
+
if ctrl.value >= 64 and manager is None:
|
169 |
+
# sustain down
|
170 |
+
manager = SustainDownManager(start=ctrl.time, end=None)
|
171 |
+
elif ctrl.value < 64 and manager is not None:
|
172 |
+
# sustain up
|
173 |
+
manager.end = ctrl.time
|
174 |
+
sustains.append(manager)
|
175 |
+
manager = None
|
176 |
+
elif ctrl.value < 64 and len(sustains) > 0:
|
177 |
+
sustains[-1].end = ctrl.time
|
178 |
+
return sustains
|
179 |
+
|
180 |
+
|
181 |
+
def _note_preprocess(susteins, notes):
|
182 |
+
note_stream = []
|
183 |
+
|
184 |
+
if susteins: # if the midi file has sustain controls
|
185 |
+
for sustain in susteins:
|
186 |
+
for note_idx, note in enumerate(notes):
|
187 |
+
if note.start < sustain.start:
|
188 |
+
note_stream.append(note)
|
189 |
+
elif note.start > sustain.end:
|
190 |
+
notes = notes[note_idx:]
|
191 |
+
sustain.transposition_notes()
|
192 |
+
break
|
193 |
+
else:
|
194 |
+
sustain.add_managed_note(note)
|
195 |
+
|
196 |
+
for sustain in susteins:
|
197 |
+
note_stream += sustain.managed_notes
|
198 |
+
|
199 |
+
else: # else, just push everything into note stream
|
200 |
+
for note_idx, note in enumerate(notes):
|
201 |
+
note_stream.append(note)
|
202 |
+
|
203 |
+
note_stream.sort(key= lambda x: x.start)
|
204 |
+
return note_stream
|
205 |
+
|
206 |
+
|
207 |
+
def encode_midi(file_path):
|
208 |
+
events = []
|
209 |
+
notes = []
|
210 |
+
mid = pretty_midi.PrettyMIDI(midi_file=file_path)
|
211 |
+
|
212 |
+
for inst in mid.instruments:
|
213 |
+
inst_notes = inst.notes
|
214 |
+
# ctrl.number is the number of sustain control. If you want to know abour the number type of control,
|
215 |
+
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2
|
216 |
+
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64])
|
217 |
+
notes += _note_preprocess(ctrls, inst_notes)
|
218 |
+
|
219 |
+
dnotes = _divide_note(notes)
|
220 |
+
|
221 |
+
# print(dnotes)
|
222 |
+
dnotes.sort(key=lambda x: x.time)
|
223 |
+
# print('sorted:')
|
224 |
+
# print(dnotes)
|
225 |
+
cur_time = 0
|
226 |
+
cur_vel = 0
|
227 |
+
for snote in dnotes:
|
228 |
+
events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
|
229 |
+
events += _snote2events(snote=snote, prev_vel=cur_vel)
|
230 |
+
# events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
|
231 |
+
|
232 |
+
cur_time = snote.time
|
233 |
+
cur_vel = snote.velocity
|
234 |
+
|
235 |
+
return [e.to_int() for e in events]
|
236 |
+
|
237 |
+
|
238 |
+
def decode_midi(idx_array, file_path=None):
|
239 |
+
event_sequence = [Event.from_int(idx) for idx in idx_array]
|
240 |
+
# print(event_sequence)
|
241 |
+
snote_seq = _event_seq2snote_seq(event_sequence)
|
242 |
+
note_seq = _merge_note(snote_seq)
|
243 |
+
note_seq.sort(key=lambda x:x.start)
|
244 |
+
|
245 |
+
mid = pretty_midi.PrettyMIDI()
|
246 |
+
# if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
|
247 |
+
instument = pretty_midi.Instrument(0, False, "Composed by Super Piano Music Transformer AI")
|
248 |
+
instument.notes = note_seq
|
249 |
+
|
250 |
+
mid.instruments.append(instument)
|
251 |
+
if file_path is not None:
|
252 |
+
mid.write(file_path)
|
253 |
+
return mid
|
254 |
+
|
255 |
+
|
256 |
+
if __name__ == '__main__':
|
257 |
+
encoded = encode_midi('bin/ADIG04.mid')
|
258 |
+
print(encoded)
|
259 |
+
decided = decode_midi(encoded,file_path='bin/test.mid')
|
260 |
+
|
261 |
+
ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid')
|
262 |
+
print(ins)
|
263 |
+
print(ins.instruments[0])
|
264 |
+
for i in ins.instruments:
|
265 |
+
print(i.control_changes)
|
266 |
+
print(i.notes)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
huggingface_hub
|
3 |
+
pretty_midi
|
4 |
+
setuptools
|
5 |
+
spaces
|
6 |
+
torch
|
uploaded_midis/am_i_blue_jazz.mid
ADDED
Binary file (21.5 kB). View file
|
|
uploaded_midis/ghibli_castle_in_the_sky.mid
ADDED
Binary file (2.81 kB). View file
|
|
utilities/argument_funcs.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
from .constants import SEPERATOR
|
4 |
+
|
5 |
+
# parse_train_args
|
6 |
+
def parse_train_args():
|
7 |
+
"""
|
8 |
+
----------
|
9 |
+
Author: Damon Gwinn
|
10 |
+
----------
|
11 |
+
Argparse arguments for training a model
|
12 |
+
----------
|
13 |
+
"""
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
|
17 |
+
parser.add_argument("-input_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
|
18 |
+
parser.add_argument("-output_dir", type=str, default="./saved_models", help="Folder to save model weights. Saves one every epoch")
|
19 |
+
parser.add_argument("-weight_modulus", type=int, default=1, help="How often to save epoch weights (ex: value of 10 means save every 10 epochs)")
|
20 |
+
parser.add_argument("-print_modulus", type=int, default=1, help="How often to print train results for a batch (batch loss, learn rate, etc.)")
|
21 |
+
|
22 |
+
parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
|
23 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
24 |
+
parser.add_argument("--no_tensorboard", action="store_true", help="Turns off tensorboard result reporting")
|
25 |
+
|
26 |
+
parser.add_argument("-continue_weights", type=str, default=None, help="Model weights to continue training based on")
|
27 |
+
parser.add_argument("-continue_epoch", type=int, default=None, help="Epoch the continue_weights model was at")
|
28 |
+
|
29 |
+
parser.add_argument("-lr", type=float, default=None, help="Constant learn rate. Leave as None for a custom scheduler.")
|
30 |
+
parser.add_argument("-ce_smoothing", type=float, default=None, help="Smoothing parameter for smoothed cross entropy loss (defaults to no smoothing)")
|
31 |
+
parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
|
32 |
+
parser.add_argument("-epochs", type=int, default=100, help="Number of epochs to use")
|
33 |
+
|
34 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
35 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
|
36 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
37 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
38 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
39 |
+
|
40 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
41 |
+
|
42 |
+
parser.add_argument("-dropout", type=float, default=0.1, help="Dropout rate")
|
43 |
+
|
44 |
+
return parser.parse_args()
|
45 |
+
|
46 |
+
# print_train_args
|
47 |
+
def print_train_args(args):
|
48 |
+
"""
|
49 |
+
----------
|
50 |
+
Author: Damon Gwinn
|
51 |
+
----------
|
52 |
+
Prints training arguments
|
53 |
+
----------
|
54 |
+
"""
|
55 |
+
|
56 |
+
print(SEPERATOR)
|
57 |
+
print("input_dir:", args.input_dir)
|
58 |
+
print("output_dir:", args.output_dir)
|
59 |
+
print("weight_modulus:", args.weight_modulus)
|
60 |
+
print("print_modulus:", args.print_modulus)
|
61 |
+
print("")
|
62 |
+
print("n_workers:", args.n_workers)
|
63 |
+
print("force_cpu:", args.force_cpu)
|
64 |
+
print("tensorboard:", not args.no_tensorboard)
|
65 |
+
print("")
|
66 |
+
print("continue_weights:", args.continue_weights)
|
67 |
+
print("continue_epoch:", args.continue_epoch)
|
68 |
+
print("")
|
69 |
+
print("lr:", args.lr)
|
70 |
+
print("ce_smoothing:", args.ce_smoothing)
|
71 |
+
print("batch_size:", args.batch_size)
|
72 |
+
print("epochs:", args.epochs)
|
73 |
+
print("")
|
74 |
+
print("rpr:", args.rpr)
|
75 |
+
print("max_sequence:", args.max_sequence)
|
76 |
+
print("n_layers:", args.n_layers)
|
77 |
+
print("num_heads:", args.num_heads)
|
78 |
+
print("d_model:", args.d_model)
|
79 |
+
print("")
|
80 |
+
print("dim_feedforward:", args.dim_feedforward)
|
81 |
+
print("dropout:", args.dropout)
|
82 |
+
print(SEPERATOR)
|
83 |
+
print("")
|
84 |
+
|
85 |
+
# parse_eval_args
|
86 |
+
def parse_eval_args():
|
87 |
+
"""
|
88 |
+
----------
|
89 |
+
Author: Damon Gwinn
|
90 |
+
----------
|
91 |
+
Argparse arguments for evaluating a model
|
92 |
+
----------
|
93 |
+
"""
|
94 |
+
|
95 |
+
parser = argparse.ArgumentParser()
|
96 |
+
|
97 |
+
parser.add_argument("-dataset_dir", type=str, default="./dataset/e_piano", help="Folder of preprocessed and pickled midi files")
|
98 |
+
parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
|
99 |
+
parser.add_argument("-n_workers", type=int, default=1, help="Number of threads for the dataloader")
|
100 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
101 |
+
|
102 |
+
parser.add_argument("-batch_size", type=int, default=2, help="Batch size to use")
|
103 |
+
|
104 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
105 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider in the model")
|
106 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
107 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
108 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
109 |
+
|
110 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
111 |
+
|
112 |
+
return parser.parse_args()
|
113 |
+
|
114 |
+
# print_eval_args
|
115 |
+
def print_eval_args(args):
|
116 |
+
"""
|
117 |
+
----------
|
118 |
+
Author: Damon Gwinn
|
119 |
+
----------
|
120 |
+
Prints evaluation arguments
|
121 |
+
----------
|
122 |
+
"""
|
123 |
+
|
124 |
+
print(SEPERATOR)
|
125 |
+
print("dataset_dir:", args.dataset_dir)
|
126 |
+
print("model_weights:", args.model_weights)
|
127 |
+
print("n_workers:", args.n_workers)
|
128 |
+
print("force_cpu:", args.force_cpu)
|
129 |
+
print("")
|
130 |
+
print("batch_size:", args.batch_size)
|
131 |
+
print("")
|
132 |
+
print("rpr:", args.rpr)
|
133 |
+
print("max_sequence:", args.max_sequence)
|
134 |
+
print("n_layers:", args.n_layers)
|
135 |
+
print("num_heads:", args.num_heads)
|
136 |
+
print("d_model:", args.d_model)
|
137 |
+
print("")
|
138 |
+
print("dim_feedforward:", args.dim_feedforward)
|
139 |
+
print(SEPERATOR)
|
140 |
+
print("")
|
141 |
+
|
142 |
+
# parse_generate_args
|
143 |
+
def parse_generate_args():
|
144 |
+
"""
|
145 |
+
----------
|
146 |
+
Author: Damon Gwinn
|
147 |
+
----------
|
148 |
+
Argparse arguments for generation
|
149 |
+
----------
|
150 |
+
"""
|
151 |
+
|
152 |
+
parser = argparse.ArgumentParser()
|
153 |
+
|
154 |
+
parser.add_argument("-midi_root", type=str, default="./dataset/e_piano/", help="Midi file to prime the generator with")
|
155 |
+
parser.add_argument("-output_dir", type=str, default="./gen", help="Folder to write generated midi to")
|
156 |
+
parser.add_argument("-primer_file", type=str, default=None, help="File path or integer index to the evaluation dataset. Default is to select a random index.")
|
157 |
+
parser.add_argument("--force_cpu", action="store_true", help="Forces model to run on a cpu even when gpu is available")
|
158 |
+
|
159 |
+
parser.add_argument("-target_seq_length", type=int, default=1024, help="Target length you'd like the midi to be")
|
160 |
+
parser.add_argument("-num_prime", type=int, default=256, help="Amount of messages to prime the generator with")
|
161 |
+
parser.add_argument("-model_weights", type=str, default="./saved_models/model.pickle", help="Pickled model weights file saved with torch.save and model.state_dict()")
|
162 |
+
parser.add_argument("-beam", type=int, default=0, help="Beam search k. 0 for random probability sample and 1 for greedy")
|
163 |
+
|
164 |
+
parser.add_argument("--rpr", action="store_true", help="Use a modified Transformer for Relative Position Representations")
|
165 |
+
parser.add_argument("-max_sequence", type=int, default=2048, help="Maximum midi sequence to consider")
|
166 |
+
parser.add_argument("-n_layers", type=int, default=6, help="Number of decoder layers to use")
|
167 |
+
parser.add_argument("-num_heads", type=int, default=8, help="Number of heads to use for multi-head attention")
|
168 |
+
parser.add_argument("-d_model", type=int, default=512, help="Dimension of the model (output dim of embedding layers, etc.)")
|
169 |
+
|
170 |
+
parser.add_argument("-dim_feedforward", type=int, default=1024, help="Dimension of the feedforward layer")
|
171 |
+
|
172 |
+
return parser.parse_args()
|
173 |
+
|
174 |
+
# print_generate_args
|
175 |
+
def print_generate_args(args):
|
176 |
+
"""
|
177 |
+
----------
|
178 |
+
Author: Damon Gwinn
|
179 |
+
----------
|
180 |
+
Prints generation arguments
|
181 |
+
----------
|
182 |
+
"""
|
183 |
+
|
184 |
+
print(SEPERATOR)
|
185 |
+
print("midi_root:", args.midi_root)
|
186 |
+
print("output_dir:", args.output_dir)
|
187 |
+
print("primer_file:", args.primer_file)
|
188 |
+
print("force_cpu:", args.force_cpu)
|
189 |
+
print("")
|
190 |
+
print("target_seq_length:", args.target_seq_length)
|
191 |
+
print("num_prime:", args.num_prime)
|
192 |
+
print("model_weights:", args.model_weights)
|
193 |
+
print("beam:", args.beam)
|
194 |
+
print("")
|
195 |
+
print("rpr:", args.rpr)
|
196 |
+
print("max_sequence:", args.max_sequence)
|
197 |
+
print("n_layers:", args.n_layers)
|
198 |
+
print("num_heads:", args.num_heads)
|
199 |
+
print("d_model:", args.d_model)
|
200 |
+
print("")
|
201 |
+
print("dim_feedforward:", args.dim_feedforward)
|
202 |
+
print(SEPERATOR)
|
203 |
+
print("")
|
204 |
+
|
205 |
+
# write_model_params
|
206 |
+
def write_model_params(args, output_file):
|
207 |
+
"""
|
208 |
+
----------
|
209 |
+
Author: Damon Gwinn
|
210 |
+
----------
|
211 |
+
Writes given training parameters to text file
|
212 |
+
----------
|
213 |
+
"""
|
214 |
+
|
215 |
+
o_stream = open(output_file, "w")
|
216 |
+
|
217 |
+
o_stream.write("rpr: " + str(args.rpr) + "\n")
|
218 |
+
o_stream.write("lr: " + str(args.lr) + "\n")
|
219 |
+
o_stream.write("ce_smoothing: " + str(args.ce_smoothing) + "\n")
|
220 |
+
o_stream.write("batch_size: " + str(args.batch_size) + "\n")
|
221 |
+
o_stream.write("max_sequence: " + str(args.max_sequence) + "\n")
|
222 |
+
o_stream.write("n_layers: " + str(args.n_layers) + "\n")
|
223 |
+
o_stream.write("num_heads: " + str(args.num_heads) + "\n")
|
224 |
+
o_stream.write("d_model: " + str(args.d_model) + "\n")
|
225 |
+
o_stream.write("dim_feedforward: " + str(args.dim_feedforward) + "\n")
|
226 |
+
o_stream.write("dropout: " + str(args.dropout) + "\n")
|
227 |
+
|
228 |
+
o_stream.close()
|
utilities/constants.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from processor import RANGE_NOTE_ON, RANGE_NOTE_OFF, RANGE_VEL, RANGE_TIME_SHIFT
|
4 |
+
|
5 |
+
SEPERATOR = "========================="
|
6 |
+
|
7 |
+
# Taken from the paper
|
8 |
+
ADAM_BETA_1 = 0.9
|
9 |
+
ADAM_BETA_2 = 0.98
|
10 |
+
ADAM_EPSILON = 10e-9
|
11 |
+
|
12 |
+
LR_DEFAULT_START = 1.0
|
13 |
+
SCHEDULER_WARMUP_STEPS = 4000
|
14 |
+
# LABEL_SMOOTHING_E = 0.1
|
15 |
+
|
16 |
+
# DROPOUT_P = 0.1
|
17 |
+
|
18 |
+
TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
|
19 |
+
TOKEN_PAD = TOKEN_END + 1
|
20 |
+
|
21 |
+
VOCAB_SIZE = TOKEN_PAD + 1
|
22 |
+
|
23 |
+
TORCH_FLOAT = torch.float32
|
24 |
+
TORCH_INT = torch.int32
|
25 |
+
|
26 |
+
TORCH_LABEL_TYPE = torch.long
|
27 |
+
|
28 |
+
PREPEND_ZEROS_WIDTH = 4
|
utilities/device.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# For all things related to devices
|
2 |
+
#### ONLY USE PROVIDED FUNCTIONS, DO NOT USE GLOBAL CONSTANTS ####
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
|
7 |
+
# change cuda devices to ones that are available after running nvidia-smi.
|
8 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '3,4,5'
|
9 |
+
|
10 |
+
TORCH_CPU_DEVICE = torch.device("cpu")
|
11 |
+
|
12 |
+
# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
|
13 |
+
if(torch.cuda.device_count() > 0):
|
14 |
+
TORCH_CUDA_DEVICE = torch.device("cuda")
|
15 |
+
|
16 |
+
else:
|
17 |
+
print("----- WARNING: CUDA devices not detected. This will cause the model to run very slow! -----")
|
18 |
+
print("")
|
19 |
+
TORCH_CUDA_DEVICE = None
|
20 |
+
|
21 |
+
USE_CUDA = True
|
22 |
+
|
23 |
+
# use_cuda
|
24 |
+
def use_cuda(cuda_bool):
|
25 |
+
"""
|
26 |
+
----------
|
27 |
+
Author: Damon Gwinn
|
28 |
+
----------
|
29 |
+
Sets whether to use CUDA (if available), or use the CPU (not recommended)
|
30 |
+
----------
|
31 |
+
"""
|
32 |
+
|
33 |
+
global USE_CUDA
|
34 |
+
USE_CUDA = cuda_bool
|
35 |
+
|
36 |
+
# get_device
|
37 |
+
def get_device():
|
38 |
+
"""
|
39 |
+
----------
|
40 |
+
Author: Damon Gwinn
|
41 |
+
----------
|
42 |
+
Grabs the default device. Default device is CUDA if available and use_cuda is not False, CPU otherwise.
|
43 |
+
----------
|
44 |
+
"""
|
45 |
+
|
46 |
+
if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)):
|
47 |
+
return TORCH_CPU_DEVICE
|
48 |
+
else:
|
49 |
+
return TORCH_CUDA_DEVICE
|
50 |
+
|
51 |
+
# cuda_device
|
52 |
+
def cuda_device():
|
53 |
+
"""
|
54 |
+
----------
|
55 |
+
Author: Damon Gwinn
|
56 |
+
----------
|
57 |
+
Grabs the cuda device (may be None if CUDA is not available)
|
58 |
+
----------
|
59 |
+
"""
|
60 |
+
|
61 |
+
return TORCH_CUDA_DEVICE
|
62 |
+
|
63 |
+
# cpu_device
|
64 |
+
def cpu_device():
|
65 |
+
"""
|
66 |
+
----------
|
67 |
+
Author: Damon Gwinn
|
68 |
+
----------
|
69 |
+
Grabs the cpu device
|
70 |
+
----------
|
71 |
+
"""
|
72 |
+
|
73 |
+
return TORCH_CPU_DEVICE
|
utilities/lr_scheduling.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#Library Imports
|
2 |
+
import math
|
3 |
+
|
4 |
+
#Using Adam optimizer with
|
5 |
+
#Beta_1=0.9, Beta_2=0.98, and Epsilon=10^-9
|
6 |
+
|
7 |
+
#Learning rate varies over course of training
|
8 |
+
#lrate = sqrt(d_model)*min((1/sqrt(step_num)), step_num*(1/warmup_steps*sqrt(warmup_steps)))
|
9 |
+
|
10 |
+
# LrStepTracker
|
11 |
+
class LrStepTracker:
|
12 |
+
"""
|
13 |
+
----------
|
14 |
+
Author: Ryan Marshall
|
15 |
+
Modified: Damon Gwinn
|
16 |
+
----------
|
17 |
+
Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR).
|
18 |
+
|
19 |
+
Learn rate for each step (batch) given the warmup steps is:
|
20 |
+
lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ]
|
21 |
+
|
22 |
+
This is from Attention is All you Need (https://arxiv.org/abs/1706.03762)
|
23 |
+
----------
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, model_dim=512, warmup_steps=4000, init_steps=0):
|
27 |
+
# Store Values
|
28 |
+
self.warmup_steps = warmup_steps
|
29 |
+
self.model_dim = model_dim
|
30 |
+
self.init_steps = init_steps
|
31 |
+
|
32 |
+
# Begin Calculations
|
33 |
+
self.invsqrt_dim = (1 / math.sqrt(model_dim))
|
34 |
+
self.invsqrt_warmup = (1 / (warmup_steps * math.sqrt(warmup_steps)))
|
35 |
+
|
36 |
+
# step
|
37 |
+
def step(self, step):
|
38 |
+
"""
|
39 |
+
----------
|
40 |
+
Author: Ryan Marshall
|
41 |
+
Modified: Damon Gwinn
|
42 |
+
----------
|
43 |
+
Method to pass to LambdaLR. Increments the step and computes the new learn rate.
|
44 |
+
----------
|
45 |
+
"""
|
46 |
+
|
47 |
+
step += self.init_steps
|
48 |
+
if(step <= self.warmup_steps):
|
49 |
+
return self.invsqrt_dim * self.invsqrt_warmup * step
|
50 |
+
else:
|
51 |
+
invsqrt_step = (1 / math.sqrt(step))
|
52 |
+
return self.invsqrt_dim * invsqrt_step
|
53 |
+
|
54 |
+
# get_lr
|
55 |
+
def get_lr(optimizer):
|
56 |
+
"""
|
57 |
+
----------
|
58 |
+
Author: Damon Gwinn
|
59 |
+
----------
|
60 |
+
Hack to get the current learn rate of the model
|
61 |
+
----------
|
62 |
+
"""
|
63 |
+
|
64 |
+
for param_group in optimizer.param_groups:
|
65 |
+
return param_group['lr']
|
utilities/run_model.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import time
|
3 |
+
|
4 |
+
from .constants import *
|
5 |
+
from utilities.device import get_device
|
6 |
+
from .lr_scheduling import get_lr
|
7 |
+
|
8 |
+
from dataset.e_piano import compute_epiano_accuracy
|
9 |
+
|
10 |
+
|
11 |
+
# train_epoch
|
12 |
+
def train_epoch(cur_epoch, model, dataloader, loss, opt, lr_scheduler=None, print_modulus=1):
|
13 |
+
"""
|
14 |
+
----------
|
15 |
+
Author: Damon Gwinn
|
16 |
+
----------
|
17 |
+
Trains a single model epoch
|
18 |
+
----------
|
19 |
+
"""
|
20 |
+
|
21 |
+
out = -1
|
22 |
+
model.train()
|
23 |
+
for batch_num, batch in enumerate(dataloader):
|
24 |
+
time_before = time.time()
|
25 |
+
|
26 |
+
opt.zero_grad()
|
27 |
+
|
28 |
+
x = batch[0].to(get_device())
|
29 |
+
tgt = batch[1].to(get_device())
|
30 |
+
|
31 |
+
y = model(x)
|
32 |
+
|
33 |
+
y = y.reshape(y.shape[0] * y.shape[1], -1)
|
34 |
+
tgt = tgt.flatten()
|
35 |
+
|
36 |
+
out = loss.forward(y, tgt)
|
37 |
+
|
38 |
+
out.backward()
|
39 |
+
opt.step()
|
40 |
+
|
41 |
+
if(lr_scheduler is not None):
|
42 |
+
lr_scheduler.step()
|
43 |
+
|
44 |
+
time_after = time.time()
|
45 |
+
time_took = time_after - time_before
|
46 |
+
|
47 |
+
if((batch_num+1) % print_modulus == 0):
|
48 |
+
print(SEPERATOR)
|
49 |
+
print("Epoch", cur_epoch, " Batch", batch_num+1, "/", len(dataloader))
|
50 |
+
print("LR:", get_lr(opt))
|
51 |
+
print("Train loss:", float(out))
|
52 |
+
print("")
|
53 |
+
print("Time (s):", time_took)
|
54 |
+
print(SEPERATOR)
|
55 |
+
print("")
|
56 |
+
|
57 |
+
return
|
58 |
+
|
59 |
+
# eval_model
|
60 |
+
def eval_model(model, dataloader, loss):
|
61 |
+
"""
|
62 |
+
----------
|
63 |
+
Author: Damon Gwinn
|
64 |
+
----------
|
65 |
+
Evaluates the model and prints the average loss and accuracy
|
66 |
+
----------
|
67 |
+
"""
|
68 |
+
|
69 |
+
model.eval()
|
70 |
+
|
71 |
+
avg_acc = -1
|
72 |
+
avg_loss = -1
|
73 |
+
with torch.set_grad_enabled(False):
|
74 |
+
n_test = len(dataloader)
|
75 |
+
sum_loss = 0.0
|
76 |
+
sum_acc = 0.0
|
77 |
+
for batch in dataloader:
|
78 |
+
x = batch[0].to(get_device())
|
79 |
+
tgt = batch[1].to(get_device())
|
80 |
+
|
81 |
+
y = model(x)
|
82 |
+
|
83 |
+
sum_acc += float(compute_epiano_accuracy(y, tgt))
|
84 |
+
|
85 |
+
y = y.reshape(y.shape[0] * y.shape[1], -1)
|
86 |
+
tgt = tgt.flatten()
|
87 |
+
|
88 |
+
out = loss.forward(y, tgt)
|
89 |
+
|
90 |
+
sum_loss += float(out)
|
91 |
+
|
92 |
+
avg_loss = sum_loss / n_test
|
93 |
+
avg_acc = sum_acc / n_test
|
94 |
+
|
95 |
+
return avg_loss, avg_acc
|