Spaces:
Runtime error
Runtime error
from audiocraft.models import MusicGen | |
import streamlit as st | |
import torch | |
import torchaudio | |
import io | |
import base64 | |
def load_model(): | |
model = MusicGen.get_pretrained('facebook/musicgen-small') | |
return model | |
def generate_music_tensors(description, duration: int): | |
st.write(f"Generating music for: '{description}' (Duration: {duration}s)") | |
model = load_model() | |
model.set_generation_params(use_sampling=True, top_k=250, duration=duration) | |
output = model.generate(descriptions=[description], progress=True) | |
return output[0] | |
def create_audio_buffer(samples: torch.Tensor): | |
"""Generate an in-memory audio buffer.""" | |
sample_rate = 32000 | |
samples = samples.detach().cpu() | |
if samples.dim() == 2: | |
samples = samples[None, ...] | |
# Create an in-memory buffer to store the audio | |
buffer = io.BytesIO() | |
torchaudio.save(buffer, samples[0], sample_rate, format="wav") | |
buffer.seek(0) | |
return buffer | |
def generate_download_link(buffer, file_label="Download Music"): | |
"""Create a download link for the generated audio.""" | |
data = buffer.read() | |
b64 = base64.b64encode(data).decode() | |
href = f'<a href="data:audio/wav;base64,{b64}" download="generated_music.wav">{file_label}</a>' | |
return href | |
# Apply CSS for improved UI styling | |
st.markdown( | |
""" | |
<style> | |
.title { | |
font-size: 3em; | |
text-align: center; | |
color: #4A90E2; | |
margin-top: 0; | |
} | |
.footer { | |
position: fixed; | |
left: 0; | |
bottom: 0; | |
width: 100%; | |
background-color: #f1f1f1; | |
text-align: center; | |
padding: 10px; | |
font-size: 0.8em; | |
color: #555; | |
} | |
</style> | |
""", unsafe_allow_html=True | |
) | |
# Set Streamlit page configuration | |
def main(): | |
st.markdown('<h1 class="title">Theaimart: Music Generator 🎵</h1>', unsafe_allow_html=True) | |
st.write("Generate music based on your text input using Meta's Audiocraft library!") | |
description = st.text_area("Enter a description:") | |
duration = st.slider("Select duration (seconds)", 1, 20, 10) | |
if description and duration: | |
music_tensors = generate_music_tensors(description, duration) | |
audio_buffer = create_audio_buffer(music_tensors) | |
st.audio(audio_buffer, format="audio/wav") | |
st.markdown(generate_download_link(audio_buffer), unsafe_allow_html=True) | |
# Add footer message | |
st.markdown('<div class="footer">Made with ❤️ by Theaimart</div>', unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |