File size: 2,699 Bytes
bdab1da
 
 
 
 
 
 
 
412929c
bdab1da
 
412929c
bdab1da
 
4e9d8a1
bdab1da
4e9d8a1
 
bdab1da
 
 
412929c
bdab1da
 
 
 
 
 
 
 
 
 
807c6f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdab1da
 
 
 
 
 
 
 
 
807c6f0
bdab1da
 
 
 
 
 
 
 
ddc4da2
bdab1da
 
 
380571b
bdab1da
 
39711bd
bdab1da
412929c
 
bdab1da
4e9d8a1
412929c
39711bd
412929c
 
bdab1da
 
 
 
39711bd
 
bdab1da
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93


import os

import argparse
import yaml
import torch

from audioldm import LatentDiffusion, seed_everything
from audioldm.utils import default_audioldm_config


import time

def make_batch_for_text_to_audio(text, batchsize=1):
    text = [text] * batchsize
    if batchsize < 1:
        print("Warning: Batchsize must be at least 1. Batchsize is set to .")
    fbank = torch.zeros((batchsize, 1024, 64))  # Not used, here to keep the code format
    stft = torch.zeros((batchsize, 1024, 512))  # Not used
    waveform = torch.zeros((batchsize, 160000))  # Not used
    fname = [""] * batchsize # Not used
    batch = (
        fbank,
        stft,
        None,
        fname,
        waveform,
        text,
    )  
    return batch



def build_model(
    ckpt_path=None,
    config=None,
    model_name="audioldm-s-full"
):
    print("Load AudioLDM: %s" % model_name)
    
    resume_from_checkpoint = "ckpt/%s.ckpt" % model_name
    
    # if(ckpt_path is None):
    #     ckpt_path = get_metadata()[model_name]["path"]
    
    # if(not os.path.exists(ckpt_path)):
    #     download_checkpoint(model_name)
        
    if(torch.cuda.is_available()):
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
        
    if(config is not None):
        assert type(config) is str
        config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
    else:
        config = default_audioldm_config(model_name)

    # Use text as condition instead of using waveform during training
    config["model"]["params"]["device"] = device
    config["model"]["params"]["cond_stage_key"] = "text"

    # No normalization here
    latent_diffusion = LatentDiffusion(**config["model"]["params"])

    checkpoint = torch.load(resume_from_checkpoint, map_location=device)
    latent_diffusion.load_state_dict(checkpoint["state_dict"])

    latent_diffusion.eval()
    latent_diffusion = latent_diffusion.to(device)

    latent_diffusion.cond_stage_model.embed_mode = "text"
    return latent_diffusion

def duration_to_latent_t_size(duration):
    return int(duration * 25.6) 

def text_to_audio(latent_diffusion, text, seed=42, duration=10, batchsize=1, guidance_scale=2.5, n_candidate_gen_per_text=3, config=None):
    seed_everything(int(seed))
    batch = make_batch_for_text_to_audio(text, batchsize=batchsize)
    
    latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
    with torch.no_grad():
        waveform = latent_diffusion.generate_sample(
            [batch],
            unconditional_guidance_scale=guidance_scale,
            n_candidate_gen_per_text=n_candidate_gen_per_text,
            duration=duration
        )
    return waveform