pop_k / app.py
patchbanks's picture
Update app.py
7fd9023 verified
from contextlib import nullcontext
from torch.nn import functional as F
from utils import TOKENIZER, Dataset
from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter
from pedalboard.io import AudioFile
import pandas as pd
import subprocess
import pretty_midi
import gradio as gr
import time
import copy
import types
import torch
import random
import spaces
import os
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
in_space = os.getenv("SYSTEM") == "spaces"
n_layer = 12
n_embd = 768
ctx_len = 2048
os.environ['RWKV_FLOAT_MODE'] = 'fp32'
os.environ['RWKV_RUN_DEVICE'] = 'cpu'
model_type = 'RWKV'
MODEL_NAME = 'model'
LENGTH_PER_TRIAL = round((2048) / 13) * 13
TEMPERATURE = 1.0
from model_run import RWKV_RNN
model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len)
tokenizer = TOKENIZER()
temp_dir = 'temp'
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
def clear_midi(dir):
for file in os.listdir(dir):
if file.endswith('.mid'):
os.remove(os.path.join(dir, file))
clear_midi(temp_dir)
ctx_seed = "000000000000\n"
ctx = tokenizer.encode(ctx_seed)
src_len = len(ctx)
src_ctx = ctx.copy()
def humanize_notes(midi_events):
def humanize(value):
if value != 0:
humanize_offset = random.choice([-0.20, 0.20])
return max(0, int(value + humanize_offset))
return value
midi_events['start'] = midi_events['start'].apply(humanize)
midi_events['end'] = midi_events['end'].apply(humanize)
max_tick = 8 * 384
midi_events['end'] = midi_events['end'].clip(upper=max_tick)
return midi_events
@spaces.GPU(duration=60)
def generate_midi(LENGTH_PER_TRIAL, src_ctx, model, src_len, ctx_len, TEMPERATURE, top_k, tokenizer, ctx_seed, bpm):
midi_seq = []
for TRIAL in range(1):
t_begin = time.time_ns()
if TRIAL > 0:
midi_seq.append("\n")
ctx = src_ctx.copy()
model.clear()
midi_tokens = []
if TRIAL == 0:
init_state = types.SimpleNamespace()
for i in range(src_len):
x = ctx[:i+1]
if i == src_len - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
else:
model.load(init_state)
midi_seq.append(ctx_seed)
for i in range(src_len, src_len + LENGTH_PER_TRIAL):
x = ctx[:i+1]
x = x[-ctx_len:]
if i == src_len:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, top_k=top_k).item()
midi_tokens.append(char)
if len(midi_tokens) > 2:
midi_tokens.pop(0)
if midi_tokens == [11, 10]: # stop token pattern
break
midi_seq.append(tokenizer.decode([int(char)]))
if midi_tokens != [11, 10]:
ctx += [char]
t_end = time.time_ns()
trim_seq = "".join(midi_seq)
events = trim_seq.split("\n")
midi_events = []
sequence = []
rndm_num = 895645
for event in events:
if event.strip() == "":
midi_events.append(sequence)
sequence = []
rndm_num = random.randint(100000, 999999)
try:
pitch = int(event[0:2])
velocity = int(event[2:4])
start = int(event[4:8])
end = int(event[8:12])
except ValueError:
pitch = 0
velocity = 0
start = 0
end = 0
sequence.append({'file_name': f'rwkv_{rndm_num}', 'pitch': pitch, 'velocity': velocity, 'start': start, 'end': end})
if sequence:
midi_events.append(sequence)
midi_events = pd.DataFrame([pd.Series(event) for sequence in midi_events for event in sequence])
midi_events = midi_events[['file_name', 'pitch', 'velocity', 'start', 'end']]
midi_events = humanize_notes(midi_events)
midi_events = midi_events.sort_values(by=['file_name', 'start']).reset_index(drop=True)
midi_events = midi_events[(midi_events['start'] < 3072) & (midi_events['end'] <= 3072)]
for file_name, events in midi_events.groupby('file_name'):
midi_obj = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96)
instrument = pretty_midi.Instrument(0)
midi_obj.instruments.append(instrument)
for _, event in events.iterrows():
note = pretty_midi.Note(
pitch=event['pitch'],
velocity=event['velocity'],
start=midi_obj.tick_to_time(event['start']),
end=midi_obj.tick_to_time(event['end'])
)
instrument.notes.append(note)
midi_path = os.path.join(temp_dir, 'output.mid')
midi_obj.write(midi_path)
return midi_path
def render_wav(midi_file, uploaded_sf2=None, output_level='2.0'):
sf2_dir = 'sf2'
audio_format = 's16'
sample_rate = '44100'
gain = str(output_level)
if uploaded_sf2:
sf2_file = uploaded_sf2
else:
sf2_files = [f for f in os.listdir(os.path.join(sf2_dir)) if f.endswith('.sf2')]
if not sf2_files:
raise ValueError("No SoundFont (.sf2) file found in directory.")
sf2_file = os.path.join(sf2_dir, random.choice(sf2_files))
#print(f"Using SoundFont: {sf2_file}")
output_wav = os.path.join(temp_dir, 'output.wav')
with open(os.devnull, 'w') as devnull:
command = [
'fluidsynth', '-ni', sf2_file, midi_file, '-F', output_wav, '-r', str(sample_rate),
'-o', f'audio.file.format={audio_format}', '-g', str(gain)
]
subprocess.call(command, stdout=devnull, stderr=devnull)
return output_wav
def generate_and_return_files(bpm, temperature, top_k, uploaded_sf2=None, output_level='2.0'):
midi_events = generate_midi(
LENGTH_PER_TRIAL, src_ctx, model, src_len, ctx_len, temperature, top_k,
tokenizer, ctx_seed, bpm
)
midi_file = 'temp/output.mid'
wav_raw = render_wav(midi_file, uploaded_sf2, output_level)
wav_fx = os.path.join(temp_dir, 'output_fx.wav')
sfx_settings = [
{
'board': Pedalboard([
Reverb(room_size=0.50, wet_level=0.30, dry_level=0.75, width=1.0),
Compressor(threshold_db=-4.0, ratio=4.0, attack_ms=0.0, release_ms=300.0),
])
}
]
for setting in sfx_settings:
board = setting['board']
with AudioFile(wav_raw) as f:
with AudioFile(wav_fx, 'w', f.samplerate, f.num_channels) as o:
while f.tell() < f.frames:
chunk = f.read(int(f.samplerate))
effected = board(chunk, f.samplerate, reset=False)
o.write(effected)
return midi_file, wav_fx
custom_css = """
#container {
max-width: 1200px !important;
margin: 0 auto !important;
}
#generate-btn {
font-size: 18px;
color: white;
padding: 10px 20px;
border: none;
border-radius: 5px;
cursor: pointer;
background: linear-gradient(90deg, hsla(268, 90%, 70%, 1) 0%, hsla(260, 72%, 74%, 1) 50%, hsla(247, 73%, 65%, 1) 100%);
transition: background 1s ease;
}
#generate-btn:hover {
color: white;
background: linear-gradient(90deg, hsla(268, 90%, 62%, 1) 0%, hsla(260, 70%, 70%, 1) 50%, hsla(247, 73%, 55%, 1) 100%);
}
#container .prose {
text-align: center !important;
}
#container h1 {
font-weight: bold;
font-size: 40px;
margin: 0px;
}
#container p {
font-size: 18px;
text-align: center;
}
"""
with gr.Blocks(
css=custom_css,
theme=gr.themes.Default(
font=[gr.themes.GoogleFont("Roboto"), "sans-serif"],
primary_hue="violet",
secondary_hue="violet"
)
) as iface:
with gr.Column(elem_id="container"):
gr.Markdown("<h1>Pop-K</h1>")
gr.Markdown("<p>Pop-K is a small RWKV model that generates pop melodies in C major and A minor.</p>")
bpm = gr.Slider(minimum=50, maximum=200, step=1, value=100, label="BPM")
temperature = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature")
top_k = gr.Slider(minimum=4, maximum=32, step=1, value=20, label="Top-k")
output_level = gr.Slider(minimum=0, maximum=3, step=0.10, value=2.0, label="Output Gain")
generate_button = gr.Button("Generate", elem_id="generate-btn")
midi_file = gr.File(label="MIDI Output")
audio_file = gr.Audio(label="Audio Output", type="filepath")
soundfont = gr.File(label="Optional: Upload SoundFont (preset=0, bank=0)")
generate_button.click(
fn=generate_and_return_files,
inputs=[bpm, temperature, top_k, soundfont, output_level],
outputs=[midi_file, audio_file]
)
gr.Markdown("<p style='font-size: 16px;'>Developed by <a href='https://www.patchbanks.com/' target='_blank'><strong>Patchbanks</strong></a></p>")
iface.launch(share=True)