paralym commited on
Commit
0391870
1 Parent(s): 80091df

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +272 -0
README.md ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - music
8
+ - art
9
+ ---
10
+
11
+ <div align="center">
12
+ <img src="Yi_logo.svg" width="150px" style="display: inline-block;">
13
+ <img src="m-a-p.png" width="150px" style="display: inline-block;">
14
+ </div>
15
+
16
+ ## MuPT: Symbolic Music Generative Pre-trained Transformer
17
+
18
+ MuPT is a series of pre-trained models for symbolic music generation. It was trained on a large-scale dataset of symbolic music, including millions of monophonic and polyphonic pieces from different genres and styles. The models are trained with the LLama2 architecture, and can be further used for downstream music generation tasks such as melody generation, accompaniment generation, and multi-track music generation.
19
+
20
+ - 09/01/2024: a series of pre-trained MuPT models are released, with parameters ranging from 110M to 1.3B.
21
+
22
+ ## Model architecture
23
+
24
+ The details of model architecture of MuPT-v1 are listed below:
25
+
26
+ | Name | Parameters | Training Data(Music Pieces) | Seq Length | Hidden Size | Layers | Heads |
27
+ | :--- | :---: | :---: | :---: | :---: | :---: | :---: |
28
+ | MuPT-v1-8192-110M | 110M | 7M x 8 epochs | 8192 | 768 | 12 | 12 |
29
+ | MuPT-v1-8192-345M | 345M | 7M x 6 epochs | 8192 | 1024 | 24 | 16 |
30
+ | MuPT-v1-8192-770M | 770M | 7M x 5 epochs | 8192 | 1280 | 36 | 20 |
31
+ | MuPT-v1-8192-1.3B | 1.3B | 7M x 8 epochs | 8192 | 1536 | 48 | 24 |
32
+
33
+ ## Model Usage
34
+
35
+ #### Huggingface
36
+
37
+ ##### Inference
38
+
39
+ ```python
40
+ from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained("m-a-p/MuPT_v1_8192_1.3B",
43
+ trust_remote_code=True,
44
+ use_fast=False)
45
+ model = AutoModelForCausalLM.from_pretrained("m-a-p/MuPT_v1_8192_1.3B").eval().half().cuda()
46
+
47
+ prefix = "X:1<n>L:1/8<n>Q:1/8=200<n>M:4/4<n>K:Gmin<n>|:\"Gm\" BGdB" # replace "\n" with "<n>" for all the MuPT-8192 models, but not for MuPT-4096 models
48
+ inputs = tokenizer(prefix, return_tensors="pt").to(model.device)
49
+
50
+ max_length = 256
51
+ outputs = model.generate(
52
+ inputs.input_ids,
53
+ max_length=max_length
54
+ )
55
+ outputs = tokenizer.decode(outputs[0])
56
+ print(outputs)
57
+ ```
58
+
59
+ ##### Post-processing
60
+
61
+ Since we merged multiple tracks into one track during training, we need to separate the outputs into standard ABC notation sequences. The post-processing code is as follows:
62
+
63
+ ```python
64
+ import re
65
+
66
+ SEPARATORS = ['|', '|]', '||', '[|', '|:', ':|', '::']
67
+ SEP_DICT = {}
68
+ for i, sep in enumerate(SEPARATORS, start=1):
69
+ # E.g. ' | ': ' <1>'
70
+ SEP_DICT[' '+sep+' '] = f' <{i}>'
71
+ NEWSEP = '<|>'
72
+
73
+ def sep2tok(row):
74
+ for sep, tok in SEP_DICT.items():
75
+ row = row.replace(sep, tok+'<=> ')
76
+ return row
77
+
78
+ def tok2sep(bar):
79
+ for sep, tok in SEP_DICT.items():
80
+ bar = bar.replace(tok, sep)
81
+ return bar
82
+
83
+
84
+ def spacing(row):
85
+
86
+ for sep in SEPARATORS:
87
+
88
+ def subfunc(match):
89
+ symbol = [':', '|', ']']
90
+ if match.group(1) is None:
91
+ return f' {sep}'
92
+ elif match.group(1) in symbol:
93
+ return f' {sep}{match.group(1)}'
94
+ else:
95
+ return ' '+sep+' '+match.group(1)
96
+
97
+ pattern = r' ' + re.escape(sep) + r'(.{1})'
98
+ row = re.sub(pattern, subfunc, row)
99
+ row = row.replace('\n'+sep+'"', '\n '+sep+' "') # B \n|"A -> B \n | "A
100
+ row = row.replace(' '+sep+'\n', ' '+sep+' \n') # B |\n -> B | \n
101
+ return row
102
+
103
+ def decode(piece):
104
+ dec_piece = ''
105
+ idx = piece.find(' '+NEWSEP+' ')
106
+ heads = piece[:idx]
107
+ scores = piece[idx:]
108
+ scores_lst = re.split(' <\|>', scores)
109
+
110
+ all_bar_lst = []
111
+ for bar in scores_lst:
112
+ if bar == '':
113
+ continue
114
+ bar = sep2tok(bar)
115
+ bar_lst = re.split('<=>', bar)
116
+ bar_lst = list(map(tok2sep, bar_lst))
117
+ if len(all_bar_lst) == 0:
118
+ all_bar_lst = [[] for _ in range(len(bar_lst))]
119
+ for i in range(len(bar_lst)):
120
+ all_bar_lst[i].append(bar_lst[i])
121
+
122
+ if len(all_bar_lst) > 1:
123
+ # There might be the bar number like %30 at the end
124
+ # which need to be specially handled.
125
+ if len(all_bar_lst[0]) > len(all_bar_lst[1]):
126
+ last_bar_lst = all_bar_lst[0][-1].split()
127
+ all_bar_lst[0].pop()
128
+ for i in range(len(all_bar_lst)):
129
+ all_bar_lst[i].append(last_bar_lst[i])
130
+ # Add the remaining symbols to the last row.
131
+ if i == len(all_bar_lst) - 1:
132
+ for j in range(i+1, len(last_bar_lst)):
133
+ all_bar_lst[i][-1] += ' ' + last_bar_lst[j]
134
+ # Ensure the lengths are consistent.
135
+ length = len(all_bar_lst[0])
136
+ for lst in all_bar_lst[1:]:
137
+ # assert len(lst) == length
138
+ pass
139
+
140
+ dec_piece += heads
141
+ for i in range(len(all_bar_lst)):
142
+ if len(all_bar_lst) > 1:
143
+ dec_piece += f'V:{i+1}\n'
144
+ dec_piece += ''.join(all_bar_lst[i])
145
+ dec_piece += '\n'
146
+ # Remove redundant spaces.
147
+ dec_piece = re.sub(' {2,}', ' ', dec_piece)
148
+
149
+ return dec_piece
150
+ ```
151
+
152
+ Processed Output:
153
+ ```shell
154
+ X:1
155
+ L:1/8
156
+ Q:1/8=200
157
+ M:4/4<n>K:Gmin
158
+ |:\"Gm\" BGdB fdBG |\"F\" AFcF dFcF |\"Gm\" BGdG gFBF |\"F\" AFAG AF F2 |\"Gm\" BGBd fffd |\"F\" cdcB cdeg |
159
+ \"Gm\" fdcB\"Eb\" AFcA |1 BGFG\"F\" AFGc :|2 BGFG\"F\" AF F2 ||
160
+ ```
161
+
162
+ Once you encode the post-processed ABC notation into audio, you will hear the following music.
163
+
164
+ <audio controls src="https://cdn-uploads.huggingface.co/production/uploads/640701cb4dc5f2846c91d4eb/gnBULaFjcUyXYzzIwXLZq.mpga"></audio>
165
+
166
+ #### Megatron-LM
167
+
168
+ We now the provide usage based on [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main).
169
+
170
+ Before starting, make sure you have setup the relevant environment and codebase.
171
+
172
+ ```shell
173
+ # pull Megatron-LM codebase
174
+ mkdir -p /path/to/workspace && cd /path/to/workspace
175
+ git clone https://github.com/NVIDIA/Megatron-LM.git
176
+ # download the pre-trained MuPT models checkpoint and vocab files from Huggingface page
177
+ mkdir -p /models/MuPT_v0_8192_1.3B && cd /models/MuPT_v0_8192_1.3B
178
+ wget -O model_optim_rng.pt https://huggingface.co/m-a-p/MuPT_v0_8192_1.3B/resolve/main/model_optim_rng.pt?download=true
179
+ wget -O newline.vocab https://huggingface.co/m-a-p/MuPT_v0_8192_1.3B/resolve/main/newline.vocab?download=true
180
+ wget -O newline.txt https://huggingface.co/m-a-p/MuPT_v0_8192_1.3B/resolve/main/newline.txt?download=true
181
+ ```
182
+
183
+ We recommend using the latest version of [NGC's PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for MuPT inference. See more details in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main)
184
+
185
+ ```shell
186
+ # pull the latest NGC's PyTorch container, mount the workspace directory and enter the container
187
+ docker run --gpus all -it --name megatron --shm-size=16g -v $PWD:/workspace -p 5000:5000 nvcr.io/nvidia/pytorch:23.11-py3 /bin/bash
188
+ ```
189
+
190
+ Once you enter the container, you can start a REST server for inference.
191
+
192
+ <details>
193
+ <summary>Click to expand the example script</summary>
194
+
195
+ #!/bin/bash
196
+ # This example will start serving the 1.3B model.
197
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
198
+
199
+ DISTRIBUTED_ARGS="--nproc_per_node 1 \
200
+ --nnodes 1 \
201
+ --node_rank 0 \
202
+ --master_addr localhost \
203
+ --master_port 6000"
204
+
205
+ CHECKPOINT=/path/to/model/checkpoint/folder
206
+ VOCAB_FILE=/path/to/vocab/file
207
+ MERGE_FILE=/path/to/merge/file
208
+
209
+ MODEL_SIZE="1.3B"
210
+ if [[ ${MODEL_SIZE} == "110M" ]]; then HIDDEN_SIZE=768; NUM_HEAD=12; NUM_QUERY_GROUP=12; NUM_LAYERS=12; FFN_HIDDEN_SIZE=3072; NORM_EPS=1e-5;
211
+ elif [[ ${MODEL_SIZE} == "345M" ]]; then HIDDEN_SIZE=1024; NUM_HEAD=16; NUM_QUERY_GROUP=16; NUM_LAYERS=24; FFN_HIDDEN_SIZE=4096; NORM_EPS=1e-5;
212
+ elif [[ ${MODEL_SIZE} == "770M" ]]; then HIDDEN_SIZE=1280; NUM_HEAD=20; NUM_QUERY_GROUP=20; NUM_LAYERS=36; FFN_HIDDEN_SIZE=5120; NORM_EPS=1e-5;
213
+ elif [[ ${MODEL_SIZE} == "1.3B" ]]; then HIDDEN_SIZE=1536; NUM_HEAD=24; NUM_QUERY_GROUP=24; NUM_LAYERS=48; FFN_HIDDEN_SIZE=6144; NORM_EPS=1e-5;
214
+ else echo "invalid MODEL_SIZE: ${MODEL_SIZE}"; exit 1
215
+ fi
216
+ MAX_SEQ_LEN=8192
217
+ MAX_POSITION_EMBEDDINGS=8192
218
+
219
+ pip install flask-restful
220
+
221
+ torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
222
+ --tensor-model-parallel-size 1 \
223
+ --pipeline-model-parallel-size 1 \
224
+ --num-layers ${NUM_LAYERS} \
225
+ --hidden-size ${HIDDEN_SIZE} \
226
+ --ffn-hidden-size ${FFN_HIDDEN_SIZE} \
227
+ --load ${CHECKPOINT} \
228
+ --group-query-attention \
229
+ --num-query-groups ${NUM_QUERY_GROUP} \
230
+ --position-embedding-type rope \
231
+ --num-attention-heads ${NUM_HEAD} \
232
+ --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \
233
+ --tokenizer-type GPT2BPETokenizer \
234
+ --normalization RMSNorm \
235
+ --norm-epsilon ${NORM_EPS} \
236
+ --make-vocab-size-divisible-by 1 \
237
+ --swiglu \
238
+ --use-flash-attn \
239
+ --bf16 \
240
+ --micro-batch-size 1 \
241
+ --disable-bias-linear \
242
+ --no-bias-gelu-fusion \
243
+ --untie-embeddings-and-output-weights \
244
+ --seq-length ${MAX_SEQ_LEN} \
245
+ --vocab-file $VOCAB_FILE \
246
+ --merge-file $MERGE_FILE \
247
+ --attention-dropout 0.0 \
248
+ --hidden-dropout 0.0 \
249
+ --weight-decay 1e-1 \
250
+ --clip-grad 1.0 \
251
+ --adam-beta1 0.9 \
252
+ --adam-beta2 0.95 \
253
+ --adam-eps 1e-8 \
254
+ --seed 42
255
+
256
+ </details>
257
+
258
+
259
+ Use CURL to query the server directly, note that the newline token `\n` is represented by `<n>` in the vocabulary, so we need to replace the newline token with `<n>` in both the prompt and the generated tokens.
260
+
261
+ ```shell
262
+ curl 'http://localhost:6000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["X:1<n>L:1/8<n>Q:1/8=200<n>M:4/4<n>K:Gmin<n>|:\"Gm\" BGdB"], "tokens_to_generate":4096}'
263
+ ```
264
+ Processed Output:
265
+ ```shell
266
+ X:1
267
+ L:1/8
268
+ Q:1/8=200
269
+ M:4/4<n>K:Gmin
270
+ |:\"Gm\" BGdB fdBG |\"F\" AFcF dFcF |\"Gm\" BGdG gFBF |\"F\" AFAG AF F2 |\"Gm\" BGBd fffd |\"F\" cdcB cdeg |
271
+ \"Gm\" fdcB\"Eb\" AFcA |1 BGFG\"F\" AFGc :|2 BGFG\"F\" AF F2 ||
272
+ ```