suric's picture
init app
c318a73
raw
history blame
3.65 kB
import time
import torch
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
import gradio as gr
from audiocraft.models import MusicGen
from tempfile import NamedTemporaryFile
from pathlib import Path
def load_model(version='facebook/musicgen-melody'):
return MusicGen.get_pretrained(version)
def _do_predictions(model, texts, melodies, duration, progress=False, gradio_progress=None, target_sr=32000, target_ac = 1, **gen_kwargs):
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
be = time.time()
processed_melodies = []
for melody in melodies:
if melody is None:
processed_melodies.append(None)
else:
sr, melody = melody[0], torch.from_numpy(melody[1]).to(model.device).float().t()
if melody.dim() == 1:
melody = melody[None]
melody = melody[..., :int(sr * duration)]
melody = convert_audio(melody, sr, target_sr, target_ac)
processed_melodies.append(melody)
try:
if any(m is not None for m in processed_melodies):
# melody condition
outputs = model.generate_with_chroma(
descriptions=texts,
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress,
return_tokens=False
)
else:
# text only
outputs = model.generate(texts, progress=progress, return_tokens=False)
except RuntimeError as e:
raise gr.Error("Error while generating " + e.args[0])
outputs = outputs.detach().cpu().float()
pending_videos = []
out_wavs = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name, output, model.sample_rate, strategy="loudness",
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
out_wavs.append(file.name)
print("generation finished", len(texts), time.time() - be)
return out_wavs
def predict(model_path, text, melody, duration, topk, topp, temperature, target_sr, progress=gr.Progress()):
global INTERRUPTING
global USE_DIFFUSION
INTERRUPTING = False
progress(0, desc="Loading model...")
model_path = model_path.strip()
if model_path:
if not Path(model_path).exists():
raise gr.Error(f"Model path {model_path} doesn't exist.")
if not Path(model_path).is_dir():
raise gr.Error(f"Model path {model_path} must be a folder containing "
"state_dict.bin and compression_state_dict_.bin.")
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
if topk < 0:
raise gr.Error("Topk must be non-negative.")
if topp < 0:
raise gr.Error("Topp must be non-negative.")
topk = int(topk)
model = load_model(model_path)
max_generated = 0
def _progress(generated, to_generate):
nonlocal max_generated
max_generated = max(generated, max_generated)
progress((min(max_generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
model.set_custom_progress_callback(_progress)
wavs = _do_predictions(
[text],
[melody],
duration,
progress=True,
target_ac=1,
target_sr=target_sr,
top_k=topk,
top_p=topp,
temperature=temperature,
gradio_progress=progress)
return wavs