File size: 2,781 Bytes
3883c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gc

import gradio
import torch
from audiocraft.models import MusicGen
from audiocraft.models import AudioGen

model: MusicGen = None
loaded = False
used_model = ''
device: str = None

melody_models = ['facebook/musicgen-melody']
audiogen_models = ['facebook/audiogen-medium']
models = ['facebook/musicgen-small', 'facebook/musicgen-medium', 'facebook/musicgen-large'] + melody_models + audiogen_models


def supports_melody():
    return used_model in melody_models


def create_model(pretrained='medium', map_device='cuda' if torch.cuda.is_available() else 'cpu'):
    if is_loaded():
        delete_model()
    global model, loaded, device, used_model
    try:
        model = MusicGen.get_pretrained(pretrained, device=map_device) if pretrained not in audiogen_models else AudioGen.get_pretrained(pretrained, device=map_device)
        device = map_device
        used_model = pretrained
        loaded = True
    except:
        raise gradio.Error('Could not load model!')


def delete_model():
    global model, loaded, device
    try:
        del model
        gc.collect()
        torch.cuda.empty_cache()
        loaded = False
        device = None
    except:
        raise gradio.Error('Could not unload model!')


def is_loaded():
    return loaded


def generate(prompt='', input_audio=None, use_sample=True, top_k=250, top_p=0.0, temp=1, duration=8, cfg_coef=3, progress=gradio.Progress()):
    if is_loaded():
        model.set_generation_params(use_sample, top_k, top_p, temp, duration, cfg_coef)
        progress(0, desc='Generating')

        def progress_callback(p, t):
            progress((p, t), desc='Generating')

        model.set_custom_progress_callback(progress_callback)


        input_audio_not_none = input_audio is not None

        sr, wav = 0, None

        if input_audio_not_none:
            sr, wav = input_audio
            wav = torch.tensor(wav)
            if wav.dtype == torch.int16:
                wav = (wav.float() / 32767.0)
            if wav.dim() == 2 and wav.shape[1] == 2:
                wav = wav.mean(dim=1)

        if input_audio_not_none and supports_melody():
            wav = model.generate_with_chroma([prompt if prompt else None], wav[None].expand(1, -1, -1), sr, True)
        elif input_audio_not_none:
            model.set_generation_params(use_sample, top_k, top_p, temp, duration, cfg_coef)
            wav = model.generate_continuation(wav[None].expand(1, -1, -1), sr, [prompt if prompt else None], True)
        elif not prompt:
            wav = model.generate_unconditional(1, True)
        else:
            wav = model.generate([prompt], True)

        wav = wav.cpu().flatten().numpy()
        return model.sample_rate, wav
    raise gradio.Error('No model loaded! Please load a model first.')