Update audioldm/pipeline.py
Browse files- audioldm/pipeline.py +18 -4
audioldm/pipeline.py
CHANGED
@@ -30,7 +30,23 @@ def make_batch_for_text_to_audio(text, batchsize=1):
|
|
30 |
)
|
31 |
return batch
|
32 |
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
if(torch.cuda.is_available()):
|
35 |
device = torch.device("cuda:0")
|
36 |
else:
|
@@ -40,7 +56,7 @@ def build_model(config=None):
|
|
40 |
assert type(config) is str
|
41 |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
|
42 |
else:
|
43 |
-
config = default_audioldm_config()
|
44 |
|
45 |
# Use text as condition instead of using waveform during training
|
46 |
config["model"]["params"]["device"] = device
|
@@ -49,8 +65,6 @@ def build_model(config=None):
|
|
49 |
# No normalization here
|
50 |
latent_diffusion = LatentDiffusion(**config["model"]["params"])
|
51 |
|
52 |
-
resume_from_checkpoint = "./ckpt/ldm_trimmed.ckpt"
|
53 |
-
|
54 |
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
|
55 |
latent_diffusion.load_state_dict(checkpoint["state_dict"])
|
56 |
|
|
|
30 |
)
|
31 |
return batch
|
32 |
|
33 |
+
|
34 |
+
|
35 |
+
def build_model(
|
36 |
+
ckpt_path=None,
|
37 |
+
config=None,
|
38 |
+
model_name="audioldm-s-full"
|
39 |
+
):
|
40 |
+
print("Load AudioLDM: %s" % model_name)
|
41 |
+
|
42 |
+
resume_from_checkpoint = "ckpt/%s.ckpt" % model_name
|
43 |
+
|
44 |
+
# if(ckpt_path is None):
|
45 |
+
# ckpt_path = get_metadata()[model_name]["path"]
|
46 |
+
|
47 |
+
# if(not os.path.exists(ckpt_path)):
|
48 |
+
# download_checkpoint(model_name)
|
49 |
+
|
50 |
if(torch.cuda.is_available()):
|
51 |
device = torch.device("cuda:0")
|
52 |
else:
|
|
|
56 |
assert type(config) is str
|
57 |
config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
|
58 |
else:
|
59 |
+
config = default_audioldm_config(model_name)
|
60 |
|
61 |
# Use text as condition instead of using waveform during training
|
62 |
config["model"]["params"]["device"] = device
|
|
|
65 |
# No normalization here
|
66 |
latent_diffusion = LatentDiffusion(**config["model"]["params"])
|
67 |
|
|
|
|
|
68 |
checkpoint = torch.load(resume_from_checkpoint, map_location=device)
|
69 |
latent_diffusion.load_state_dict(checkpoint["state_dict"])
|
70 |
|