MusicGen / app.py
Tirath5504's picture
Update app.py
6c42479 verified
raw
history blame contribute delete
No virus
2.39 kB
import streamlit as st
import numpy as np
import torch
import transformers
from packaging.version import parse
import sys
import io
import importlib.metadata as importlib_metadata
import soundfile as sf
import importlib.metadata as importlib_metadata
loading_kwargs = {}
if parse(importlib_metadata.version("transformers")) >= parse("4.40.0"):
loading_kwargs["attn_implementation"] = "eager"
def generate(prompt):
model = transformers.MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torchscript=True, return_dict=False, **loading_kwargs)
sample_length = 8
n_tokens = sample_length * model.config.audio_encoder.frame_rate + 3
sampling_rate = model.config.audio_encoder.sampling_rate
processor = transformers.AutoProcessor.from_pretrained("facebook/musicgen-small")
inputs = processor(
text=[
prompt,
],
padding=True,
return_tensors="pt",
)
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens)
waveform = audio_values[0].cpu().squeeze() * 2**15
audio_buffer = io.BytesIO()
sf.write(audio_buffer, waveform.numpy().astype(np.int16), sampling_rate, format='WAV')
audio_buffer.seek(0)
return audio_buffer
st.title("Music Generator")
st.subheader("Select an example or write a text prompt")
text_prompt = st.text_input("Text Prompt", "")
examples = [
"80s pop track with bassy drums and synth",
"Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves",
"90s rock song with loud guitars and heavy drums",
"Heartful EDM with beautiful synths and chords",
"None"
]
st.subheader("Examples")
selected_example = st.radio("Select an example", examples)
if st.button("Generate Audio"):
if selected_example != "None" or text_prompt:
prompt = ""
if text_prompt:
prompt = text_prompt
else:
prompt = selected_example
with st.spinner("Generating audio..."):
audio_output = generate(prompt)
st.audio(audio_output, format='audio/wav')
else:
st.warning("Please select or enter a text prompt.")
if st.checkbox("Show debug info"):
if text_prompt:
st.write("Prompt:", text_prompt)
else:
st.write("Prompt:", selected_example)