File size: 3,753 Bytes
225b9d3
 
893807d
 
ec033ea
f38a7f2
225b9d3
893807d
225b9d3
893807d
225b9d3
893807d
225b9d3
 
 
 
 
893807d
225b9d3
 
 
893807d
 
 
225b9d3
 
 
 
893807d
225b9d3
 
 
bf7634c
225b9d3
 
 
893807d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225b9d3
ec033ea
225b9d3
 
 
61de3b8
ec033ea
893807d
ec033ea
225b9d3
ec033ea
893807d
 
ec033ea
 
225b9d3
 
ec033ea
 
 
 
 
7245754
 
ec033ea
 
893807d
 
ec033ea
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from diffusers import DiffusionPipeline
import os
import sys
from huggingface_hub import HfApi, hf_hub_download
# from .tools import build_dataset_json_from_list
import torch

class MOSDiffusionPipeline(DiffusionPipeline):

    def __init__(self, config_yaml, list_inference, reload_from_ckpt=None, base_folder=None):
        """
        Initialize the MOS Diffusion pipeline and download the necessary files/folders.

        Args:
            config_yaml (str): Path to the YAML configuration file.
            list_inference (str): Path to the file containing inference prompts.
            reload_from_ckpt (str, optional): Checkpoint path to reload from.
            base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory.
        """
        super().__init__()


        self.base_folder = base_folder if base_folder else os.getcwd()
        self.repo_id = "jadechoghari/qa-mdt" 
        self.config_yaml = config_yaml
        self.list_inference = list_inference
        self.reload_from_ckpt = reload_from_ckpt
        config_yaml_path = os.path.join(self.config_yaml)
        self.configs = self.load_yaml(config_yaml_path)
        if self.reload_from_ckpt is not None:
            self.configs["reload_from_ckpt"] = self.reload_from_ckpt

        self.dataset_key = build_dataset_json_from_list(self.list_inference)
        self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
        self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))

    def download_required_folders(self):
        """
        Downloads the necessary folders from the Hugging Face Hub if they are not already available locally.
        """
        api = HfApi()

        files = api.list_repo_files(repo_id=self.repo_id)

        required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"]

        files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)]

        for file in files_to_download:
            local_file_path = os.path.join(self.base_folder, file)
            if not os.path.exists(local_file_path):
                downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file)

                os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

                os.rename(downloaded_file, local_file_path)

        sys.path.append(self.base_folder)

    def load_yaml(self, yaml_path):
        """
        Helper method to load the YAML configuration.
        """
        import yaml
        with open(yaml_path, "r") as f:
            return yaml.safe_load(f)


    @torch.no_grad()
    def __call__(self, prompt: str):
        """
        Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
        """
        from .infer.infer_mos5 import infer
        dataset_key = self.build_dataset_json_from_prompt(prompt)

        # we run inference with the prompt - configs - and other settings
        infer(
            dataset_key=dataset_key,
            configs=self.configs,
            config_yaml_path=self.config_yaml,
            exp_group_name="qa_mdt",
            exp_name="mos_as_token"
        )

    def build_dataset_json_from_prompt(self, prompt: str):
        """
        Build dataset_key dynamically from the provided prompt.
        """
        # for simplicity let's just return the prompt as the dataset_key
        data = [{"wav": "", "caption": prompt}]  # no wav file, just the caption (prompt)
        return {"data": data}
        

# Example of how to use the pipeline
if __name__ == "__main__":
    pipe = MOSDiffusionPipeline()
    result = pipe("Generate a description of a sunny day.")
    print(result)