Byte_Beats_Demo / app.py
qichenhuang's picture
change file path
d7a040f verified
raw
history blame
952 Bytes
import gradio as gr
import torch
from transformers import pipeline
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
import wave
import io
def generate_music(text):
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
inputs = processor(
text=text,
padding=True,
return_tensors="pt",
)
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256)
sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
file_path= "musicgen_out.wav"
try:
return file_path
except Exception as e:
return str(e)
demo = gr.Interface(
fn=generate_music,
inputs='text',
outputs='file',
)
demo.launch(share = True)