File size: 4,997 Bytes
44d964a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import urllib

import os
from typing import List
from urllib.parse import urlparse

from tqdm import tqdm

from src.conversion.hf_converter import convert_hf_whisper

class ModelConfig:
    def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
        """
        Initialize a model configuration.

        name: Name of the model
        url: URL to download the model from
        path: Path to the model file. If not set, the model will be downloaded from the URL.
        type: Type of model. Can be whisper or huggingface.
        """
        self.name = name
        self.url = url
        self.path = path
        self.type = type

    def download_url(self, root_dir: str):
        import whisper

        # See if path is already set
        if self.path is not None:
            return self.path
        
        if root_dir is None:
            root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")

        model_type = self.type.lower() if self.type is not None else "whisper"

        if model_type in ["huggingface", "hf"]:
            self.path = self.url
            destination_target = os.path.join(root_dir, self.name + ".pt")

            # Convert from HuggingFace format to Whisper format
            if os.path.exists(destination_target):
                print(f"File {destination_target} already exists, skipping conversion")
            else:
                print("Saving HuggingFace model in Whisper format to " + destination_target)
                convert_hf_whisper(self.url, destination_target)

            self.path = destination_target

        elif model_type in ["whisper", "w"]:
            self.path = self.url

            # See if URL is just a file
            if self.url in whisper._MODELS:
                # No need to download anything - Whisper will handle it
                self.path = self.url
            elif self.url.startswith("file://"):
                # Get file path
                self.path = urlparse(self.url).path
            # See if it is an URL
            elif self.url.startswith("http://") or self.url.startswith("https://"):
                # Extension (or file name)
                extension = os.path.splitext(self.url)[-1]
                download_target = os.path.join(root_dir, self.name + extension)

                if os.path.exists(download_target) and not os.path.isfile(download_target):
                    raise RuntimeError(f"{download_target} exists and is not a regular file")

                if not os.path.isfile(download_target):
                    self._download_file(self.url, download_target)
                else:
                    print(f"File {download_target} already exists, skipping download")

                self.path = download_target
            # Must be a local file
            else:
                self.path = self.url

        else:
            raise ValueError(f"Unknown model type {model_type}")

        return self.path

    def _download_file(self, url: str, destination: str):
        with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
            with tqdm(
                total=int(source.info().get("Content-Length")),
                ncols=80,
                unit="iB",
                unit_scale=True,
                unit_divisor=1024,
            ) as loop:
                while True:
                    buffer = source.read(8192)
                    if not buffer:
                        break

                    output.write(buffer)
                    loop.update(len(buffer))

class ApplicationConfig:
    def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600, 
                 share: bool = False, server_name: str = None, server_port: int = 7860, default_model_name: str = "medium", 
                 default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800, 
                 auto_parallel: bool = False, output_dir: str = None):
        self.models = models
        self.input_audio_max_duration = input_audio_max_duration
        self.share = share
        self.server_name = server_name
        self.server_port = server_port
        self.default_model_name = default_model_name
        self.default_vad = default_vad
        self.vad_parallel_devices = vad_parallel_devices
        self.vad_cpu_cores = vad_cpu_cores
        self.vad_process_timeout = vad_process_timeout
        self.auto_parallel = auto_parallel
        self.output_dir = output_dir

    def get_model_names(self):
        return [ x.name for x in self.models ]

    @staticmethod
    def parse_file(config_path: str):
        import json5

        with open(config_path, "r") as f:
            # Load using json5
            data = json5.load(f)
            data_models = data.pop("models", [])

            models = [ ModelConfig(**x) for x in data_models ]

            return ApplicationConfig(models, **data)