Spaces:
Sleeping
Sleeping
sadafwalliyani
commited on
Commit
•
948403f
1
Parent(s):
5ebbc0c
Update app.py
Browse files
app.py
CHANGED
@@ -20,10 +20,8 @@ def load_model():
|
|
20 |
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
21 |
return model
|
22 |
|
23 |
-
def generate_music_tensors(
|
24 |
model = load_model()
|
25 |
-
# model = load_model().to('cpu')
|
26 |
-
|
27 |
|
28 |
model.set_generation_params(
|
29 |
use_sampling=True,
|
@@ -33,7 +31,7 @@ def generate_music_tensors(descriptions, duration: int):
|
|
33 |
|
34 |
with st.spinner("Generating Music..."):
|
35 |
output = model.generate(
|
36 |
-
descriptions=
|
37 |
progress=True,
|
38 |
return_tokens=True
|
39 |
)
|
@@ -54,6 +52,7 @@ def save_audio(samples: torch.Tensor):
|
|
54 |
for idx, audio in enumerate(samples):
|
55 |
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
|
56 |
torchaudio.save(audio_path, audio, sample_rate)
|
|
|
57 |
|
58 |
def get_binary_file_downloader_html(bin_file, file_label='File'):
|
59 |
with open(bin_file, 'rb') as f:
|
@@ -80,52 +79,29 @@ def main():
|
|
80 |
|
81 |
st.subheader("2. Select time duration (In Seconds)")
|
82 |
time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
|
83 |
-
# mood = st.selectbox("Select Mood (Optional)", ["Happy", "Sad", "Angry", "Relaxed", "Energetic"], None)
|
84 |
-
# instrument = st.selectbox("Select Instrument (Optional)", ["Piano", "Guitar", "Flute", "Violin", "Drums"], None)
|
85 |
-
# tempo = st.selectbox("Select Tempo (Optional)", ["Slow", "Moderate", "Fast"], None)
|
86 |
-
# melody = st.text_input("Enter Melody or Chord Progression (Optional)", "e.g: C D:min G:7 C, Twinkle Twinkle Little Star")
|
87 |
|
88 |
if st.button('Let\'s Generate 🎶'):
|
89 |
st.text('\n\n')
|
90 |
st.subheader("Generated Music")
|
91 |
-
descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
|
92 |
-
|
93 |
-
# Generate audio
|
94 |
-
# description = text_area # Initialize description with text_area
|
95 |
-
# if selected_genre:
|
96 |
-
# description += f" {selected_genre}"
|
97 |
-
# st.empty() # Hide the selected_genre selectbox after selecting one option
|
98 |
-
# if bpm:
|
99 |
-
# description += f" {bpm} BPM"
|
100 |
-
# if mood:
|
101 |
-
# description += f" {mood}"
|
102 |
-
# st.empty() # Hide the mood selectbox after selecting one option
|
103 |
-
# if instrument:
|
104 |
-
# description += f" {instrument}"
|
105 |
-
# st.empty() # Hide the instrument selectbox after selecting one option
|
106 |
-
# if tempo:
|
107 |
-
# description += f" {tempo}"
|
108 |
-
# st.empty() # Hide the tempo selectbox after selecting one option
|
109 |
-
# if melody:
|
110 |
-
# description += f" {melody}"
|
111 |
|
|
|
|
|
112 |
# Clear CUDA memory cache before generating music
|
113 |
torch.cuda.empty_cache()
|
114 |
|
115 |
music_tensors = generate_music_tensors(description, time_slider)
|
116 |
|
117 |
-
|
118 |
idx = 0
|
119 |
music_tensor = music_tensors[idx]
|
120 |
-
|
121 |
-
audio_filepath = f'output/audio_{idx}.wav'
|
122 |
audio_file = open(audio_filepath, 'rb')
|
123 |
audio_bytes = audio_file.read()
|
124 |
|
125 |
-
|
126 |
st.audio(audio_bytes, format='audio/wav')
|
127 |
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
|
128 |
|
129 |
|
130 |
if __name__ == "__main__":
|
131 |
-
main()
|
|
|
20 |
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
21 |
return model
|
22 |
|
23 |
+
def generate_music_tensors(description, duration: int):
|
24 |
model = load_model()
|
|
|
|
|
25 |
|
26 |
model.set_generation_params(
|
27 |
use_sampling=True,
|
|
|
31 |
|
32 |
with st.spinner("Generating Music..."):
|
33 |
output = model.generate(
|
34 |
+
descriptions=description,
|
35 |
progress=True,
|
36 |
return_tokens=True
|
37 |
)
|
|
|
52 |
for idx, audio in enumerate(samples):
|
53 |
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
|
54 |
torchaudio.save(audio_path, audio, sample_rate)
|
55 |
+
return audio_path
|
56 |
|
57 |
def get_binary_file_downloader_html(bin_file, file_label='File'):
|
58 |
with open(bin_file, 'rb') as f:
|
|
|
79 |
|
80 |
st.subheader("2. Select time duration (In Seconds)")
|
81 |
time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
|
|
|
|
|
|
|
|
|
82 |
|
83 |
if st.button('Let\'s Generate 🎶'):
|
84 |
st.text('\n\n')
|
85 |
st.subheader("Generated Music")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
+
description = f"{text_area} {selected_genre} {bpm} BPM"
|
88 |
+
|
89 |
# Clear CUDA memory cache before generating music
|
90 |
torch.cuda.empty_cache()
|
91 |
|
92 |
music_tensors = generate_music_tensors(description, time_slider)
|
93 |
|
94 |
+
# Only play the full audio for index 0
|
95 |
idx = 0
|
96 |
music_tensor = music_tensors[idx]
|
97 |
+
audio_filepath = save_audio(music_tensor)
|
|
|
98 |
audio_file = open(audio_filepath, 'rb')
|
99 |
audio_bytes = audio_file.read()
|
100 |
|
101 |
+
# Play the full audio
|
102 |
st.audio(audio_bytes, format='audio/wav')
|
103 |
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
|
104 |
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
+
main()
|