File size: 3,017 Bytes
493ccb0
1e02332
 
 
5fca855
1e02332
332cbcf
 
 
 
 
 
 
 
493ccb0
332cbcf
1e02332
 
 
5ebbc0c
1e02332
 
948403f
1e02332
 
 
 
 
 
 
 
3125456
332cbcf
948403f
332cbcf
 
 
1e02332
 
 
 
ec1b64f
332cbcf
1e02332
ec1b64f
1e02332
 
 
 
 
 
 
332cbcf
1e02332
948403f
1e02332
 
 
 
 
 
 
 
 
 
 
 
 
 
5ebbc0c
84248c7
 
 
332cbcf
 
 
 
 
 
 
 
84248c7
 
 
 
332cbcf
948403f
 
332cbcf
 
84248c7
 
 
948403f
84248c7
332cbcf
948403f
332cbcf
84248c7
 
948403f
84248c7
332cbcf
1e02332
282313c
1e02332
948403f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import torch
import torchaudio
import os
import numpy as np
import base64
from audiocraft.models import MusicGen

# Before
batch_size = 64

# After
batch_size = 32 
torch.cuda.empty_cache()

genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]

@st.cache_resource()
def load_model():
    model = MusicGen.get_pretrained('facebook/musicgen-small')
    return model

def generate_music_tensors(description, duration: int):
    model = load_model()

    model.set_generation_params(
        use_sampling=True,
        top_k=250,
        duration=duration
    )

    with st.spinner("Generating Music..."):
        output = model.generate(
            descriptions=description,
            progress=True,
            return_tokens=True
        )

    st.success("Music Generation Complete!")
    return output


def save_audio(samples: torch.Tensor):
    sample_rate = 30000
    save_path = "audio_output" 
    assert samples.dim() == 2 or samples.dim() == 3

    samples = samples.detach().cpu()
    if samples.dim() == 2:
        samples = samples[None, ...]

    for idx, audio in enumerate(samples):
        audio_path = os.path.join(save_path, f"audio_{idx}.wav")
        torchaudio.save(audio_path, audio, sample_rate)
        return audio_path

def get_binary_file_downloader_html(bin_file, file_label='File'):
    with open(bin_file, 'rb') as f:
        data = f.read()
    bin_str = base64.b64encode(data).decode()
    href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
    return href

st.set_page_config(
    page_icon= "musical_note",
    page_title= "Music Gen"
)

def main():
    st.title("🎧 AI Composer Small-Model 🎧")

    st.subheader("Craft your perfect melody!")
    bpm = st.number_input("Enter Speed in BPM", min_value=60)

    text_area = st.text_area('Ex : 80s rock song with guitar and drums')
    st.text('')
    # Dropdown for genres
    selected_genre = st.selectbox("Select Genre", genres)

    st.subheader("2. Select time duration (In Seconds)")
    time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)

    if st.button('Let\'s Generate 🎶'):
        st.text('\n\n')
        st.subheader("Generated Music")

        description = f"{text_area} {selected_genre} {bpm} BPM"
        
        # Clear CUDA memory cache before generating music
        torch.cuda.empty_cache()

        music_tensors = generate_music_tensors(description, time_slider)

        # Only play the full audio for index 0
        idx = 0
        music_tensor = music_tensors[idx]
        audio_filepath = save_audio(music_tensor)
        audio_file = open(audio_filepath, 'rb')
        audio_bytes = audio_file.read()

        # Play the full audio
        st.audio(audio_bytes, format='audio/wav')
        st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)


if __name__ == "__main__":
    main()