krystv commited on
Commit
1da781e
1 Parent(s): 8448f93

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ subprocess.check_call([sys.executable,"-m","pip","install",'causal-conv1d'])
5
+ subprocess.check_call([sys.executable, "-m", "pip", "install", 'miditok','mamba-ssm','gradio'])
6
+ subprocess.check_call(["apt-get", "install", "timidity", "-y"])
7
+
8
+ # !pip install pretty_midi midi2audio
9
+ # !pip install miditok
10
+ # !apt-get install fluidsynth
11
+ # !apt install timidity -y
12
+ # !pip install causal-conv1d>=1.1.0
13
+ # !pip install mamba-ssm
14
+ # !pip install gradio
15
+
16
+
17
+
18
+ # !export LC_ALL="en_US.UTF-8"
19
+ # !export LD_LIBRARY_PATH="/usr/lib64-nvidia"
20
+ # !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
21
+
22
+ # subprocess.check_call(['export', 'LC_ALL="en_US.UTF-8"'])
23
+ # subprocess.check_call(['export', 'LD_LIBRARY_PATH="/usr/lib64-nvidia"'])
24
+ # subprocess.check_call(['export', 'LIBRARY_PATH="/usr/local/cuda/lib64/stubs"'])
25
+ import os
26
+
27
+ os.environ['LC_ALL'] = "en_US.UTF-8"
28
+ os.environ['LD_LIBRARY_PATH'] = "/usr/lib64-nvidia"
29
+ os.environ['LIBRARY_PATH'] = "/usr/local/cuda/lib64/stubs"
30
+
31
+
32
+
33
+ import gradio as gr
34
+ import torch
35
+ from mamba_ssm import Mamba
36
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
37
+ from mamba_ssm.models.config_mamba import MambaConfig
38
+ import numpy as np
39
+
40
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
41
+ if torch.cuda.is_available():
42
+ subprocess.check_call(['ldconfig', '/usr/lib64-nvidia'])
43
+ # !ldconfig /usr/lib64-nvidia
44
+
45
+ # !wget "https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/MIDI_Mamba-159M_1536VS.pt"
46
+ # !wget "https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json"
47
+ if os.path.isfile("MIDI_Mamba-159M_1536VS.pt") == False:
48
+ subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/MIDI_Mamba-159M_1536VS.pt'])
49
+
50
+ if os.path.isfile("tokenizer_1536mix_BPE.json") == False:
51
+ subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json'])
52
+
53
+
54
+
55
+ mc = MambaConfig()
56
+ mc.d_model = 768
57
+ mc.n_layer = 42
58
+ mc.vocab_size = 1536
59
+
60
+ from miditok import MIDILike,REMI,TokenizerConfig
61
+ from pathlib import Path
62
+ import torch
63
+
64
+ tokenizer = REMI(params='tokenizer_1536mix_BPE.json')
65
+
66
+
67
+
68
+ mf = MambaLMHeadModel(config=mc,device=device)
69
+ mf.load_state_dict(torch.load("/content/MIDI_Mamba-159M_1536VS.pt",map_location=device))
70
+
71
+
72
+
73
+ twitter_follow_link = "https://twitter.com/iamhemantindia"
74
+ instagram_follow_link = "https://instagram.com/iamhemantindia"
75
+
76
+ custom_html = f"""
77
+ <div style='text-align: center;'>
78
+ <a href="{twitter_follow_link}" target="_blank" style="margin-right: 5px;">
79
+ <img src="https://img.icons8.com/fluent/24/000000/twitter.png" alt="Follow on Twitter"/>
80
+ </a>
81
+ <a href="{instagram_follow_link}" target="_blank">
82
+ <img src="https://img.icons8.com/fluent/24/000000/instagram-new.png" alt="Follow on Instagram"/>
83
+ </a>
84
+ </div>
85
+ """
86
+
87
+
88
+ @spaces.GPU(duration=120)
89
+ def generate(number,top_k_selector,top_p_selector, temperature_selector):
90
+ input_ids = torch.tensor([[1,]]).to(device)
91
+ out = mf.generate(
92
+ input_ids=input_ids,
93
+ max_length=int(number),
94
+ temperature=temperature_selector,
95
+ top_p=top_p_selector,
96
+ top_k=top_k_selector,
97
+
98
+ eos_token_id=2,)
99
+ m = tokenizer.decode(np.array(out[0].to('cpu')))
100
+ np.array(out.to('cpu')).shape
101
+ m.dump_midi('output.mid')
102
+ # !timidity output.mid -Ow -o - | ffmpeg -y -f wav -i - output.mp3
103
+ timidity_cmd = ['timidity', 'output.mid', '-Ow', '-o', 'output.wav']
104
+ subprocess.check_call(timidity_cmd)
105
+
106
+ # Then convert the WAV to MP3 using ffmpeg
107
+ ffmpeg_cmd = ['ffmpeg', '-y', '-f', 'wav', '-i', 'output.wav', 'output.mp3']
108
+ subprocess.check_call(ffmpeg_cmd)
109
+
110
+ return "output.mp3"
111
+
112
+
113
+ # text_box = gr.Textbox(label="Enter Text")
114
+
115
+
116
+ def generate_and_save(number,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid):
117
+ output_audio = generate(number,top_k_selector,top_p_selector, temperature_selector)
118
+ return gr.Audio(output_audio,autoplay=True),gr.File(label="Download MIDI",value="output.mid"),generate_button
119
+
120
+
121
+
122
+
123
+
124
+
125
+ # iface = gr.Interface(fn=generate_and_save,
126
+ # inputs=[number_selector,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid],
127
+ # outputs=[output_box,download_midi_button],
128
+ # title="MIDI Mamba-159M",submit_btn=False,
129
+ # clear_btn=False,
130
+ # description="MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model.",
131
+ # allow_flagging=False,)
132
+
133
+ with gr.Blocks() as b1:
134
+ gr.Markdown("<h1 style='text-align: center;'>MIDI Mamba-159M <h1/> ")
135
+ gr.Markdown("<h3 style='text-align: center;'>MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model. <br> by Hemant Kumar<h3/>")
136
+ with gr.Row():
137
+ with gr.Column():
138
+ number_selector = gr.Number(label="Select Length of output",value=512)
139
+ top_p_selector = gr.Slider(label="Select Top P", minimum=0, maximum=1.0, step=0.05, value=0.9)
140
+ temperature_selector = gr.Slider(label="Select Temperature", minimum=0, maximum=1.0, step=0.1, value=0.9)
141
+ top_k_selector = gr.Slider(label="Select Top K", minimum=1, maximum=1536, step=1, value=30)
142
+ generate_button = gr.Button(value="Generate",variant="primary")
143
+ custom_html_wid = gr.HTML(custom_html)
144
+ with gr.Column():
145
+ output_box = gr.Audio("output.mp3",autoplay=True,)
146
+ download_midi_button = gr.File(label="Download MIDI")
147
+ generate_button.click(generate_and_save,inputs=[number_selector,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid],outputs=[output_box,download_midi_button,generate_button])
148
+
149
+
150
+
151
+
152
+ b1.launch(share=True)