patchbanks commited on
Commit
508b4d3
·
verified ·
1 Parent(s): 2728cc4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -47
app.py CHANGED
@@ -10,6 +10,7 @@ from contextlib import nullcontext
10
  from model import GPTConfig, GPT
11
  from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter
12
  from pedalboard.io import AudioFile
 
13
 
14
  in_space = os.getenv("SYSTEM") == "spaces"
15
 
@@ -22,7 +23,7 @@ ckpt_load = 'model.pt'
22
 
23
  start = "000000000000\n"
24
  num_samples = 1
25
- max_new_tokens = 564
26
 
27
  seed = random.randint(1, 100000)
28
  torch.manual_seed(seed)
@@ -58,9 +59,9 @@ model.to(device)
58
  if compile:
59
  model = torch.compile(model)
60
 
61
- tokenizer = re.compile(r'000000000000|\d{1}|\n')
62
 
63
- meta_path = os.path.join('', 'meta.pkl')
64
  with open(meta_path, 'rb') as f:
65
  meta = pickle.load(f)
66
  stoi = meta.get('stoi', None)
@@ -131,7 +132,6 @@ def generate_midi(temperature, top_k):
131
  return midi_events
132
 
133
 
134
-
135
  def write_midi(midi_events, bpm):
136
  midi_data = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96)
137
  midi_data.time_signature_changes.append(pretty_midi.containers.TimeSignature(4, 4, 0))
@@ -152,19 +152,21 @@ def write_midi(midi_events, bpm):
152
  print(f"Generated: {midi_path}")
153
 
154
 
155
- def render_wav(midi_file):
156
-
157
  sf2_dir = 'sf2_kits'
158
  audio_format = 's16'
159
  sample_rate = '44100'
160
  gain = '2.0'
161
 
162
- sf2_files = [f for f in os.listdir(sf2_dir) if f.endswith('.sf2')]
163
- if not sf2_files:
164
- raise ValueError("No SoundFont (.sf2) file found in directory.")
 
 
 
 
165
 
166
- sf2_file = os.path.join(sf2_dir, random.choice(sf2_files))
167
- print(sf2_file)
168
  output_wav = os.path.join(temp_dir, 'output.wav')
169
 
170
  with open(os.devnull, 'w') as devnull:
@@ -177,23 +179,7 @@ def render_wav(midi_file):
177
  return output_wav
178
 
179
 
180
- def render_sfx(wav_raw, settings):
181
- wav_fx = wav_raw
182
-
183
- for setting in settings:
184
- board = setting['board']
185
-
186
- with AudioFile(wav_raw) as f:
187
- with AudioFile(wav_fx, 'w', f.samplerate, f.num_channels) as o:
188
- while f.tell() < f.frames:
189
- chunk = f.read(int(f.samplerate))
190
- effected = board(chunk, f.samplerate, reset=False)
191
- o.write(effected)
192
-
193
- return wav_fx
194
-
195
-
196
- def generate_and_return_files(bpm, temperature, top_k):
197
  midi_events = generate_midi(temperature, top_k)
198
  if not midi_events:
199
  return "Error generating MIDI.", None, None
@@ -201,7 +187,7 @@ def generate_and_return_files(bpm, temperature, top_k):
201
  write_midi(midi_events, bpm)
202
 
203
  midi_file = os.path.join(temp_dir, 'output.mid')
204
- wav_raw = render_wav(midi_file)
205
  wav_fx = os.path.join(temp_dir, 'output_fx.wav')
206
 
207
  sfx_settings = [
@@ -226,22 +212,45 @@ def generate_and_return_files(bpm, temperature, top_k):
226
  return midi_file, wav_fx
227
 
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- iface = gr.Interface(
231
- fn=generate_and_return_files,
232
- inputs=[
233
- gr.Slider(minimum=50, maximum=200, step=1, value=87, label="bpm"),
234
- gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="temperature"),
235
- gr.Slider(minimum=4, maximum=128, step=1, value=16, label="top_k")
236
- ],
237
- outputs=[
238
- gr.File(label="MIDI File"),
239
- gr.Audio(label="Generated Audio", type="filepath")
240
- ],
241
- title="<h1 style='font-weight: bold; text-align: center;'>nanoMPC - AI Midi Drum Sequencer</h1>",
242
- description="<p style='text-align:center;'>nanoMPC is a tiny transformer model that generates MIDI drum beats inspired by Lo-Fi, Boom Bap and other styles of Hip Hop.</p>",
243
- theme="soft",
244
- allow_flagging="never",
245
- )
246
-
247
- iface.launch()
 
