Spaces:
Sleeping
Sleeping
File size: 3,017 Bytes
493ccb0 1e02332 5fca855 1e02332 332cbcf 493ccb0 332cbcf 1e02332 5ebbc0c 1e02332 948403f 1e02332 3125456 332cbcf 948403f 332cbcf 1e02332 ec1b64f 332cbcf 1e02332 ec1b64f 1e02332 332cbcf 1e02332 948403f 1e02332 5ebbc0c 84248c7 332cbcf 84248c7 332cbcf 948403f 332cbcf 84248c7 948403f 84248c7 332cbcf 948403f 332cbcf 84248c7 948403f 84248c7 332cbcf 1e02332 282313c 1e02332 948403f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import streamlit as st
import torch
import torchaudio
import os
import numpy as np
import base64
from audiocraft.models import MusicGen
# Before
batch_size = 64
# After
batch_size = 32
torch.cuda.empty_cache()
genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]
@st.cache_resource()
def load_model():
model = MusicGen.get_pretrained('facebook/musicgen-small')
return model
def generate_music_tensors(description, duration: int):
model = load_model()
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
with st.spinner("Generating Music..."):
output = model.generate(
descriptions=description,
progress=True,
return_tokens=True
)
st.success("Music Generation Complete!")
return output
def save_audio(samples: torch.Tensor):
sample_rate = 30000
save_path = "audio_output"
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
for idx, audio in enumerate(samples):
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
torchaudio.save(audio_path, audio, sample_rate)
return audio_path
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
st.set_page_config(
page_icon= "musical_note",
page_title= "Music Gen"
)
def main():
st.title("🎧 AI Composer Small-Model 🎧")
st.subheader("Craft your perfect melody!")
bpm = st.number_input("Enter Speed in BPM", min_value=60)
text_area = st.text_area('Ex : 80s rock song with guitar and drums')
st.text('')
# Dropdown for genres
selected_genre = st.selectbox("Select Genre", genres)
st.subheader("2. Select time duration (In Seconds)")
time_slider = st.slider("Select time duration (In Seconds)", 0, 30, 10)
if st.button('Let\'s Generate 🎶'):
st.text('\n\n')
st.subheader("Generated Music")
description = f"{text_area} {selected_genre} {bpm} BPM"
# Clear CUDA memory cache before generating music
torch.cuda.empty_cache()
music_tensors = generate_music_tensors(description, time_slider)
# Only play the full audio for index 0
idx = 0
music_tensor = music_tensors[idx]
audio_filepath = save_audio(music_tensor)
audio_file = open(audio_filepath, 'rb')
audio_bytes = audio_file.read()
# Play the full audio
st.audio(audio_bytes, format='audio/wav')
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
if __name__ == "__main__":
main()
|