Spaces:
No application file
No application file
import gc | |
import gradio | |
import torch | |
from audiocraft.models import MusicGen | |
from audiocraft.models import AudioGen | |
model: MusicGen = None | |
loaded = False | |
used_model = '' | |
device: str = None | |
melody_models = ['facebook/musicgen-melody'] | |
audiogen_models = ['facebook/audiogen-medium'] | |
models = ['facebook/musicgen-small', 'facebook/musicgen-medium', 'facebook/musicgen-large'] + melody_models + audiogen_models | |
def supports_melody(): | |
return used_model in melody_models | |
def create_model(pretrained='medium', map_device='cuda' if torch.cuda.is_available() else 'cpu'): | |
if is_loaded(): | |
delete_model() | |
global model, loaded, device, used_model | |
try: | |
model = MusicGen.get_pretrained(pretrained, device=map_device) if pretrained not in audiogen_models else AudioGen.get_pretrained(pretrained, device=map_device) | |
device = map_device | |
used_model = pretrained | |
loaded = True | |
except: | |
raise gradio.Error('Could not load model!') | |
def delete_model(): | |
global model, loaded, device | |
try: | |
del model | |
gc.collect() | |
torch.cuda.empty_cache() | |
loaded = False | |
device = None | |
except: | |
raise gradio.Error('Could not unload model!') | |
def is_loaded(): | |
return loaded | |
def generate(prompt='', input_audio=None, use_sample=True, top_k=250, top_p=0.0, temp=1, duration=8, cfg_coef=3, progress=gradio.Progress()): | |
if is_loaded(): | |
model.set_generation_params(use_sample, top_k, top_p, temp, duration, cfg_coef) | |
progress(0, desc='Generating') | |
def progress_callback(p, t): | |
progress((p, t), desc='Generating') | |
model.set_custom_progress_callback(progress_callback) | |
input_audio_not_none = input_audio is not None | |
sr, wav = 0, None | |
if input_audio_not_none: | |
sr, wav = input_audio | |
wav = torch.tensor(wav) | |
if wav.dtype == torch.int16: | |
wav = (wav.float() / 32767.0) | |
if wav.dim() == 2 and wav.shape[1] == 2: | |
wav = wav.mean(dim=1) | |
if input_audio_not_none and supports_melody(): | |
wav = model.generate_with_chroma([prompt if prompt else None], wav[None].expand(1, -1, -1), sr, True) | |
elif input_audio_not_none: | |
model.set_generation_params(use_sample, top_k, top_p, temp, duration, cfg_coef) | |
wav = model.generate_continuation(wav[None].expand(1, -1, -1), sr, [prompt if prompt else None], True) | |
elif not prompt: | |
wav = model.generate_unconditional(1, True) | |
else: | |
wav = model.generate([prompt], True) | |
wav = wav.cpu().flatten().numpy() | |
return model.sample_rate, wav | |
raise gradio.Error('No model loaded! Please load a model first.') | |