THEAIMART's picture
Create app.py
7dbe588 verified
raw
history blame
2.57 kB
from audiocraft.models import MusicGen
import streamlit as st
import torch
import torchaudio
import io
import base64
@st.cache_resource
def load_model():
model = MusicGen.get_pretrained('facebook/musicgen-small')
return model
def generate_music_tensors(description, duration: int):
st.write(f"Generating music for: '{description}' (Duration: {duration}s)")
model = load_model()
model.set_generation_params(use_sampling=True, top_k=250, duration=duration)
output = model.generate(descriptions=[description], progress=True)
return output[0]
def create_audio_buffer(samples: torch.Tensor):
"""Generate an in-memory audio buffer."""
sample_rate = 32000
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
# Create an in-memory buffer to store the audio
buffer = io.BytesIO()
torchaudio.save(buffer, samples[0], sample_rate, format="wav")
buffer.seek(0)
return buffer
def generate_download_link(buffer, file_label="Download Music"):
"""Create a download link for the generated audio."""
data = buffer.read()
b64 = base64.b64encode(data).decode()
href = f'<a href="data:audio/wav;base64,{b64}" download="generated_music.wav">{file_label}</a>'
return href
# Apply CSS for improved UI styling
st.markdown(
"""
<style>
.title {
font-size: 3em;
text-align: center;
color: #4A90E2;
margin-top: 0;
}
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: #f1f1f1;
text-align: center;
padding: 10px;
font-size: 0.8em;
color: #555;
}
</style>
""", unsafe_allow_html=True
)
# Set Streamlit page configuration
def main():
st.markdown('<h1 class="title">Theaimart: Music Generator 🎵</h1>', unsafe_allow_html=True)
st.write("Generate music based on your text input using Meta's Audiocraft library!")
description = st.text_area("Enter a description:")
duration = st.slider("Select duration (seconds)", 1, 20, 10)
if description and duration:
music_tensors = generate_music_tensors(description, duration)
audio_buffer = create_audio_buffer(music_tensors)
st.audio(audio_buffer, format="audio/wav")
st.markdown(generate_download_link(audio_buffer), unsafe_allow_html=True)
# Add footer message
st.markdown('<div class="footer">Made with ❤️ by Theaimart</div>', unsafe_allow_html=True)
if __name__ == "__main__":
main()