|
from diffusers import DiffusionPipeline |
|
import torch |
|
import os |
|
import yaml |
|
from audioldm_train.utilities.tools import build_dataset_json_from_list |
|
from infer_mos5 import infer |
|
|
|
class MOSDiffusionPipeline(DiffusionPipeline): |
|
|
|
def __init__(self, config_yaml, list_inference, reload_from_ckpt=None): |
|
""" |
|
Initialize the MOS Diffusion pipeline. |
|
|
|
Args: |
|
config_yaml (str): Path to the YAML configuration file. |
|
list_inference (str): Path to the file containing inference prompts. |
|
reload_from_ckpt (str, optional): Checkpoint path to reload from. |
|
""" |
|
super().__init__() |
|
|
|
|
|
self.config_yaml = config_yaml |
|
self.list_inference = list_inference |
|
self.reload_from_ckpt = reload_from_ckpt |
|
|
|
|
|
config_yaml_path = os.path.join(self.config_yaml) |
|
self.configs = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader) |
|
|
|
|
|
if self.reload_from_ckpt is not None: |
|
self.configs["reload_from_ckpt"] = self.reload_from_ckpt |
|
|
|
self.dataset_key = build_dataset_json_from_list(self.list_inference) |
|
self.exp_name = os.path.basename(self.config_yaml.split(".")[0]) |
|
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml)) |
|
|
|
@torch.no_grad() |
|
def __call__(self, *args, **kwargs): |
|
""" |
|
Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py. |
|
|
|
Args: |
|
*args: Additional arguments. |
|
**kwargs: Keyword arguments that may contain overrides for configurations. |
|
|
|
Returns: |
|
None. Inference is performed and samples are generated. |
|
""" |
|
|
|
infer( |
|
dataset_key=self.dataset_key, |
|
configs=self.configs, |
|
config_yaml_path=self.config_yaml, |
|
exp_group_name=self.exp_group_name, |
|
exp_name=self.exp_name |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|