File size: 1,995 Bytes
88490a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import os
from transformer_wrapper import TransformerWrapper
from omegaconf import OmegaConf


@st.cache(show_spinner=False)
def get_file_content_as_string(path):
    return open(path, "r", encoding="utf-8").read()


@st.cache(show_spinner=True)
def model_load():
    config = OmegaConf.load("config.yaml")
    wrapper = TransformerWrapper(config)
    wrapper = wrapper.load_from_checkpoint(
        "https://huggingface.co/sweetcocoa/pop2piano/raw/main/model-1999-val_0.67311615.ckpt",
        config=config,
    ).cuda()
    model_id = "dpipqxiy"
    wrapper.eval()
    return wrapper, model_id, config


def main():

    wrapper, model_id, config = model_load()
    composers = list(config.composer_to_feature_token.keys())
    dest_dir = "ytsamples"
    composer = st.selectbox(label="Arranger", options=composers)
    file_up = st.file_uploader("Upload an audio", type=["mp3", "wav"])

    if st.button("convert"):

        if file_up is not None:
            bytes_data = file_up.getvalue()
            target_file = f"{dest_dir}/{file_up.name}"
            with open(target_file, "wb") as f:
                f.write(bytes_data)

            with st.spinner("Wait for it..."):
                midi, arranger, mix_path, midi_path = wrapper.generate(
                    audio_path=target_file,
                    composer=composer,
                    model=model_id,
                    ignore_duplicate=True,
                    show_plot=False,
                    save_midi=True,
                    save_mix=True,
                    vqvae=None,
                )

            with open(midi_path, "rb") as midi_f:
                file_down = st.download_button(
                    "Download midi",
                    data=midi_f,
                    file_name=os.path.basename(midi_path),
                )
            with open(mix_path, "rb") as audio_f:
                st.audio(audio_f.read(), format="audio/wav")


if __name__ == "__main__":
    main()