krystv commited on
Commit
9a75c40
1 Parent(s): 186a22d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -36
app.py CHANGED
@@ -2,8 +2,8 @@ import subprocess
2
  import sys
3
 
4
  subprocess.check_call([sys.executable,"-m","pip","install", 'torch'])
5
- subprocess.check_call([sys.executable,"-m","pip","install",'causal-conv1d'])
6
- subprocess.check_call([sys.executable, "-m", "pip", "install", 'numpy', 'miditok','mamba-ssm','gradio'])
7
  subprocess.check_call(["apt-get", "install", "timidity", "-y"])
8
 
9
  # !pip install pretty_midi midi2audio
@@ -25,31 +25,15 @@ subprocess.check_call(["apt-get", "install", "timidity", "-y"])
25
  # subprocess.check_call(['export', 'LIBRARY_PATH="/usr/local/cuda/lib64/stubs"'])
26
  import os
27
 
28
- os.environ['LC_ALL'] = "en_US.UTF-8"
29
- os.environ['LD_LIBRARY_PATH'] = "/usr/lib64-nvidia"
30
- os.environ['LIBRARY_PATH'] = "/usr/local/cuda/lib64/stubs"
31
-
32
 
33
 
34
  import gradio as gr
35
  import torch
36
- from mamba_ssm import Mamba
37
- from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
38
- from mamba_ssm.models.config_mamba import MambaConfig
39
  import numpy as np
40
 
41
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
42
- if torch.cuda.is_available():
43
- subprocess.check_call(['ldconfig', '/usr/lib64-nvidia'])
44
- # !ldconfig /usr/lib64-nvidia
45
-
46
- # !wget "https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/MIDI_Mamba-159M_1536VS.pt"
47
- # !wget "https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json"
48
- if os.path.isfile("MIDI_Mamba-159M_1536VS.pt") == False:
49
- subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/MIDI_Mamba-159M_1536VS.pt'])
50
-
51
- if os.path.isfile("tokenizer_1536mix_BPE.json") == False:
52
- subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json'])
53
 
54
 
55
 
@@ -61,13 +45,15 @@ mc.vocab_size = 1536
61
  from miditok import MIDILike,REMI,TokenizerConfig
62
  from pathlib import Path
63
  import torch
64
-
 
65
  tokenizer = REMI(params='tokenizer_1536mix_BPE.json')
66
 
67
 
68
 
69
- mf = MambaLMHeadModel(config=mc,device=device)
70
- mf.load_state_dict(torch.load("/content/MIDI_Mamba-159M_1536VS.pt",map_location=device))
 
71
 
72
 
73
 
@@ -86,19 +72,20 @@ custom_html = f"""
86
  """
87
 
88
 
89
- @spaces.GPU(duration=120)
90
  def generate(number,top_k_selector,top_p_selector, temperature_selector):
91
- input_ids = torch.tensor([[1,]]).to(device)
92
- out = mf.generate(
 
 
93
  input_ids=input_ids,
 
94
  max_length=int(number),
95
  temperature=temperature_selector,
96
  top_p=top_p_selector,
97
  top_k=top_k_selector,
98
 
99
  eos_token_id=2,)
100
- m = tokenizer.decode(np.array(out[0].to('cpu')))
101
- np.array(out.to('cpu')).shape
102
  m.dump_midi('output.mid')
103
  # !timidity output.mid -Ow -o - | ffmpeg -y -f wav -i - output.mp3
104
  timidity_cmd = ['timidity', 'output.mid', '-Ow', '-o', 'output.wav']
@@ -123,17 +110,10 @@ def generate_and_save(number,top_k_selector,top_p_selector, temperature_selector
123
 
124
 
125
 
126
- # iface = gr.Interface(fn=generate_and_save,
127
- # inputs=[number_selector,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid],
128
- # outputs=[output_box,download_midi_button],
129
- # title="MIDI Mamba-159M",submit_btn=False,
130
- # clear_btn=False,
131
- # description="MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model.",
132
- # allow_flagging=False,)
133
 
134
  with gr.Blocks() as b1:
135
  gr.Markdown("<h1 style='text-align: center;'>MIDI Mamba-159M <h1/> ")
136
- gr.Markdown("<h3 style='text-align: center;'>MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model. <br> by Hemant Kumar<h3/>")
137
  with gr.Row():
138
  with gr.Column():
139
  number_selector = gr.Number(label="Select Length of output",value=512)
 
2
  import sys
3
 
4
  subprocess.check_call([sys.executable,"-m","pip","install", 'torch'])
5
+ # subprocess.check_call([sys.executable,"-m","pip","install",'causal-conv1d'])
6
+ subprocess.check_call([sys.executable, "-m", "pip", "install", 'numpy', 'miditok','transformers','gradio'])
7
  subprocess.check_call(["apt-get", "install", "timidity", "-y"])
8
 
9
  # !pip install pretty_midi midi2audio
 
25
  # subprocess.check_call(['export', 'LIBRARY_PATH="/usr/local/cuda/lib64/stubs"'])
26
  import os
27
 
28
+ from transformers import MambaConfig, MambaForCausalLM
 
 
 
29
 
30
 
31
  import gradio as gr
32
  import torch
 
 
 
33
  import numpy as np
34
 
35
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
36
+ model = MambaForCausalLM.from_pretrained("krystv/MIDI_Mamba-159M")
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
 
 
45
  from miditok import MIDILike,REMI,TokenizerConfig
46
  from pathlib import Path
47
  import torch
48
+ if os.path.isfile("tokenizer_1536mix_BPE.json") == False:
49
+ subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json'])
50
  tokenizer = REMI(params='tokenizer_1536mix_BPE.json')
51
 
52
 
53
 
54
+ if torch.cuda.is_available():
55
+ model= model.to(device)
56
+
57
 
58
 
59
 
 
72
  """
73
 
74
 
 
75
  def generate(number,top_k_selector,top_p_selector, temperature_selector):
76
+ input_ids = torch.tensor([[1,]])
77
+ if torch.cuda.is_available():
78
+ input_ids = input_ids.to(device)
79
+ out = model.generate(
80
  input_ids=input_ids,
81
+ do_sample=True,
82
  max_length=int(number),
83
  temperature=temperature_selector,
84
  top_p=top_p_selector,
85
  top_k=top_k_selector,
86
 
87
  eos_token_id=2,)
88
+ m = tokenizer.decode(np.array(out[0].cpu()))
 
89
  m.dump_midi('output.mid')
90
  # !timidity output.mid -Ow -o - | ffmpeg -y -f wav -i - output.mp3
91
  timidity_cmd = ['timidity', 'output.mid', '-Ow', '-o', 'output.wav']
 
110
 
111
 
112
 
 
 
 
 
 
 
 
113
 
114
  with gr.Blocks() as b1:
115
  gr.Markdown("<h1 style='text-align: center;'>MIDI Mamba-159M <h1/> ")
116
+ gr.Markdown("<h3 style='text-align: center;'>MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model. Current accuracy is performance is decreased due to model conversion make sure to checkout the <a href='https://colab.research.google.com/github/HemantKArya/MIDI_Mamba/blob/main/MIDI_Mamba.ipynb'>colab notebook<a/> <br> by Hemant Kumar<h3/>")
117
  with gr.Row():
118
  with gr.Column():
119
  number_selector = gr.Number(label="Select Length of output",value=512)