zhangysk paralym commited on
Commit
0dda770
1 Parent(s): dabc87b

Update README.md (#1)

Browse files

- Update README.md (a547e24017989149bc01a523147a6c10cc67cfe7)


Co-authored-by: paralym <paralym@users.noreply.huggingface.co>

Files changed (1) hide show
  1. README.md +130 -0
README.md CHANGED
@@ -1,3 +1,133 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - en
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - music
8
+ - art
9
  ---
10
+ ## SMuPT: Symbolic Music Generative Pre-trained Transformer
11
+
12
+ SMuPT is a series of pre-trained models for symbolic music generation. It was trained on a large-scale dataset of symbolic music, including millions of monophonic and polyphonic pieces from different genres and styles. The models are trained with the LLama2 architecture, and can be further used for downstream music generation tasks such as melody generation, accompaniment generation, and multi-track music generation.
13
+
14
+ - 09/01/2024: a series of pre-trained SMuPT models are released, with parameters ranging from 110M to 1.3B.
15
+
16
+ ## Model architecture
17
+
18
+ The details of model architecture of SMuPT-v0 are listed below:
19
+
20
+ | Name | Parameters | Training Data(Music Pieces) | Seq Length | Hidden Size | Layers | Heads |
21
+ | :--- | :---: | :---: | :---: | :---: | :---: | :---: |
22
+ | SMuPT-v0-8192-110M | 110M | 7M x 5.8 epochs | 8192 | 768 | 12 | 12 |
23
+ | SMuPT-v0-8192-345M | 345M | 7M x 4 epochs | 8192 | 1024 | 24 | 16 |
24
+ | SMuPT-v0-8192-770M | 770M | 7M x 3 epochs | 8192 | 1280 | 36 | 20 |
25
+ | SMuPT-v0-8192-1.3B | 1.3B | 7M x 2.2 epochs | 8192 | 1536 | 48 | 24 |
26
+
27
+ ## Model Usage
28
+
29
+ There are several ways to use our pre-trained SMuPT models, we now the usage based on [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main). Huggingface format will be supported soon.
30
+
31
+ Before starting, make sure you have setup the relevant environment and codebase.
32
+
33
+ ```shell
34
+ # pull Megatron-LM codebase
35
+ mkdir -p /path/to/workspace && cd /path/to/workspace
36
+ git clone https://github.com/NVIDIA/Megatron-LM.git
37
+
38
+ # download the pre-trained SMuPT models checkpoint and vocab files from Huggingface page
39
+ mkdir -p /models/SMuPT_v0_8192_1.3B && cd /models/SMuPT_v0_8192_1.3B
40
+ wget -O model_optim_rng.pt https://huggingface.co/m-a-p/SMuPT_v0_8192_1.3B/resolve/main/model_optim_rng.pt?download=true
41
+ wget -O newline.vocab https://huggingface.co/m-a-p/SMuPT_v0_8192_1.3B/resolve/main/newline.vocab?download=true
42
+ wget -O newline.txt https://huggingface.co/m-a-p/SMuPT_v0_8192_1.3B/resolve/main/newline.txt?download=true
43
+ ```
44
+
45
+ We recommend using the latest version of [NGC's PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for SMuPT inference. See more details in [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main)
46
+
47
+ ```shell
48
+ # pull the latest NGC's PyTorch container, mount the workspace directory and enter the container
49
+ docker run --gpus all -it --name megatron --shm-size=16g -v $PWD:/workspace -p 5000:5000 nvcr.io/nvidia/pytorch:23.11-py3 /bin/bash
50
+ ```
51
+
52
+ Once you enter the container, you can start a REST server for inference.
53
+
54
+ <details>
55
+ <summary>Click to expand the example script</summary>
56
+
57
+ #!/bin/bash
58
+ # This example will start serving the 1.3B model.
59
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
60
+
61
+ DISTRIBUTED_ARGS="--nproc_per_node 1 \
62
+ --nnodes 1 \
63
+ --node_rank 0 \
64
+ --master_addr localhost \
65
+ --master_port 6000"
66
+
67
+ CHECKPOINT=/path/to/model/checkpoint/folder
68
+ VOCAB_FILE=/path/to/vocab/file
69
+ MERGE_FILE=/path/to/merge/file
70
+
71
+ MODEL_SIZE="1.3B"
72
+ if [[ ${MODEL_SIZE} == "110M" ]]; then HIDDEN_SIZE=768; NUM_HEAD=12; NUM_QUERY_GROUP=12; NUM_LAYERS=12; FFN_HIDDEN_SIZE=3072; NORM_EPS=1e-5;
73
+ elif [[ ${MODEL_SIZE} == "345M" ]]; then HIDDEN_SIZE=1024; NUM_HEAD=16; NUM_QUERY_GROUP=16; NUM_LAYERS=24; FFN_HIDDEN_SIZE=4096; NORM_EPS=1e-5;
74
+ elif [[ ${MODEL_SIZE} == "770M" ]]; then HIDDEN_SIZE=1280; NUM_HEAD=20; NUM_QUERY_GROUP=20; NUM_LAYERS=36; FFN_HIDDEN_SIZE=5120; NORM_EPS=1e-5;
75
+ elif [[ ${MODEL_SIZE} == "1.3B" ]]; then HIDDEN_SIZE=1536; NUM_HEAD=24; NUM_QUERY_GROUP=24; NUM_LAYERS=48; FFN_HIDDEN_SIZE=6144; NORM_EPS=1e-5;
76
+ else echo "invalid MODEL_SIZE: ${MODEL_SIZE}"; exit 1
77
+ fi
78
+ MAX_SEQ_LEN=8192
79
+ MAX_POSITION_EMBEDDINGS=8192
80
+
81
+ pip install flask-restful
82
+
83
+ torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
84
+ --tensor-model-parallel-size 1 \
85
+ --pipeline-model-parallel-size 1 \
86
+ --num-layers ${NUM_LAYERS} \
87
+ --hidden-size ${HIDDEN_SIZE} \
88
+ --ffn-hidden-size ${FFN_HIDDEN_SIZE} \
89
+ --load ${CHECKPOINT} \
90
+ --group-query-attention \
91
+ --num-query-groups ${NUM_QUERY_GROUP} \
92
+ --position-embedding-type rope \
93
+ --num-attention-heads ${NUM_HEAD} \
94
+ --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \
95
+ --tokenizer-type GPT2BPETokenizer \
96
+ --normalization RMSNorm \
97
+ --norm-epsilon ${NORM_EPS} \
98
+ --make-vocab-size-divisible-by 1 \
99
+ --swiglu \
100
+ --use-flash-attn \
101
+ --bf16 \
102
+ --micro-batch-size 1 \
103
+ --disable-bias-linear \
104
+ --no-bias-gelu-fusion \
105
+ --untie-embeddings-and-output-weights \
106
+ --seq-length ${MAX_SEQ_LEN} \
107
+ --vocab-file $VOCAB_FILE \
108
+ --merge-file $MERGE_FILE \
109
+ --attention-dropout 0.0 \
110
+ --hidden-dropout 0.0 \
111
+ --weight-decay 1e-1 \
112
+ --clip-grad 1.0 \
113
+ --adam-beta1 0.9 \
114
+ --adam-beta2 0.95 \
115
+ --adam-eps 1e-8 \
116
+ --seed 42
117
+
118
+ </details>
119
+
120
+
121
+ Use CURL to query the server directly, note that the newline token `\n` is represented by `<n>` in the vocabulary, so we need to replace the newline token with `<n>` in both the prompt and the generated tokens.
122
+
123
+ ```shell
124
+ curl 'http://localhost:6000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["X:1<n>L:1/8<n>Q:1/8=200<n>M:4/4<n>K:Gmin<n>|:\"Gm\" BGdB"], "tokens_to_generate":4096}'
125
+ ```
126
+ Output:
127
+ ```shell
128
+ X:1<n>L:1/8<n>Q:1/8=200<n>M:4/4<n>K:Gmin<n>|:\"Gm\" BGdB fdBG |\"F\" AFcF dFcF |\"Gm\" BGdG gFBF |\"F\" AFAG AF F2 |\"Gm\" BGBd fffd |\"F\" cdcB cdeg | <n>\"Gm\" fdcB\"Eb\" AFcA |1 BGFG\"F\" AFGc :|2 BGFG\"F\" AF F2 ||<eos>
129
+ ```
130
+
131
+ Once you encode the generated tokens into audio, you will hear the following music.
132
+
133
+ <audio controls src="https://cdn-uploads.huggingface.co/production/uploads/640701cb4dc5f2846c91d4eb/gnBULaFjcUyXYzzIwXLZq.mpga"></audio>