import streamlit as st import os from transformer_wrapper import TransformerWrapper from omegaconf import OmegaConf @st.cache(show_spinner=False) def get_file_content_as_string(path): return open(path, "r", encoding="utf-8").read() @st.cache(show_spinner=True) def model_load(): config = OmegaConf.load("config.yaml") wrapper = TransformerWrapper(config) wrapper = wrapper.load_from_checkpoint( "https://huggingface.co/sweetcocoa/pop2piano/raw/main/model-1999-val_0.67311615.ckpt", config=config, ).cuda() model_id = "dpipqxiy" wrapper.eval() return wrapper, model_id, config def main(): wrapper, model_id, config = model_load() composers = list(config.composer_to_feature_token.keys()) dest_dir = "ytsamples" composer = st.selectbox(label="Arranger", options=composers) file_up = st.file_uploader("Upload an audio", type=["mp3", "wav"]) if st.button("convert"): if file_up is not None: bytes_data = file_up.getvalue() target_file = f"{dest_dir}/{file_up.name}" with open(target_file, "wb") as f: f.write(bytes_data) with st.spinner("Wait for it..."): midi, arranger, mix_path, midi_path = wrapper.generate( audio_path=target_file, composer=composer, model=model_id, ignore_duplicate=True, show_plot=False, save_midi=True, save_mix=True, vqvae=None, ) with open(midi_path, "rb") as midi_f: file_down = st.download_button( "Download midi", data=midi_f, file_name=os.path.basename(midi_path), ) with open(mix_path, "rb") as audio_f: st.audio(audio_f.read(), format="audio/wav") if __name__ == "__main__": main()