10
  from model import GPTConfig, GPT
11
  from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter
12
  from pedalboard.io import AudioFile
13
+ import gradio as gr
14
 
15
  in_space = os.getenv("SYSTEM") == "spaces"
16
 
 
23
 
24
  start = "000000000000\n"
25
  num_samples = 1
26
+ max_new_tokens = 384
27
 
28
  seed = random.randint(1, 100000)
29
  torch.manual_seed(seed)
 
59
  if compile:
60
  model = torch.compile(model)
61
 
62
+ tokenizer = re.compile(r'000000000000|\d{2}|\n')
63
 
64
+ meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
65
  with open(meta_path, 'rb') as f:
66
  meta = pickle.load(f)
67
  stoi = meta.get('stoi', None)
 
132
  return midi_events
133
 
134
 
 
135
  def write_midi(midi_events, bpm):
136
  midi_data = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96)
137
  midi_data.time_signature_changes.append(pretty_midi.containers.TimeSignature(4, 4, 0))
 
152
  print(f"Generated: {midi_path}")
153
 
154
 
155
+ def render_wav(midi_file, uploaded_sf2=None):
 
156
  sf2_dir = 'sf2_kits'
157
  audio_format = 's16'
158
  sample_rate = '44100'
159
  gain = '2.0'
160
 
161
+ if uploaded_sf2:
162
+ sf2_file = uploaded_sf2
163
+ else:
164
+ sf2_files = [f for f in os.listdir(sf2_dir) if f.endswith('.sf2')]
165
+ if not sf2_files:
166
+ raise ValueError("No SoundFont (.sf2) file found in directory.")
167
+ sf2_file = os.path.join(sf2_dir, random.choice(sf2_files))
168
 
169
+ print(f"Using SoundFont: {sf2_file}")
 
170
  output_wav = os.path.join(temp_dir, 'output.wav')
171
 
172
  with open(os.devnull, 'w') as devnull:
 
179
  return output_wav
180
 
181
 
182
+ def generate_and_return_files(bpm, temperature, top_k, uploaded_sf2=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  midi_events = generate_midi(temperature, top_k)
184
  if not midi_events:
185
  return "Error generating MIDI.", None, None
 
187
  write_midi(midi_events, bpm)
188
 
189
  midi_file = os.path.join(temp_dir, 'output.mid')
190
+ wav_raw = render_wav(midi_file, uploaded_sf2)
191
  wav_fx = os.path.join(temp_dir, 'output_fx.wav')
192
 
193
  sfx_settings = [
 
212
  return midi_file, wav_fx
213
 
214
 
215
+ custom_css = """
216
+ #generate-btn {
217
+ background-color: #6366f1 !important;
218
+ color: white !important;
219
+ border: none !important;
220
+ font-size: 16px;
221
+ padding: 10px 20px;
222
+ border-radius: 5px;
223
+ cursor: pointer;
224
+ }
225
+ #generate-btn:hover {
226
+ background-color: #4f51c5 !important;
227
+ }
228
+ """
229
+
230
+ with gr.Blocks(css=custom_css, theme="soft") as iface:
231
+ gr.Markdown("<h1 style='font-weight: bold; text-align: center;'>nanoMPC - AI Midi Drum Sequencer</h1>")
232
+ gr.Markdown("<p style='text-align:center;'>nanoMPC is a tiny transformer model that generates MIDI drum beats inspired by Lo-Fi, Boom Bap and other styles of Hip Hop.</p>")
233
+
234
+ with gr.Row():
235
+ with gr.Column(scale=1):
236
+ bpm = gr.Slider(minimum=50, maximum=200, step=1, value=90, label="BPM")
237
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature")
238
+ top_k = gr.Slider(minimum=4, maximum=256, step=1, value=128, label="Top-k")
239
+ soundfont = gr.File(label="Optional: Upload SoundFont (preset=0, bank=0)")
240
+
241
+ with gr.Column(scale=1):
242
+ midi_file = gr.File(label="MIDI File Output")
243
+ audio_file = gr.Audio(label="Generated Audio Output", type="filepath")
244
+ generate_button = gr.Button("Generate", elem_id="generate-btn")
245
+
246
+ generate_button.click(
247
+ fn=generate_and_return_files,
248
+ inputs=[bpm, temperature, top_k, soundfont],
249
+ outputs=[midi_file, audio_file]
250
+ )
251
+
252
+ iface.launch(share=True)
253
+
254
+
255
+
256