Byte-Beats-demo / app.py
qichenhuang's picture
Update app.py
55ab5a8 verified
raw
history blame
2.11 kB
import gradio as gr
import torch
from transformers import pipeline
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
import wave
import io
def generate_music(text):
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
inputs = processor(
text=text,
padding=True,
return_tensors="pt",
)
audio_values = model.generate(**inputs, do_sample=True
, guidance_scale=3
, max_new_tokens=1350)
sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
file_path= "musicgen_out.wav"
try:
Audio = gr.Audio(file_path)
return Audio
except Exception as e:
return str(e)
theme='Taithrah/Minimal'
# 创建 Gradio Blocks 界面
iface = gr.Blocks(css = ".gradio-container {background: url('file=background.jpg')}")
with iface as demo:
# 创建一个列容器
with gr.Row():
# 在左侧列中添加文本输入框
with gr.Column(scale=1):
input_text = gr.Textbox(lines=2, placeholder="Please enter your lyric here...")
example_demo = gr.Markdown("## Example Demo\n\nHere is an example demo for you to try.")
company_intro = gr.Markdown("## Company Introduction\n\nHere you can add information about your company.")
# 在右侧列中添加音频输出组件
with gr.Column(scale=3):
output_audio = gr.Audio(label="Generated Music")
generate_button = gr.Button("Generate Music")
generate_button.click(fn=generate_music, inputs=input_text, outputs=output_audio)
# 在网页左上方加入公司logo
# with gr.Row():
# gr.Image("logo.jpg", width=1)
# 启动界面
demo.queue(max_size=4).launch(share = True)
# demo = gr.Interface(
# fn=generate_music,
# inputs='text',
# outputs='audio',
# )
# demo.queue(max_size=4).launch(share = True)