File size: 5,077 Bytes
601c87a
 
 
0089eca
 
 
186a22d
9a75c40
 
601c87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d28b9a
601c87a
9a75c40
601c87a
 
 
 
 
 
 
9a75c40
601c87a
 
 
 
 
 
 
 
 
 
 
9a75c40
 
601c87a
 
 
 
9a75c40
 
 
601c87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1da781e
 
 
601c87a
9a75c40
 
 
 
601c87a
9a75c40
601c87a
 
 
 
 
 
9a75c40
601c87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a75c40
1da781e
601c87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import subprocess
import sys

#update the system
subprocess.check_call(["apt-get", "update"])
subprocess.check_call(["apt-get", "install", "timidity", "-y"])
subprocess.check_call([sys.executable,"-m","pip","install", 'torch'])
# subprocess.check_call([sys.executable,"-m","pip","install",'causal-conv1d'])
subprocess.check_call([sys.executable, "-m", "pip", "install", 'numpy', 'miditok','transformers','gradio'])

# !pip install pretty_midi midi2audio
# !pip install miditok
# !apt-get install fluidsynth
# !apt install timidity -y
# !pip install causal-conv1d>=1.1.0
# !pip install mamba-ssm
# !pip install gradio



# !export LC_ALL="en_US.UTF-8"
# !export LD_LIBRARY_PATH="/usr/lib64-nvidia"
# !export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"

# subprocess.check_call(['export', 'LC_ALL="en_US.UTF-8"'])
# subprocess.check_call(['export', 'LD_LIBRARY_PATH="/usr/lib64-nvidia"'])
# subprocess.check_call(['export', 'LIBRARY_PATH="/usr/local/cuda/lib64/stubs"'])
import os

from transformers import MambaConfig, MambaForCausalLM


import gradio as gr
import torch
import numpy as np

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = MambaForCausalLM.from_pretrained("krystv/MIDI_Mamba-159M")



mc = MambaConfig()
mc.d_model = 768
mc.n_layer = 42
mc.vocab_size = 1536

from miditok import MIDILike,REMI,TokenizerConfig
from pathlib import Path
import torch
if os.path.isfile("tokenizer_1536mix_BPE.json") == False:
    subprocess.check_call(['wget', 'https://huggingface.co/krystv/MIDI_Mamba-159M/resolve/main/tokenizer_1536mix_BPE.json'])
tokenizer = REMI(params='tokenizer_1536mix_BPE.json')



if torch.cuda.is_available():
    model= model.to(device)




twitter_follow_link = "https://twitter.com/iamhemantindia"
instagram_follow_link = "https://instagram.com/iamhemantindia"

custom_html = f"""

<div style='text-align: center;'>

    <a href="{twitter_follow_link}" target="_blank" style="margin-right: 5px;">

        <img src="https://img.icons8.com/fluent/24/000000/twitter.png" alt="Follow on Twitter"/>

    </a>

    <a href="{instagram_follow_link}" target="_blank">

        <img src="https://img.icons8.com/fluent/24/000000/instagram-new.png" alt="Follow on Instagram"/>

    </a>

</div>

"""


def generate(number,top_k_selector,top_p_selector, temperature_selector):
    input_ids = torch.tensor([[1,]])
    if torch.cuda.is_available():
        input_ids = input_ids.to(device)
    out = model.generate(
    input_ids=input_ids,
    do_sample=True,
    max_length=int(number),
    temperature=temperature_selector,
    top_p=top_p_selector,
    top_k=top_k_selector,

    eos_token_id=2,)
    m = tokenizer.decode(np.array(out[0].cpu()))
    m.dump_midi('output.mid')
    # !timidity output.mid -Ow -o - | ffmpeg -y -f wav -i - output.mp3
    timidity_cmd = ['timidity', 'output.mid', '-Ow', '-o', 'output.wav']
    subprocess.check_call(timidity_cmd)

    # Then convert the WAV to MP3 using ffmpeg
    ffmpeg_cmd = ['ffmpeg', '-y', '-f', 'wav', '-i', 'output.wav', 'output.mp3']
    subprocess.check_call(ffmpeg_cmd)

    return "output.mp3"


# text_box = gr.Textbox(label="Enter Text")


def generate_and_save(number,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid):
    output_audio = generate(number,top_k_selector,top_p_selector, temperature_selector)
    return gr.Audio(output_audio,autoplay=True),gr.File(label="Download MIDI",value="output.mid"),generate_button







with gr.Blocks() as b1:
    gr.Markdown("<h1 style='text-align: center;'>MIDI Mamba-159M <h1/> ")
    gr.Markdown("<h3 style='text-align: center;'>MIDI Mamba is a Mamba based model trained on MIDI data collected from open internet to train music model. Current accuracy is performance is decreased due to model conversion make sure to checkout the <a href='https://colab.research.google.com/github/HemantKArya/MIDI_Mamba/blob/main/MIDI_Mamba.ipynb'>colab notebook<a/> <br> by Hemant Kumar<h3/>")
    with gr.Row():
        with gr.Column():
            number_selector = gr.Number(label="Select Length of output",value=512)
            top_p_selector = gr.Slider(label="Select Top P", minimum=0, maximum=1.0, step=0.05, value=0.9)
            temperature_selector = gr.Slider(label="Select Temperature", minimum=0, maximum=1.0, step=0.1, value=0.9)
            top_k_selector = gr.Slider(label="Select Top K", minimum=1, maximum=1536, step=1, value=30)
            generate_button = gr.Button(value="Generate",variant="primary")
            custom_html_wid = gr.HTML(custom_html)
        with gr.Column():
            output_box = gr.Audio("output.mp3",autoplay=True,)
            download_midi_button = gr.File(label="Download MIDI")
            generate_button.click(generate_and_save,inputs=[number_selector,top_k_selector,top_p_selector, temperature_selector,generate_button,custom_html_wid],outputs=[output_box,download_midi_button,generate_button])




b1.launch(share=True)