from contextlib import nullcontext from torch.nn import functional as F from utils import TOKENIZER, Dataset from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter from pedalboard.io import AudioFile import pandas as pd import subprocess import pretty_midi import gradio as gr import time import copy import types import torch import random import spaces import os torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True in_space = os.getenv("SYSTEM") == "spaces" n_layer = 12 n_embd = 768 ctx_len = 2048 os.environ['RWKV_FLOAT_MODE'] = 'fp32' os.environ['RWKV_RUN_DEVICE'] = 'cpu' model_type = 'RWKV' MODEL_NAME = 'model' LENGTH_PER_TRIAL = round((2048) / 13) * 13 TEMPERATURE = 1.0 from model_run import RWKV_RNN model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len) tokenizer = TOKENIZER() temp_dir = 'temp' if not os.path.exists(temp_dir): os.makedirs(temp_dir) def clear_midi(dir): for file in os.listdir(dir): if file.endswith('.mid'): os.remove(os.path.join(dir, file)) clear_midi(temp_dir) ctx_seed = "000000000000\n" ctx = tokenizer.encode(ctx_seed) src_len = len(ctx) src_ctx = ctx.copy() def humanize_notes(midi_events): def humanize(value): if value != 0: humanize_offset = random.choice([-0.20, 0.20]) return max(0, int(value + humanize_offset)) return value midi_events['start'] = midi_events['start'].apply(humanize) midi_events['end'] = midi_events['end'].apply(humanize) max_tick = 8 * 384 midi_events['end'] = midi_events['end'].clip(upper=max_tick) return midi_events @spaces.GPU(duration=60) def generate_midi(LENGTH_PER_TRIAL, src_ctx, model, src_len, ctx_len, TEMPERATURE, top_k, tokenizer, ctx_seed, bpm): midi_seq = [] for TRIAL in range(1): t_begin = time.time_ns() if TRIAL > 0: midi_seq.append("\n") ctx = src_ctx.copy() model.clear() midi_tokens = [] if TRIAL == 0: init_state = types.SimpleNamespace() for i in range(src_len): x = ctx[:i+1] if i == src_len - 1: init_state.out = model.run(x) else: model.run(x) model.save(init_state) else: model.load(init_state) midi_seq.append(ctx_seed) for i in range(src_len, src_len + LENGTH_PER_TRIAL): x = ctx[:i+1] x = x[-ctx_len:] if i == src_len: out = copy.deepcopy(init_state.out) else: out = model.run(x) char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, top_k=top_k).item() midi_tokens.append(char) if len(midi_tokens) > 2: midi_tokens.pop(0) if midi_tokens == [11, 10]: # stop token pattern break midi_seq.append(tokenizer.decode([int(char)])) if midi_tokens != [11, 10]: ctx += [char] t_end = time.time_ns() trim_seq = "".join(midi_seq) events = trim_seq.split("\n") midi_events = [] sequence = [] rndm_num = 895645 for event in events: if event.strip() == "": midi_events.append(sequence) sequence = [] rndm_num = random.randint(100000, 999999) try: pitch = int(event[0:2]) velocity = int(event[2:4]) start = int(event[4:8]) end = int(event[8:12]) except ValueError: pitch = 0 velocity = 0 start = 0 end = 0 sequence.append({'file_name': f'rwkv_{rndm_num}', 'pitch': pitch, 'velocity': velocity, 'start': start, 'end': end}) if sequence: midi_events.append(sequence) midi_events = pd.DataFrame([pd.Series(event) for sequence in midi_events for event in sequence]) midi_events = midi_events[['file_name', 'pitch', 'velocity', 'start', 'end']] midi_events = humanize_notes(midi_events) midi_events = midi_events.sort_values(by=['file_name', 'start']).reset_index(drop=True) midi_events = midi_events[(midi_events['start'] < 3072) & (midi_events['end'] <= 3072)] for file_name, events in midi_events.groupby('file_name'): midi_obj = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96) instrument = pretty_midi.Instrument(0) midi_obj.instruments.append(instrument) for _, event in events.iterrows(): note = pretty_midi.Note( pitch=event['pitch'], velocity=event['velocity'], start=midi_obj.tick_to_time(event['start']), end=midi_obj.tick_to_time(event['end']) ) instrument.notes.append(note) midi_path = os.path.join(temp_dir, 'output.mid') midi_obj.write(midi_path) return midi_path def render_wav(midi_file, uploaded_sf2=None, output_level='2.0'): sf2_dir = 'sf2' audio_format = 's16' sample_rate = '44100' gain = str(output_level) if uploaded_sf2: sf2_file = uploaded_sf2 else: sf2_files = [f for f in os.listdir(os.path.join(sf2_dir)) if f.endswith('.sf2')] if not sf2_files: raise ValueError("No SoundFont (.sf2) file found in directory.") sf2_file = os.path.join(sf2_dir, random.choice(sf2_files)) #print(f"Using SoundFont: {sf2_file}") output_wav = os.path.join(temp_dir, 'output.wav') with open(os.devnull, 'w') as devnull: command = [ 'fluidsynth', '-ni', sf2_file, midi_file, '-F', output_wav, '-r', str(sample_rate), '-o', f'audio.file.format={audio_format}', '-g', str(gain) ] subprocess.call(command, stdout=devnull, stderr=devnull) return output_wav def generate_and_return_files(bpm, temperature, top_k, uploaded_sf2=None, output_level='2.0'): midi_events = generate_midi( LENGTH_PER_TRIAL, src_ctx, model, src_len, ctx_len, temperature, top_k, tokenizer, ctx_seed, bpm ) midi_file = 'temp/output.mid' wav_raw = render_wav(midi_file, uploaded_sf2, output_level) wav_fx = os.path.join(temp_dir, 'output_fx.wav') sfx_settings = [ { 'board': Pedalboard([ Reverb(room_size=0.50, wet_level=0.30, dry_level=0.75, width=1.0), Compressor(threshold_db=-4.0, ratio=4.0, attack_ms=0.0, release_ms=300.0), ]) } ] for setting in sfx_settings: board = setting['board'] with AudioFile(wav_raw) as f: with AudioFile(wav_fx, 'w', f.samplerate, f.num_channels) as o: while f.tell() < f.frames: chunk = f.read(int(f.samplerate)) effected = board(chunk, f.samplerate, reset=False) o.write(effected) return midi_file, wav_fx custom_css = """ #container { max-width: 1200px !important; margin: 0 auto !important; } #generate-btn { font-size: 18px; color: white; padding: 10px 20px; border: none; border-radius: 5px; cursor: pointer; background: linear-gradient(90deg, hsla(268, 90%, 70%, 1) 0%, hsla(260, 72%, 74%, 1) 50%, hsla(247, 73%, 65%, 1) 100%); transition: background 1s ease; } #generate-btn:hover { color: white; background: linear-gradient(90deg, hsla(268, 90%, 62%, 1) 0%, hsla(260, 70%, 70%, 1) 50%, hsla(247, 73%, 55%, 1) 100%); } #container .prose { text-align: center !important; } #container h1 { font-weight: bold; font-size: 40px; margin: 0px; } #container p { font-size: 18px; text-align: center; } """ with gr.Blocks( css=custom_css, theme=gr.themes.Default( font=[gr.themes.GoogleFont("Roboto"), "sans-serif"], primary_hue="violet", secondary_hue="violet" ) ) as iface: with gr.Column(elem_id="container"): gr.Markdown("
Pop-K is a small RWKV model that generates pop melodies in C major and A minor.
") bpm = gr.Slider(minimum=50, maximum=200, step=1, value=100, label="BPM") temperature = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature") top_k = gr.Slider(minimum=4, maximum=32, step=1, value=20, label="Top-k") output_level = gr.Slider(minimum=0, maximum=3, step=0.10, value=2.0, label="Output Gain") generate_button = gr.Button("Generate", elem_id="generate-btn") midi_file = gr.File(label="MIDI Output") audio_file = gr.Audio(label="Audio Output", type="filepath") soundfont = gr.File(label="Optional: Upload SoundFont (preset=0, bank=0)") generate_button.click( fn=generate_and_return_files, inputs=[bpm, temperature, top_k, soundfont, output_level], outputs=[midi_file, audio_file] ) gr.Markdown("Developed by Patchbanks
") iface.launch(share=True)