import os import pickle import torch import random import subprocess import re import pretty_midi import gradio as gr from contextlib import nullcontext from model import GPTConfig, GPT from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter from pedalboard.io import AudioFile import spaces in_space = os.getenv("SYSTEM") == "spaces" temp_dir = 'temp' os.makedirs(temp_dir, exist_ok=True) init_from = 'resume' out_dir = 'checkpoints' ckpt_load = 'model.pt' start = "000000000000\n" num_samples = 1 max_new_tokens = 768 seed = random.randint(1, 100000) torch.manual_seed(seed) device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' compile = False exec(open('configurator.py').read()) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device_type = 'cpu' if 'cuda' in device else 'cpu' ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) if init_from == 'resume': ckpt_path = os.path.join(out_dir, ckpt_load) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) elif init_from.startswith('gpt2'): model = GPT.from_pretrained(init_from, dict(dropout=0.0)) model.eval() model.to(device) if compile: model = torch.compile(model) tokenizer = re.compile(r'000000000000|\d{2}|\n') meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') with open(meta_path, 'rb') as f: meta = pickle.load(f) stoi = meta.get('stoi', None) itos = meta.get('itos', None) def encode(text): matches = tokenizer.findall(text) return [stoi[c] for c in matches] def decode(encoded): return ''.join([itos[i] for i in encoded]) def clear_midi(dir): for file in os.listdir(dir): if file.endswith('.mid'): os.remove(os.path.join(dir, file)) clear_midi(temp_dir) @spaces.GPU(duration=10) def generate_midi(temperature, top_k): start_ids = encode(start) x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) midi_events = [] seq_count = 0 with torch.no_grad(): for _ in range(num_samples): sequence = [] y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) tkn_seq = decode(y[0].tolist()) lines = tkn_seq.splitlines() for event in lines: if event.startswith(start.strip()): if sequence: midi_events.append(sequence) sequence = [] seq_count += 1 elif event.strip() == "": continue else: try: p = int(event[0:2]) v = int(event[2:4]) s = int(event[4:8]) e = int(event[8:12]) except ValueError: p, v, s, e = 0, 0, 0, 0 sequence.append({'file_name': f'nanompc_{seq_count:02d}', 'pitch': p, 'velocity': v, 'start': s, 'end': e}) if sequence: midi_events.append(sequence) round_bars = [] for sequence in midi_events: filtered_sequence = [] for event in sequence: if event['start'] < 1536 and event['end'] <= 1536: filtered_sequence.append(event) if filtered_sequence: round_bars.append(filtered_sequence) midi_events = round_bars for track in midi_events: track.sort(key=lambda x: x['start']) unique_notes = [] for note in track: if not any(abs(note['start'] - n['start']) < 12 and note['pitch'] == n['pitch'] for n in unique_notes): unique_notes.append(note) track[:] = unique_notes return midi_events def write_single_midi(midi_events, bpm): midi_data = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96) midi_data.time_signature_changes.append(pretty_midi.containers.TimeSignature(4, 4, 0)) instrument = pretty_midi.Instrument(0) midi_data.instruments.append(instrument) for event in midi_events[0]: pitch = event['pitch'] velocity = event['velocity'] start = midi_data.tick_to_time(event['start']) end = midi_data.tick_to_time(event['end']) note = pretty_midi.Note(pitch=pitch, velocity=velocity, start=start, end=end) instrument.notes.append(note) midi_path = os.path.join(temp_dir, 'output.mid') midi_data.write(midi_path) print(f"Generated: {midi_path}") def render_wav(midi_file, uploaded_sf2=None, output_level='2.0'): sf2_dir = 'sf2_kits' audio_format = 's16' sample_rate = '44100' gain = str(output_level) if uploaded_sf2: sf2_file = uploaded_sf2 else: sf2_files = [f for f in os.listdir(os.path.join(sf2_dir)) if f.endswith('.sf2')] if not sf2_files: raise ValueError("No SoundFont (.sf2) file found in directory.") sf2_file = os.path.join(sf2_dir, random.choice(sf2_files)) output_wav = os.path.join(temp_dir, 'output.wav') with open(os.devnull, 'w') as devnull: command = [ 'fluidsynth', '-ni', sf2_file, midi_file, '-F', output_wav, '-r', str(sample_rate), '-o', f'audio.file.format={audio_format}', '-g', str(gain) ] subprocess.call(command, stdout=devnull, stderr=devnull) return output_wav def generate_and_return_files(bpm, temperature, top_k, uploaded_sf2=None, output_level='2.0'): midi_events = generate_midi(temperature, top_k) if not midi_events: return "Error generating MIDI.", None, None write_single_midi(midi_events, bpm) midi_file = os.path.join(temp_dir, 'output.mid') wav_raw = render_wav(midi_file, uploaded_sf2, output_level) wav_fx = os.path.join(temp_dir, 'output_fx.wav') sfx_settings = [ { 'board': Pedalboard([ Reverb(room_size=0.01, wet_level=random.uniform(0.005, 0.01), dry_level=0.75, width=1.0), Compressor(threshold_db=-3.0, ratio=8.0, attack_ms=0.0, release_ms=300.0), ]) } ] for setting in sfx_settings: board = setting['board'] with AudioFile(wav_raw) as f: with AudioFile(wav_fx, 'w', f.samplerate, f.num_channels) as o: while f.tell() < f.frames: chunk = f.read(int(f.samplerate)) effected = board(chunk, f.samplerate, reset=False) o.write(effected) return midi_file, wav_fx custom_css = """ #container { max-width: 1200px !important; margin: 0 auto !important; } #generate-btn { font-size: 18px; color: white; padding: 10px 20px; border: none; border-radius: 5px; cursor: pointer; background: linear-gradient(90deg, hsla(268, 90%, 70%, 1) 0%, hsla(260, 72%, 74%, 1) 50%, hsla(247, 73%, 65%, 1) 100%); transition: background 1s ease; } #generate-btn:hover { color: white; background: linear-gradient(90deg, hsla(268, 90%, 62%, 1) 0%, hsla(260, 70%, 70%, 1) 50%, hsla(247, 73%, 55%, 1) 100%); } #container .prose { text-align: center !important; } #container h1 { font-weight: bold; font-size: 40px; margin: 0px; } #container p { font-size: 18px; text-align: center; } """ with gr.Blocks( css=custom_css, theme=gr.themes.Default( font=[gr.themes.GoogleFont("Roboto"), "sans-serif"], primary_hue="violet", secondary_hue="violet" ) ) as iface: with gr.Column(elem_id="container"): gr.Markdown("
Neural Breaks is a generative MIDI model trained on dynamic transcriptions of funk and soul drum breaks.
") bpm = gr.Slider(minimum=50, maximum=200, step=1, value=100, label="BPM") temperature = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature") top_k = gr.Slider(minimum=4, maximum=16, step=1, value=8, label="Top-k") output_level = gr.Slider(minimum=0, maximum=3, step=0.10, value=2.0, label="Output Gain") generate_button = gr.Button("Generate", elem_id="generate-btn") midi_file = gr.File(label="MIDI Output") audio_file = gr.Audio(label="Audio Output", type="filepath") soundfont = gr.File(label="Optional: Upload SoundFont (preset=0, bank=0)") generate_button.click( fn=generate_and_return_files, inputs=[bpm, temperature, top_k, soundfont, output_level], outputs=[midi_file, audio_file] ) gr.Markdown("Developed by Patchbanks
") iface.launch(share=True)