pop2piano / app.py
sweetcocoa's picture
initial test
88490a8
raw
history blame
No virus
2 kB
import streamlit as st
import os
from transformer_wrapper import TransformerWrapper
from omegaconf import OmegaConf
@st.cache(show_spinner=False)
def get_file_content_as_string(path):
return open(path, "r", encoding="utf-8").read()
@st.cache(show_spinner=True)
def model_load():
config = OmegaConf.load("config.yaml")
wrapper = TransformerWrapper(config)
wrapper = wrapper.load_from_checkpoint(
"https://huggingface.co/sweetcocoa/pop2piano/raw/main/model-1999-val_0.67311615.ckpt",
config=config,
).cuda()
model_id = "dpipqxiy"
wrapper.eval()
return wrapper, model_id, config
def main():
wrapper, model_id, config = model_load()
composers = list(config.composer_to_feature_token.keys())
dest_dir = "ytsamples"
composer = st.selectbox(label="Arranger", options=composers)
file_up = st.file_uploader("Upload an audio", type=["mp3", "wav"])
if st.button("convert"):
if file_up is not None:
bytes_data = file_up.getvalue()
target_file = f"{dest_dir}/{file_up.name}"
with open(target_file, "wb") as f:
f.write(bytes_data)
with st.spinner("Wait for it..."):
midi, arranger, mix_path, midi_path = wrapper.generate(
audio_path=target_file,
composer=composer,
model=model_id,
ignore_duplicate=True,
show_plot=False,
save_midi=True,
save_mix=True,
vqvae=None,
)
with open(midi_path, "rb") as midi_f:
file_down = st.download_button(
"Download midi",
data=midi_f,
file_name=os.path.basename(midi_path),
)
with open(mix_path, "rb") as audio_f:
st.audio(audio_f.read(), format="audio/wav")
if __name__ == "__main__":
main()