File size: 4,675 Bytes
5085882 0d7e2ec e3b4eb8 5085882 ab52fda 0780f67 e3b4eb8 5085882 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import shutil
import os
import argparse
import yaml
import torch
import sys
#colab implementation
# lets add the local path for the audioldm_train library
sys.path.append('/content/qa-mdt')
from .audioldm_train.utilities.data.dataset_original_mos5 import AudioDataset as AudioDataset
from .audioldm_train.utilities.tools import build_dataset_json_from_list
from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything
from .audioldm_train.utilities.tools import get_restore_step
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def infer(dataset_key, configs, config_yaml_path, exp_group_name, exp_name):
seed_everything(0)
if "precision" in configs.keys():
torch.set_float32_matmul_precision(configs["precision"])
log_path = configs["log_directory"]
if "dataloader_add_ons" in configs["data"].keys():
dataloader_add_ons = configs["data"]["dataloader_add_ons"]
else:
dataloader_add_ons = []
val_dataset = AudioDataset(
configs, split="test", add_ons=dataloader_add_ons, dataset_json=dataset_key
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
)
try:
config_reload_from_ckpt = configs["reload_from_ckpt"]
except:
config_reload_from_ckpt = None
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints")
wandb_path = os.path.join(log_path, exp_group_name, exp_name)
os.makedirs(checkpoint_path, exist_ok=True)
shutil.copy(config_yaml_path, wandb_path)
# /disk1/changli/jiqun_training_checkpoints/checkpoints/
if len(os.listdir(checkpoint_path)) > 0:
print("Load checkpoint from path: %s" % checkpoint_path)
restore_step, n_step = get_restore_step(checkpoint_path)
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step)
print("Resume from checkpoint", resume_from_checkpoint)
elif config_reload_from_ckpt is not None:
resume_from_checkpoint = config_reload_from_ckpt
print("Reload ckpt specified in the config file %s" % resume_from_checkpoint)
else:
print("Train from scratch")
resume_from_checkpoint = None
latent_diffusion = instantiate_from_config(configs["model"])
latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name)
guidance_scale = configs["model"]["params"]["evaluation_params"][
"unconditional_guidance_scale"
]
ddim_sampling_steps = configs["model"]["params"]["evaluation_params"][
"ddim_sampling_steps"
]
n_candidates_per_samples = configs["model"]["params"]["evaluation_params"][
"n_candidates_per_samples"
]
# resume_from_checkpoint = ""
checkpoint = torch.load(resume_from_checkpoint)
latent_diffusion.load_state_dict(checkpoint["state_dict"],strict=False)
latent_diffusion.eval()
latent_diffusion = latent_diffusion.cuda()
latent_diffusion.generate_sample(
val_loader,
unconditional_guidance_scale=guidance_scale,
ddim_steps=ddim_sampling_steps,
n_gen=n_candidates_per_samples,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_yaml",
type=str,
required=False,
help="path to config .yaml file",
)
parser.add_argument(
"-l",
"--list_inference",
type=str,
required=False,
help="The filelist that contain captions (and optionally filenames)",
)
parser.add_argument(
"-reload_from_ckpt",
"--reload_from_ckpt",
type=str,
required=False,
default=None,
help="the checkpoint path for the model",
)
args = parser.parse_args()
assert torch.cuda.is_available(), "CUDA is not available"
config_yaml = args.config_yaml
dataset_key = build_dataset_json_from_list(args.list_inference)
exp_name = os.path.basename(config_yaml.split(".")[0])
exp_group_name = os.path.basename(os.path.dirname(config_yaml))
config_yaml_path = os.path.join(config_yaml)
config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader)
if args.reload_from_ckpt is not None:
config_yaml["reload_from_ckpt"] = args.reload_from_ckpt
infer(dataset_key, config_yaml, config_yaml_path, exp_group_name, exp_name)
|