Spaces:
Running
on
T4
Running
on
T4
import torch | |
import numpy as np | |
import gradio as gr | |
import soundfile as sf | |
from transformers import pipeline | |
from huggingface_hub import InferenceClient | |
def _grab_best_device(use_gpu=True): | |
if torch.cuda.device_count() > 0 and use_gpu: | |
device = "cuda" | |
else: | |
device = "cpu" | |
return device | |
device = _grab_best_device() | |
title = """# MusiGen Prompt Upsampling""" | |
vibes = pipeline("text-to-audio", | |
"facebook/musicgen-stereo-medium", | |
torch_dtype=torch.float16, | |
device="cuda") | |
client = InferenceClient(model="HuggingFaceH4/zephyr-7b-beta",) | |
# Inference | |
def generate_audio(text,): | |
prompt = f"Take the next sentence and enrich it with details. Keep it compact. {text}" | |
output = client.text_generation(prompt, max_new_tokens=100) | |
out = vibes(output) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
sf.write(f.name, out["audio"][0].T, out["sampling_rate"]) | |
return f.name | |
css = """ | |
#container{ | |
margin: 0 auto; | |
max-width: 80rem; | |
} | |
#intro{ | |
max-width: 100%; | |
text-align: center; | |
margin: 0 auto; | |
} | |
""" | |
# Gradio blocks demo | |
with gr.Blocks(css=css) as demo_blocks: | |
gr.Markdown(title, elem_id="intro") | |
with gr.Row(): | |
with gr.Column(): | |
inp_text = gr.Textbox(label="Input Prompt", info="What would you like MusicGen to synthesise?") | |
btn = gr.Button("Generate Music!πΆ") | |
with gr.Column(): | |
out = gr.Audio(autoplay=False, label=f"Generated Music", show_label=True,) | |
with gr.Accordion("Run MusicGen with Transformers π€", open=False): | |
gr.Markdown( | |
""" | |
import torch | |
import soundfile as sf | |
from transformers import pipeline | |
synthesiser = pipeline("text-to-audio", | |
"facebook/musicgen-stereo-medium", | |
device="cuda:0", | |
torch_dtype=torch.float16) | |
music = synthesiser("lo-fi music with a soothing melody", forward_params={"max_new_tokens": 256}) | |
sf.write("musicgen_out.wav", music["audio"][0].T, music["sampling_rate"]) | |
""" | |
) | |
btn.click(generate_audio, inp_text, out) | |
demo_blocks.queue().launch() |